Skip to content

Commit

Permalink
Actually handle enum query params as strings (#666)
Browse files Browse the repository at this point in the history
* Actually handle enum query params as strings

With this change zanzibar will treat enums in query params as strings mapped to integers.
Previous implementation treated "enum" as an alias for int32 which is incorrect.

* minor fixit

* fix-indent
  • Loading branch information
argouber committed Oct 31, 2019
1 parent 8175628 commit 3270596
Show file tree
Hide file tree
Showing 15 changed files with 1,619 additions and 392 deletions.
220 changes: 125 additions & 95 deletions codegen/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -1012,14 +1012,14 @@ func getQueryMethodForPrimitiveType(typeSpec compile.TypeSpec) string {
queryMethod = "GetQueryInt16"
case *compile.I32Spec:
queryMethod = "GetQueryInt32"
case *compile.EnumSpec:
queryMethod = "GetQueryInt32"
case *compile.I64Spec:
queryMethod = "GetQueryInt64"
case *compile.DoubleSpec:
queryMethod = "GetQueryFloat64"
case *compile.StringSpec:
queryMethod = "GetQueryValue"
case *compile.EnumSpec:
queryMethod = "GetQueryValue"
default:
panic(fmt.Sprintf(
"Unsupported type (%T) for %s as query string parameter",
Expand Down Expand Up @@ -1057,18 +1057,12 @@ func getQueryEncodeExprPrimitive(typeSpec compile.TypeSpec) string {
} else {
encodeExpression = "strconv.FormatBool(%s)"
}
case *compile.I8Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I16Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I32Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I64Spec:
if isTypedef {
encodeExpression = "strconv.FormatInt(int64(%s), 10)"
} else {
encodeExpression = "strconv.FormatInt(%s, 10)"
}
case
*compile.I8Spec,
*compile.I16Spec,
*compile.I32Spec,
*compile.I64Spec:
encodeExpression = "strconv.FormatInt(int64(%s), 10)"
case *compile.DoubleSpec:
if isTypedef {
encodeExpression = "strconv.FormatFloat(float64(%s), 'G', -1, 64)"
Expand All @@ -1082,7 +1076,7 @@ func getQueryEncodeExprPrimitive(typeSpec compile.TypeSpec) string {
encodeExpression = "%s"
}
case *compile.EnumSpec:
encodeExpression = "strconv.Itoa(int(%s))"
encodeExpression = "(%s).String()"
default:
// This is intentional -- lets evaluate why we would want other types here before opening the flood gates
panic(fmt.Sprintf(
Expand Down Expand Up @@ -1273,6 +1267,16 @@ func makeUniqIdent(identifier string, seen map[string]int) string {
return identifier
}

func getCustomType(pkgHelper *PackageHelper, itemType compile.TypeSpec) (string, error) {
switch itemType.(type) {
case
*compile.TypedefSpec,
*compile.EnumSpec:
return GoType(pkgHelper, itemType)
}
return "", nil
}

func (ms *MethodSpec) setParseQueryParamStatements(
funcSpec *compile.FunctionSpec, packageHelper *PackageHelper, hasNoBody bool,
) error {
Expand All @@ -1283,6 +1287,7 @@ func (ms *MethodSpec) setParseQueryParamStatements(
var finalError error
var stack = []string{}
seenIdents := map[string]int{}
indent := ""

visitor := func(
goPrefix string, thriftPrefix string, field *compile.FieldSpec,
Expand All @@ -1299,49 +1304,38 @@ func (ms *MethodSpec) setParseQueryParamStatements(
if len(stack) > 0 {
if !strings.HasPrefix(longFieldName, stack[len(stack)-1]) {
stack = stack[:len(stack)-1]
statements.append("}")
indent = indent[:len(indent)-1]
statements.append(indent, "}")
}
}

var err error
var typedef string
if _, ok := field.Type.(*compile.TypedefSpec); ok {
typedef, err = GoType(packageHelper, field.Type)
if err != nil {
finalError = err
return true
}
}

if _, ok := field.Type.(*compile.EnumSpec); ok {
typedef, err = GoType(packageHelper, field.Type)
if err != nil {
finalError = err
return true
}
customType, err := getCustomType(packageHelper, field.Type)
if err != nil {
finalError = err
return true
}

var aggrValueTypedef string
var isList, isSet bool
var customElemType string
var isEnumElem bool
switch t := realType.(type) {
// Before you ask -- yes duplicated code because ValueSpec is not defined in the generic interface
case *compile.ListSpec:
isList = true
if _, ok := t.ValueSpec.(*compile.TypedefSpec); ok {
aggrValueTypedef, err = GoType(packageHelper, t.ValueSpec)
if err != nil {
finalError = err
return true
}
customElemType, err = getCustomType(packageHelper, t.ValueSpec)
if err != nil {
finalError = err
return true
}
_, isEnumElem = t.ValueSpec.(*compile.EnumSpec)
case *compile.SetSpec:
isSet = true
if _, ok := t.ValueSpec.(*compile.TypedefSpec); ok {
aggrValueTypedef, err = GoType(packageHelper, t.ValueSpec)
if err != nil {
finalError = err
return true
}
customElemType, err = getCustomType(packageHelper, t.ValueSpec)
if err != nil {
finalError = err
return true
}
_, isEnumElem = t.ValueSpec.(*compile.EnumSpec)
case *compile.StructSpec:
typeName, err := GoType(packageHelper, realType)
if err != nil {
Expand All @@ -1362,87 +1356,123 @@ func (ms *MethodSpec) setParseQueryParamStatements(
statements.append("if _queryNeeded {")
}

statements.appendf("if requestBody%s == nil {", longFieldName)
statements.appendf("\trequestBody%s = &%s{}", longFieldName, typeName)
statements.append("}")
statements.appendf("%sif requestBody%s == nil {", indent, longFieldName)
statements.appendf("%s\trequestBody%s = &%s{}", indent, longFieldName, typeName)
statements.append(indent, "}")

return false
}
isAggregate := isList || isSet // we do not support maps

// For disambiguation of similar names
baseIdent := makeUniqIdent(CamelCase(longQueryName), seenIdents)
identifierName := baseIdent + "Query"
okIdentifierName := baseIdent + "Ok"

// make sure value is present
if field.Required {
statements.appendf("%s := req.CheckQueryValue(%q)",
okIdentifierName, shortQueryParam,
)
statements.appendf("if !%s {", okIdentifierName)
statements.append("\treturn")
statements.append("}")
statements.appendf("%s%s := req.CheckQueryValue(%q)", indent, okIdentifierName, shortQueryParam)
statements.appendf("%sif !%s {", indent, okIdentifierName)
statements.append(indent, "\treturn")
statements.append(indent, "}")
} else {
statements.appendf("%s := req.HasQueryValue(%q)",
okIdentifierName, shortQueryParam,
)
statements.appendf("if %s {", okIdentifierName)
statements.appendf("%s%s := req.HasQueryValue(%q)", indent, okIdentifierName, shortQueryParam)
statements.appendf("%sif %s {", indent, okIdentifierName)
indent += "\t"
}

queryMethodName := getQueryMethodForType(realType)

statements.appendf("%s, ok := req.%s(%q)",
identifierName, queryMethodName, shortQueryParam,
)
queryRValue := fmt.Sprintf("req.%s(%q)", getQueryMethodForType(realType), shortQueryParam)

// Transform if enum
if _, isEnumType := field.Type.(*compile.EnumSpec); isEnumType {
statements.appendf("var %s %s", identifierName, customType)
tmpVar := "_tmp" + identifierName
statements.appendf("%s%s, ok := %s", indent, tmpVar, queryRValue)
statements.append(indent, "if ok {")
statements.appendf("%s\tif err := %s.UnmarshalText([]byte(%s)); err != nil {",
indent, identifierName, tmpVar)
statements.appendf("%s\t\treq.LogAndSendQueryError(err, %q, %q, %s)",
indent, "enum", shortQueryParam, tmpVar)
statements.append(indent, "\t\tok = false")
statements.append(indent, "\t}")
statements.append(indent, "}")
} else {
statements.appendf("%s%s, ok := %s", indent, identifierName, queryRValue)
}

statements.append("if !ok {")
statements.append("\treturn")
statements.append("}")
statements.append(indent, "if !ok {")
statements.append(indent, "\treturn")
statements.append(indent, "}")

target := identifierName

indent := ""
if !field.Required {
indent += "\t"
}

// if field is a list and list value is typedef, list values must be converted first
if aggrValueTypedef != "" {
target = fmt.Sprintf("%sFinal", identifierName)
// If field is an "aggregate" with custom element types, we need to convert them first
// Note that enums and typedefs are what get in here
if customElemType != "" {
target += "Final"
valVar := "v"
if isList {
statements.appendf(
"%s%s := make([]%s, len(%s))",
indent, target, aggrValueTypedef, identifierName,
indent, target, customElemType, identifierName,
)
statements.appendf("%sfor i, v := range %s {", indent, identifierName)
statements.appendf("%s%s[i] = %s(v)", indent+"\t", target, aggrValueTypedef)
statements.appendf("%s}", indent)
statements.appendf("%sfor i, %s := range %s {", indent, valVar, identifierName)
indent += "\t"
if isEnumElem {
tmpVar := "_tmp" + valVar
statements.appendf("%svar %s %s", indent, tmpVar, customElemType)
statements.appendf("%sif err := %s.UnmarshalText([]byte(%s)); err != nil {",
indent, tmpVar, valVar)
statements.appendf("%s\treq.LogAndSendQueryError(err, %q, %q, %s)",
indent, "enum", shortQueryParam, valVar)
statements.append(indent, "\treturn")
statements.append(indent, "}")
valVar = tmpVar
}
statements.appendf("%s%s[i] = %s(%s)", indent, target, customElemType, valVar)
indent = indent[:len(indent)-1]
statements.append(indent, "}")
} else if isSet {
statements.appendf(
"%s%s := make(map[%s]struct{}, len(%s))",
indent, target, aggrValueTypedef, identifierName,
indent, target, customElemType, identifierName,
)
statements.appendf("%sfor _, v := range %s {", indent, identifierName)
statements.appendf("%s%s[%s(v)] = struct{}{}", indent+"\t", target, aggrValueTypedef)
statements.appendf("%s}", indent)
indent += "\t"
if isEnumElem {
tmpVar := "_tmp" + valVar
statements.appendf("%svar %s %s", indent, tmpVar, customElemType)
statements.appendf("%sif err := %s.UnmarshalText([]byte(%s)); err != nil {",
indent, tmpVar, valVar)
statements.appendf("%s\treq.LogAndSendQueryError(err, %q, %q, %s)",
indent, "enum", shortQueryParam, valVar)
statements.append(indent, "\treturn")
statements.append(indent, "}")
valVar = tmpVar
}
statements.appendf("%s%s[%s(%s)] = struct{}{}", indent, target, customElemType, valVar)
indent = indent[:len(indent)-1]
statements.append(indent, "}")
}
}

if field.Required || isList || isSet {
if typedef != "" {
statements.appendf("%srequestBody%s = %s(%s)", indent, longFieldName, typedef, target)
} else {
statements.appendf("%srequestBody%s = %s", indent, longFieldName, target)
}

} else {
target = fmt.Sprintf("ptr.%s(%s)", pointerMethodType(realType), identifierName)
if typedef != "" {
statements.appendf("%srequestBody%s = (*%s)(%s)", indent, longFieldName, typedef, target)
} else {
statements.appendf("%srequestBody%s = %s", indent, longFieldName, target)
var deref string
if !field.Required && !isAggregate {
deref = "*"
targetName := identifierName
if customType != "" {
targetName = fmt.Sprintf("%s(%s)", strings.ToLower(pointerMethodType(realType)), targetName)
}
target = fmt.Sprintf("ptr.%s(%s)", pointerMethodType(realType), targetName)
}
if customType != "" {
target = fmt.Sprintf("(%s%s)(%s)", deref, customType, target)
}
statements.appendf("%srequestBody%s = %s", indent, longFieldName, target)

if !field.Required {
statements.append("}")
indent = indent[:len(indent)-1]
statements.append(indent, "}")
}

// new line after block.
Expand Down
Loading

0 comments on commit 3270596

Please sign in to comment.