Skip to content

Commit

Permalink
Allow query params to be marshaled into set (#617)
Browse files Browse the repository at this point in the history
* Allow query params to be marshaled into thrift set

With this change, we are allowing repeated query params to be mapped into a set
Currently they can only be added to a list.
  • Loading branch information
argouber committed Aug 7, 2019
1 parent e6898db commit 5a3df0d
Show file tree
Hide file tree
Showing 17 changed files with 1,916 additions and 354 deletions.
169 changes: 80 additions & 89 deletions codegen/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -1018,10 +1018,10 @@ func (ms *MethodSpec) setTypeConverters(
return nil
}

func getQueryMethodForType(typeSpec compile.TypeSpec) string {
func getQueryMethodForPrimitiveType(typeSpec compile.TypeSpec) string {
var queryMethod string

switch t := typeSpec.(type) {
switch typeSpec.(type) {
case *compile.BoolSpec:
queryMethod = "GetQueryBool"
case *compile.I8Spec:
Expand All @@ -1038,30 +1038,6 @@ func getQueryMethodForType(typeSpec compile.TypeSpec) string {
queryMethod = "GetQueryFloat64"
case *compile.StringSpec:
queryMethod = "GetQueryValue"
case *compile.ListSpec:
switch compile.RootTypeSpec(t.ValueSpec).(type) {
case *compile.BoolSpec:
queryMethod = "GetQueryBoolList"
case *compile.I8Spec:
queryMethod = "GetQueryInt8List"
case *compile.I16Spec:
queryMethod = "GetQueryInt16List"
case *compile.I32Spec:
queryMethod = "GetQueryInt32List"
case *compile.EnumSpec:
queryMethod = "GetQueryInt32List"
case *compile.I64Spec:
queryMethod = "GetQueryInt64List"
case *compile.DoubleSpec:
queryMethod = "GetQueryFloat64List"
case *compile.StringSpec:
queryMethod = "GetQueryValues"
default:
panic(fmt.Sprintf(
"Unsupported list value type (%T) for %s as query string parameter",
t.ValueSpec, t.ValueSpec.ThriftName(),
))
}
default:
panic(fmt.Sprintf(
"Unsupported type (%T) for %s as query string parameter",
Expand All @@ -1072,14 +1048,27 @@ func getQueryMethodForType(typeSpec compile.TypeSpec) string {
return queryMethod
}

func getQueryEncodeExpression(
typeSpec compile.TypeSpec, valueName string,
) string {
func getQueryMethodForType(typeSpec compile.TypeSpec) string {
var queryMethod string

switch t := typeSpec.(type) {
case *compile.ListSpec:
queryMethod = getQueryMethodForPrimitiveType(compile.RootTypeSpec(t.ValueSpec)) + "List"
case *compile.SetSpec:
queryMethod = getQueryMethodForPrimitiveType(compile.RootTypeSpec(t.ValueSpec)) + "Set"
default:
queryMethod = getQueryMethodForPrimitiveType(typeSpec)
}

return queryMethod
}

func getQueryEncodeExprPrimitive(typeSpec compile.TypeSpec) string {
var encodeExpression string

_, isTypedef := typeSpec.(*compile.TypedefSpec)

switch t := compile.RootTypeSpec(typeSpec).(type) {
switch compile.RootTypeSpec(typeSpec).(type) {
case *compile.BoolSpec:
if isTypedef {
encodeExpression = "strconv.FormatBool(bool(%s))"
Expand Down Expand Up @@ -1112,51 +1101,26 @@ func getQueryEncodeExpression(
}
case *compile.EnumSpec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.ListSpec:
_, isValueTypedef := t.ValueSpec.(*compile.TypedefSpec)
switch compile.RootTypeSpec(t.ValueSpec).(type) {
case *compile.BoolSpec:
if isValueTypedef {
encodeExpression = "strconv.FormatBool(bool(%s))"
} else {
encodeExpression = "strconv.FormatBool(%s)"
}
case *compile.I8Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I16Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I32Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I64Spec:
if isValueTypedef {
encodeExpression = "strconv.FormatInt(int64(%s), 10)"
} else {
encodeExpression = "strconv.FormatInt(%s, 10)"
}
case *compile.DoubleSpec:
if isValueTypedef {
encodeExpression = "strconv.FormatFloat(float64(%s), 'G', -1, 64)"
} else {
encodeExpression = "strconv.FormatFloat(%s, 'G', -1, 64)"
}
case *compile.StringSpec:
if isValueTypedef {
encodeExpression = "string(%s)"
} else {
encodeExpression = "%s"
}
default:
panic(fmt.Sprintf(
"Unsupported list value type (%T) for %s as query string parameter",
t.ValueSpec, t.ValueSpec.ThriftName(),
))
}
default:
panic(fmt.Sprintf(
"Unsupported type (%T) for %s as query string parameter",
typeSpec, typeSpec.ThriftName(),
))
}
return encodeExpression
}

func getQueryEncodeExpression(typeSpec compile.TypeSpec, valueName string) string {
var encodeExpression string

switch t := compile.RootTypeSpec(typeSpec).(type) {
case *compile.ListSpec:
encodeExpression = getQueryEncodeExprPrimitive(t.ValueSpec)
case *compile.SetSpec:
encodeExpression = getQueryEncodeExprPrimitive(t.ValueSpec)
default:
encodeExpression = getQueryEncodeExprPrimitive(typeSpec)
}

return fmt.Sprintf(encodeExpression, valueName)
}
Expand Down Expand Up @@ -1210,6 +1174,7 @@ func (ms *MethodSpec) setWriteQueryParamStatements(
longQueryName := ms.getLongQueryName(field, thriftPrefix)
identifierName := CamelCase(longQueryName) + "Query"
_, isList := realType.(*compile.ListSpec)
_, isSet := realType.(*compile.SetSpec)

if !hasQueryFields {
statements.append("queryValues := &url.Values{}")
Expand All @@ -1222,6 +1187,11 @@ func (ms *MethodSpec) setWriteQueryParamStatements(
statements.appendf("for _, value := range %s {", "r"+longFieldName)
statements.appendf("\tqueryValues.Add(\"%s\", %s)", longQueryName, encodeExpr)
statements.append("}")
} else if isSet {
encodeExpr := getQueryEncodeExpression(field.Type, "value")
statements.appendf("for value := range %s {", "r"+longFieldName)
statements.appendf("\tqueryValues.Add(\"%s\", %s)", longQueryName, encodeExpr)
statements.append("}")
} else {
encodeExpr := getQueryEncodeExpression(field.Type, "r"+longFieldName)
statements.appendf("%s := %s", identifierName, encodeExpr)
Expand All @@ -1234,14 +1204,18 @@ func (ms *MethodSpec) setWriteQueryParamStatements(
statements.appendf("for _, value := range %s {", "r"+longFieldName)
statements.appendf("\tqueryValues.Add(\"%s\", %s)", longQueryName, encodeExpr)
statements.append("}")
} else if isSet {
encodeExpr := getQueryEncodeExpression(field.Type, "value")
statements.appendf("for value := range %s {", "r"+longFieldName)
statements.appendf("\tqueryValues.Add(\"%s\", %s)", longQueryName, encodeExpr)
statements.append("}")
} else {
encodeExpr := getQueryEncodeExpression(field.Type, "*r"+longFieldName)
statements.appendf("\t%s := %s", identifierName, encodeExpr)
statements.appendf("\tqueryValues.Set(\"%s\", %s)", longQueryName, identifierName)
}
statements.append("}")
}

return false
}
walkFieldGroups(compile.FieldGroup(funcSpec.ArgsSpec), visitor)
Expand Down Expand Up @@ -1299,21 +1273,29 @@ func (ms *MethodSpec) setParseQueryParamStatements(
}
}

var listValueTypedef string
t, isList := realType.(*compile.ListSpec)
if isList {
var aggrValueTypedef string
var isList, isSet bool
switch t := realType.(type) {
case *compile.ListSpec:
isList = true
if _, ok := t.ValueSpec.(*compile.TypedefSpec); ok {
listValueTypedef, err = GoType(packageHelper, t.ValueSpec)
aggrValueTypedef, err = GoType(packageHelper, t.ValueSpec)
if err != nil {
finalError = err
return true
}
}
}

// If the type is a struct then we cannot really do anything
if _, ok := realType.(*compile.StructSpec); ok {
// if a field is a struct then we must do a nil check
case *compile.SetSpec:
isSet = true
if _, ok := t.ValueSpec.(*compile.TypedefSpec); ok {
aggrValueTypedef, err = GoType(packageHelper, t.ValueSpec)
if err != nil {
finalError = err
return true
}
}
case *compile.StructSpec:
// If the type is a struct then we cannot really do anything

typeName, err := GoType(packageHelper, realType)
if err != nil {
Expand All @@ -1339,7 +1321,6 @@ func (ms *MethodSpec) setParseQueryParamStatements(

return false
}

identifierName := CamelCase(longQueryName) + "Query"

httpRefAnnotation := field.Annotations[ms.annotations.HTTPRef]
Expand Down Expand Up @@ -1380,18 +1361,28 @@ func (ms *MethodSpec) setParseQueryParamStatements(
}

// if field is a list and list value is typedef, list values must be converted first
if listValueTypedef != "" {
if aggrValueTypedef != "" {
target = fmt.Sprintf("%sFinal", identifierName)
statements.appendf(
"%s%s := make([]%s, len(%s))",
indent, target, listValueTypedef, identifierName,
)
statements.appendf("%sfor i, v := range %s {", indent, identifierName)
statements.appendf("%s%s[i] = %s(v)", indent+"\t", target, listValueTypedef)
statements.appendf("%s}", indent)
if isList {
statements.appendf(
"%s%s := make([]%s, len(%s))",
indent, target, aggrValueTypedef, identifierName,
)
statements.appendf("%sfor i, v := range %s {", indent, identifierName)
statements.appendf("%s%s[i] = %s(v)", indent+"\t", target, aggrValueTypedef)
statements.appendf("%s}", indent)
} else if isSet {
statements.appendf(
"%s%s := make(map[%s]struct{}, len(%s))",
indent, target, aggrValueTypedef, identifierName,
)
statements.appendf("%sfor _, v := range %s {", indent, identifierName)
statements.appendf("%s%s[%s(v)] = struct{}{}", indent+"\t", target, aggrValueTypedef)
statements.appendf("%s}", indent)
}
}

if field.Required || isList {
if field.Required || isList || isSet {
if typedef != "" {
statements.appendf("%srequestBody%s = %s(%s)", indent, longFieldName, typedef, target)
} else {
Expand Down
12 changes: 9 additions & 3 deletions codegen/thrift.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func GoType(p PackageNameResolver, spec compile.TypeSpec) (string, error) {
if err != nil {
return "", err
}
if !isHashable(s.ValueSpec) {
if !isHashable(s.ValueSpec) || isSliceSetType(s) {
return fmt.Sprintf("[]%s", v), nil
}
return fmt.Sprintf("map[%s]struct{}", v), nil
Expand Down Expand Up @@ -134,6 +134,13 @@ func isHashable(spec compile.TypeSpec) bool {
}
}

// IsSliceSetType returns true if the given thrift type is a Set implemented as a slice (as opposed to a map)
func isSliceSetType(spec compile.TypeSpec) bool {
spec = compile.RootTypeSpec(spec)
_, isSet := spec.(*compile.SetSpec)
return isSet && spec.ThriftAnnotations()["go.type"] == "slice"
}

func pointerMethodType(typeSpec compile.TypeSpec) string {
var pointerMethod string

Expand Down Expand Up @@ -212,6 +219,7 @@ func walkFieldGroupsInternal(
case *compile.I64Spec:
case *compile.EnumSpec:
case *compile.ListSpec:
case *compile.SetSpec:
case *compile.StructSpec:
bail := walkFieldGroupsInternal(
goPrefix+"."+PascalCase(field.Name),
Expand All @@ -223,8 +231,6 @@ func walkFieldGroupsInternal(
if bail {
return true
}
case *compile.SetSpec:
// TODO: implement
case *compile.MapSpec:
// TODO: implement
default:
Expand Down
10 changes: 10 additions & 0 deletions examples/example-gateway/build/clients/bar/bar.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 5a3df0d

Please sign in to comment.