From c07579c0b20f5f5e12788fb0704c3a9138882b58 Mon Sep 17 00:00:00 2001 From: James Phillips Date: Thu, 15 Nov 2018 09:23:18 -0800 Subject: [PATCH] ensure default resolver can traverse embedded structs (#2341) Signed-off-by: James Phillips --- backend/apid/graphql/globalid/entity.go | 2 +- backend/apid/graphql/globalid/handlers.go | 2 +- backend/apid/graphql/globalid/silences.go | 2 +- graphql/resolvers.go | 74 +++++++++++++---------- graphql/resolvers_test.go | 73 ++++++++++++++++++++++ 5 files changed, 118 insertions(+), 35 deletions(-) create mode 100644 graphql/resolvers_test.go diff --git a/backend/apid/graphql/globalid/entity.go b/backend/apid/graphql/globalid/entity.go index 1a1eebc486..597cc88a9a 100644 --- a/backend/apid/graphql/globalid/entity.go +++ b/backend/apid/graphql/globalid/entity.go @@ -13,7 +13,7 @@ var entityName = "entities" // EntityTranslator global ID resource var EntityTranslator = commonTranslator{ name: entityName, - encodeFunc: standardEncoder(entityName, "ID"), + encodeFunc: standardEncoder(entityName, "Name"), decodeFunc: standardDecoder, isResponsibleFunc: func(record interface{}) bool { _, ok := record.(*types.Entity) diff --git a/backend/apid/graphql/globalid/handlers.go b/backend/apid/graphql/globalid/handlers.go index 9bab79487b..1e6c9501e6 100644 --- a/backend/apid/graphql/globalid/handlers.go +++ b/backend/apid/graphql/globalid/handlers.go @@ -11,7 +11,7 @@ var handlerName = "handlers" // HandlerTranslator global ID resource var HandlerTranslator = commonTranslator{ name: handlerName, - encodeFunc: standardEncoder(handlerName, "ID"), + encodeFunc: standardEncoder(handlerName, "Name"), decodeFunc: standardDecoder, isResponsibleFunc: func(record interface{}) bool { _, ok := record.(*types.Handler) diff --git a/backend/apid/graphql/globalid/silences.go b/backend/apid/graphql/globalid/silences.go index e9ce1a575b..0987458800 100644 --- a/backend/apid/graphql/globalid/silences.go +++ b/backend/apid/graphql/globalid/silences.go @@ -11,7 +11,7 @@ var silenceName = "silences" // SilenceTranslator global ID resource var SilenceTranslator = commonTranslator{ name: silenceName, - encodeFunc: standardEncoder(silenceName, "ID"), + encodeFunc: standardEncoder(silenceName, "Name"), decodeFunc: standardDecoder, isResponsibleFunc: func(record interface{}) bool { _, ok := record.(*types.Silenced) diff --git a/graphql/resolvers.go b/graphql/resolvers.go index b62d4553a6..0a97576fbd 100644 --- a/graphql/resolvers.go +++ b/graphql/resolvers.go @@ -146,38 +146,8 @@ func DefaultResolver(source interface{}, fieldName string) (interface{}, error) // Struct if sourceVal.Type().Kind() == reflect.Struct { - fieldName = strings.Title(fieldName) - for i := 0; i < sourceVal.NumField(); i++ { - valueField := sourceVal.Field(i) - typeField := sourceVal.Type().Field(i) - if typeField.Name == fieldName { - // If ptr and value is nil return nil - if valueField.Type().Kind() == reflect.Ptr && valueField.IsNil() { - return nil, nil - } - return valueField.Interface(), nil - } - tag := typeField.Tag - checkTag := func(tagName string) bool { - t := tag.Get(tagName) - tOptions := strings.Split(t, ",") - if len(tOptions) == 0 { - return false - } - if tOptions[0] != fieldName { - return false - } - return true - } - if checkTag("json") || checkTag("graphql") { - return valueField.Interface(), nil - } - if valueField.Kind() == reflect.Struct && typeField.Anonymous { - return DefaultResolver(valueField.Interface(), fieldName) - } - continue - } - return nil, nil + _, val, err := findFieldInStruct(sourceVal, fieldName) + return val, err } // map[string]interface @@ -198,6 +168,46 @@ func DefaultResolver(source interface{}, fieldName string) (interface{}, error) return nil, nil } +func findFieldInStruct(source reflect.Value, fieldName string) (bool, interface{}, error) { + for i := 0; i < source.NumField(); i++ { + fieldValue := source.Field(i) + fieldType := source.Type().Field(i) + + if fieldType.Name == strings.Title(fieldName) { + // If ptr and value is nil return nil + if fieldValue.Type().Kind() == reflect.Ptr && fieldValue.IsNil() { + return true, nil, nil + } + return true, fieldValue.Interface(), nil + } + + tag := fieldType.Tag + checkTag := func(tagName string) bool { + t := tag.Get(tagName) + tOptions := strings.Split(t, ",") + if len(tOptions) == 0 { + return false + } + if tOptions[0] != fieldName { + return false + } + return true + } + if checkTag("json") || checkTag("graphql") { + return true, fieldValue.Interface(), nil + } + + if fieldValue.Kind() == reflect.Struct && fieldType.Anonymous { + if ok, val, err := findFieldInStruct(fieldValue, fieldName); ok { + return ok, val, err + } + } + continue + } + + return false, nil, nil +} + type typeResolver interface { ResolveType(interface{}, ResolveTypeParams) *Type } diff --git a/graphql/resolvers_test.go b/graphql/resolvers_test.go new file mode 100644 index 0000000000..7486b65854 --- /dev/null +++ b/graphql/resolvers_test.go @@ -0,0 +1,73 @@ +package graphql + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDefaultResolver(t *testing.T) { + type meta struct { + Name string `json:"firstname"` + } + + type animal struct { + meta + Age int + } + + fren := animal{meta: meta{Name: "bob"}, Age: 10} + + testCases := []struct { + desc string + source interface{} + field string + out interface{} + }{ + { + desc: "field on struct", + source: fren, + field: "name", + out: "bob", + }, + { + desc: "second field on struct", + source: fren, + field: "age", + out: 10, + }, + { + desc: "field on struct w/ tag", + source: fren, + field: "firstname", + out: "bob", + }, + { + desc: "missing field on struct", + source: fren, + field: "surname", + out: nil, + }, + { + desc: "field on map", + source: map[string]interface{}{"name": "bob"}, + field: "name", + out: "bob", + }, + { + desc: "missing field on map", + source: map[string]interface{}{"name": "bob"}, + field: "firstname", + out: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + result, err := DefaultResolver(tc.source, tc.field) + require.NoError(t, err) + assert.EqualValues(t, result, tc.out) + }) + } +}