Skip to content

Commit

Permalink
Merge e48a582 into 61de1b6
Browse files Browse the repository at this point in the history
  • Loading branch information
timshannon committed Oct 16, 2020
2 parents 61de1b6 + e48a582 commit 796dda8
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 45 deletions.
3 changes: 1 addition & 2 deletions find_test.go
Expand Up @@ -139,8 +139,7 @@ var testData = []ItemTest{
Category: "vehicle",
Created: time.Now().AddDate(0, 0, 30),
Color: "pink",
Fruit: "apple",
},
Fruit: "apple"},
{
Key: 12,
ID: 11,
Expand Down
14 changes: 3 additions & 11 deletions get.go
Expand Up @@ -7,7 +7,6 @@ package badgerhold
import (
"errors"
"reflect"
"strings"

"github.com/dgraph-io/badger/v2"
)
Expand Down Expand Up @@ -51,17 +50,10 @@ func (s *Store) TxGet(tx *badger.Txn, key, result interface{}) error {
tp = tp.Elem()
}

var keyField string
keyField, ok := getKeyField(tp)

for i := 0; i < tp.NumField(); i++ {
if strings.Contains(string(tp.Field(i).Tag), BadgerholdKeyTag) {
keyField = tp.Field(i).Name
break
}
}

if keyField != "" {
err := decodeKey(gk, reflect.ValueOf(result).Elem().FieldByName(keyField).Addr().Interface(), storer.Type())
if ok {
err := decodeKey(gk, reflect.ValueOf(result).Elem().FieldByName(keyField.Name).Addr().Interface(), storer.Type())
if err != nil {
return err
}
Expand Down
99 changes: 99 additions & 0 deletions get_test.go
Expand Up @@ -35,3 +35,102 @@ func TestGet(t *testing.T) {
}
})
}

func TestIssue36(t *testing.T) {
testWrap(t, func(store *badgerhold.Store, t *testing.T) {
type Tag1 struct {
ID uint64 `badgerholdKey`
}

type Tag2 struct {
ID uint64 `badgerholdKey:"Key"`
}

type Tag3 struct {
ID uint64 `badgerhold:"key"`
}

type Tag4 struct {
ID uint64 `badgerholdKey:""`
}

data1 := []*Tag1{{}, {}, {}}
for i := range data1 {
ok(t, store.Insert(badgerhold.NextSequence(), data1[i]))
equals(t, uint64(i), data1[i].ID)
}

data2 := []*Tag2{{}, {}, {}}
for i := range data2 {
ok(t, store.Insert(badgerhold.NextSequence(), data2[i]))
equals(t, uint64(i), data2[i].ID)
}

data3 := []*Tag3{{}, {}, {}}
for i := range data3 {
ok(t, store.Insert(badgerhold.NextSequence(), data3[i]))
equals(t, uint64(i), data3[i].ID)
}

data4 := []*Tag4{{}, {}, {}}
for i := range data4 {
ok(t, store.Insert(badgerhold.NextSequence(), data4[i]))
equals(t, uint64(i), data4[i].ID)
}

// Get
for i := range data1 {
get1 := &Tag1{}
ok(t, store.Get(data1[i].ID, get1))
equals(t, data1[i], get1)
}

for i := range data2 {
get2 := &Tag2{}
ok(t, store.Get(data2[i].ID, get2))
equals(t, data2[i], get2)
}

for i := range data3 {
get3 := &Tag3{}
ok(t, store.Get(data3[i].ID, get3))
equals(t, data3[i], get3)
}

for i := range data4 {
get4 := &Tag4{}
ok(t, store.Get(data4[i].ID, get4))
equals(t, data4[i], get4)
}

// Find

for i := range data1 {
var find1 []*Tag1
ok(t, store.Find(&find1, badgerhold.Where(badgerhold.Key).Eq(data1[i].ID)))
assert(t, len(find1) == 1, "incorrect rows returned")
equals(t, find1[0], data1[i])
}

for i := range data2 {
var find2 []*Tag2
ok(t, store.Find(&find2, badgerhold.Where(badgerhold.Key).Eq(data2[i].ID)))
assert(t, len(find2) == 1, "incorrect rows returned")
equals(t, find2[0], data2[i])
}

for i := range data3 {
var find3 []*Tag3
ok(t, store.Find(&find3, badgerhold.Where(badgerhold.Key).Eq(data3[i].ID)))
assert(t, len(find3) == 1, "incorrect rows returned")
equals(t, find3[0], data3[i])
}

for i := range data4 {
var find4 []*Tag4
ok(t, store.Find(&find4, badgerhold.Where(badgerhold.Key).Eq(data4[i].ID)))
assert(t, len(find4) == 1, "incorrect rows returned")
equals(t, find4[0], data4[i])
}
})
}
32 changes: 13 additions & 19 deletions put.go
Expand Up @@ -87,26 +87,20 @@ func (s *Store) TxInsert(tx *badger.Txn, key, data interface{}) error {
if !dataVal.CanSet() {
return nil
}
dataType := dataVal.Type()

for i := 0; i < dataType.NumField(); i++ {
tf := dataType.Field(i)
if _, ok := tf.Tag.Lookup(BadgerholdKeyTag); ok ||
tf.Tag.Get(badgerholdPrefixTag) == badgerholdPrefixKeyValue {
fieldValue := dataVal.Field(i)
keyValue := reflect.ValueOf(key)
if keyValue.Type() != tf.Type {
break
}
if !fieldValue.CanSet() {
break
}
if !reflect.DeepEqual(fieldValue.Interface(), reflect.Zero(tf.Type).Interface()) {
break
}
fieldValue.Set(keyValue)
break

if keyField, ok := getKeyField(dataVal.Type()); ok {
fieldValue := dataVal.FieldByName(keyField.Name)
keyValue := reflect.ValueOf(key)
if keyValue.Type() != keyField.Type {
return nil
}
if !fieldValue.CanSet() {
return nil
}
if !reflect.DeepEqual(fieldValue.Interface(), reflect.Zero(keyField.Type).Interface()) {
return nil
}
fieldValue.Set(keyValue)
}

return nil
Expand Down
16 changes: 3 additions & 13 deletions query.go
Expand Up @@ -800,17 +800,7 @@ func findQuery(tx *badger.Txn, result interface{}, query *Query) error {
tp = tp.Elem()
}

var keyType reflect.Type
var keyField string

for i := 0; i < tp.NumField(); i++ {
if strings.Contains(string(tp.Field(i).Tag), BadgerholdKeyTag) ||
tp.Field(i).Tag.Get(badgerholdPrefixTag) == badgerholdPrefixKeyValue {
keyType = tp.Field(i).Type
keyField = tp.Field(i).Name
break
}
}
keyField, hasKeyField := getKeyField(tp)

val := reflect.New(tp)

Expand All @@ -824,12 +814,12 @@ func findQuery(tx *badger.Txn, result interface{}, query *Query) error {
rowValue = r.value.Elem()
}

if keyType != nil {
if hasKeyField {
rowKey := rowValue
for rowKey.Kind() == reflect.Ptr {
rowKey = rowKey.Elem()
}
err := decodeKey(r.key, rowKey.FieldByName(keyField).Addr().Interface(), tp.Name())
err := decodeKey(r.key, rowKey.FieldByName(keyField.Name).Addr().Interface(), tp.Name())
if err != nil {
return err
}
Expand Down
15 changes: 15 additions & 0 deletions store.go
Expand Up @@ -203,3 +203,18 @@ func (s *Store) getSequence(typeName string) (uint64, error) {
func typePrefix(typeName string) []byte {
return []byte("bh_" + typeName)
}

func getKeyField(tp reflect.Type) (reflect.StructField, bool) {

for i := 0; i < tp.NumField(); i++ {
if strings.HasPrefix(string(tp.Field(i).Tag), BadgerholdKeyTag) {
return tp.Field(i), true
}

if tag := tp.Field(i).Tag.Get(badgerholdPrefixTag); tag == badgerholdPrefixKeyValue {
return tp.Field(i), true
}
}

return reflect.StructField{}, false
}
31 changes: 31 additions & 0 deletions store_test.go
Expand Up @@ -6,6 +6,10 @@ package badgerhold_test

import (
"encoding/json"
"fmt"
"path/filepath"
"reflect"
"runtime"

// "fmt"
"io/ioutil"
Expand Down Expand Up @@ -148,3 +152,30 @@ func tempdir() string {
}
return name
}

// assert fails the test if the condition is false.
func assert(tb testing.TB, condition bool, msg string, v ...interface{}) {
if !condition {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d: "+msg+"\033[39m\n\n", append([]interface{}{filepath.Base(file), line}, v...)...)
tb.FailNow()
}
}

// ok fails the test if an err is not nil.
func ok(tb testing.TB, err error) {
if err != nil {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d: unexpected error: %s\033[39m\n\n", filepath.Base(file), line, err.Error())
tb.FailNow()
}
}

// equals fails the test if exp is not equal to act.
func equals(tb testing.TB, exp, act interface{}) {
if !reflect.DeepEqual(exp, act) {
_, file, line, _ := runtime.Caller(1)
fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act)
tb.FailNow()
}
}

0 comments on commit 796dda8

Please sign in to comment.