Skip to content

Commit

Permalink
Actually handle enum query params as strings
Browse files Browse the repository at this point in the history
With this change zanzibar will treat enums in query params as strings mapped to integers.
Previous implementation treated "enum" as an alias for int32 which is incorrect.
  • Loading branch information
argouber committed Oct 31, 2019
1 parent 8175628 commit 5effab0
Show file tree
Hide file tree
Showing 15 changed files with 1,613 additions and 387 deletions.
209 changes: 119 additions & 90 deletions codegen/method.go
Expand Up @@ -1012,14 +1012,14 @@ func getQueryMethodForPrimitiveType(typeSpec compile.TypeSpec) string {
queryMethod = "GetQueryInt16"
case *compile.I32Spec:
queryMethod = "GetQueryInt32"
case *compile.EnumSpec:
queryMethod = "GetQueryInt32"
case *compile.I64Spec:
queryMethod = "GetQueryInt64"
case *compile.DoubleSpec:
queryMethod = "GetQueryFloat64"
case *compile.StringSpec:
queryMethod = "GetQueryValue"
case *compile.EnumSpec:
queryMethod = "GetQueryValue"
default:
panic(fmt.Sprintf(
"Unsupported type (%T) for %s as query string parameter",
Expand Down Expand Up @@ -1057,18 +1057,12 @@ func getQueryEncodeExprPrimitive(typeSpec compile.TypeSpec) string {
} else {
encodeExpression = "strconv.FormatBool(%s)"
}
case *compile.I8Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I16Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I32Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I64Spec:
if isTypedef {
encodeExpression = "strconv.FormatInt(int64(%s), 10)"
} else {
encodeExpression = "strconv.FormatInt(%s, 10)"
}
case
*compile.I8Spec,
*compile.I16Spec,
*compile.I32Spec,
*compile.I64Spec:
encodeExpression = "strconv.FormatInt(int64(%s), 10)"
case *compile.DoubleSpec:
if isTypedef {
encodeExpression = "strconv.FormatFloat(float64(%s), 'G', -1, 64)"
Expand All @@ -1082,7 +1076,7 @@ func getQueryEncodeExprPrimitive(typeSpec compile.TypeSpec) string {
encodeExpression = "%s"
}
case *compile.EnumSpec:
encodeExpression = "strconv.Itoa(int(%s))"
encodeExpression = "(%s).String()"
default:
// This is intentional -- lets evaluate why we would want other types here before opening the flood gates
panic(fmt.Sprintf(
Expand Down Expand Up @@ -1273,6 +1267,16 @@ func makeUniqIdent(identifier string, seen map[string]int) string {
return identifier
}

func getCustomType(pkgHelper *PackageHelper, itemType compile.TypeSpec) (string, error) {
switch itemType.(type) {
case
*compile.TypedefSpec,
*compile.EnumSpec:
return GoType(pkgHelper, itemType)
}
return "", nil
}

func (ms *MethodSpec) setParseQueryParamStatements(
funcSpec *compile.FunctionSpec, packageHelper *PackageHelper, hasNoBody bool,
) error {
Expand All @@ -1283,6 +1287,7 @@ func (ms *MethodSpec) setParseQueryParamStatements(
var finalError error
var stack = []string{}
seenIdents := map[string]int{}
indent := ""

visitor := func(
goPrefix string, thriftPrefix string, field *compile.FieldSpec,
Expand All @@ -1303,45 +1308,33 @@ func (ms *MethodSpec) setParseQueryParamStatements(
}
}

var err error
var typedef string
if _, ok := field.Type.(*compile.TypedefSpec); ok {
typedef, err = GoType(packageHelper, field.Type)
if err != nil {
finalError = err
return true
}
}

if _, ok := field.Type.(*compile.EnumSpec); ok {
typedef, err = GoType(packageHelper, field.Type)
if err != nil {
finalError = err
return true
}
customType, err := getCustomType(packageHelper, field.Type)
if err != nil {
finalError = err
return true
}

var aggrValueTypedef string
var isList, isSet bool
var customElemType string
var isEnumElem bool
switch t := realType.(type) {
// Before you ask -- yes duplicated code because ValueSpec is not defined in the generic interface
case *compile.ListSpec:
isList = true
if _, ok := t.ValueSpec.(*compile.TypedefSpec); ok {
aggrValueTypedef, err = GoType(packageHelper, t.ValueSpec)
if err != nil {
finalError = err
return true
}
customElemType, err = getCustomType(packageHelper, t.ValueSpec)
if err != nil {
finalError = err
return true
}
_, isEnumElem = t.ValueSpec.(*compile.EnumSpec)
case *compile.SetSpec:
isSet = true
if _, ok := t.ValueSpec.(*compile.TypedefSpec); ok {
aggrValueTypedef, err = GoType(packageHelper, t.ValueSpec)
if err != nil {
finalError = err
return true
}
customElemType, err = getCustomType(packageHelper, t.ValueSpec)
if err != nil {
finalError = err
return true
}
_, isEnumElem = t.ValueSpec.(*compile.EnumSpec)
case *compile.StructSpec:
typeName, err := GoType(packageHelper, realType)
if err != nil {
Expand All @@ -1362,84 +1355,120 @@ func (ms *MethodSpec) setParseQueryParamStatements(
statements.append("if _queryNeeded {")
}

statements.appendf("if requestBody%s == nil {", longFieldName)
statements.appendf("\trequestBody%s = &%s{}", longFieldName, typeName)
statements.append("}")
statements.appendf("%sif requestBody%s == nil {", indent, longFieldName)
statements.appendf("%s\trequestBody%s = &%s{}", indent, longFieldName, typeName)
statements.append(indent, "}")

return false
}
isAggregate := isList || isSet // we do not support maps

// For disambiguation of similar names
baseIdent := makeUniqIdent(CamelCase(longQueryName), seenIdents)
identifierName := baseIdent + "Query"
okIdentifierName := baseIdent + "Ok"

// make sure value is present
if field.Required {
statements.appendf("%s := req.CheckQueryValue(%q)",
okIdentifierName, shortQueryParam,
)
statements.appendf("if !%s {", okIdentifierName)
statements.append("\treturn")
statements.append("}")
statements.appendf("%s%s := req.CheckQueryValue(%q)", indent, okIdentifierName, shortQueryParam)
statements.appendf("%sif !%s {", indent, okIdentifierName)
statements.append(indent, "\treturn")
statements.append(indent, "}")
} else {
statements.appendf("%s := req.HasQueryValue(%q)",
okIdentifierName, shortQueryParam,
)
statements.appendf("if %s {", okIdentifierName)
statements.appendf("%s%s := req.HasQueryValue(%q)", indent, okIdentifierName, shortQueryParam)
statements.appendf("%sif %s {", indent, okIdentifierName)
indent += "\t"
}

queryMethodName := getQueryMethodForType(realType)

statements.appendf("%s, ok := req.%s(%q)",
identifierName, queryMethodName, shortQueryParam,
)
queryRValue := fmt.Sprintf("req.%s(%q)", getQueryMethodForType(realType), shortQueryParam)

// Transform if enum
if _, isEnumType := field.Type.(*compile.EnumSpec); isEnumType {
statements.appendf("var %s %s", identifierName, customType)
tmpVar := "_tmp" + identifierName
statements.appendf("%s%s, ok := %s", indent, tmpVar, queryRValue)
statements.append(indent, "if ok {")
statements.appendf("%s\tif err := %s.UnmarshalText([]byte(%s)); err != nil {",
indent, identifierName, tmpVar)
statements.appendf("%s\t\treq.LogAndSendQueryError(err, %q, %q, %s)",
indent, "enum", shortQueryParam, tmpVar)
statements.append(indent, "\t\tok = false")
statements.append(indent, "\t}")
statements.append(indent, "}")
} else {
statements.appendf("%s%s, ok := %s", indent, identifierName, queryRValue)
}

statements.append("if !ok {")
statements.append("\treturn")
statements.append("}")

target := identifierName

indent := ""
if !field.Required {
indent += "\t"
}

// if field is a list and list value is typedef, list values must be converted first
if aggrValueTypedef != "" {
target = fmt.Sprintf("%sFinal", identifierName)
// If field is an "aggregate" with custom element types, we need to convert them first
// Note that enums and typedefs are what get in here
if customElemType != "" {
target += "Final"
valVar := "v"
if isList {
statements.appendf(
"%s%s := make([]%s, len(%s))",
indent, target, aggrValueTypedef, identifierName,
indent, target, customElemType, identifierName,
)
statements.appendf("%sfor i, v := range %s {", indent, identifierName)
statements.appendf("%s%s[i] = %s(v)", indent+"\t", target, aggrValueTypedef)
statements.appendf("%s}", indent)
statements.appendf("%sfor i, %s := range %s {", indent, valVar, identifierName)
indent += "\t"
if isEnumElem {
tmpVar := "_tmp" + valVar
statements.appendf("%svar %s %s", indent, tmpVar, customElemType)
statements.appendf("%sif err := %s.UnmarshalText([]byte(%s)); err != nil {",
indent, tmpVar, valVar)
statements.appendf("%s\treq.LogAndSendQueryError(err, %q, %q, %s)",
indent, "enum", shortQueryParam, valVar)
// TODO: @argo: we should ideally log a warning here
statements.append(indent, "\treturn")
statements.append(indent, "}")
valVar = tmpVar
}
statements.appendf("%s\t%s[i] = %s(%s)", indent, target, customElemType, valVar)
indent = indent[:len(indent)-1]
statements.append(indent, "}")
} else if isSet {
statements.appendf(
"%s%s := make(map[%s]struct{}, len(%s))",
indent, target, aggrValueTypedef, identifierName,
indent, target, customElemType, identifierName,
)
statements.appendf("%sfor _, v := range %s {", indent, identifierName)
statements.appendf("%s%s[%s(v)] = struct{}{}", indent+"\t", target, aggrValueTypedef)
statements.appendf("%s}", indent)
indent += "\t"
if isEnumElem {
tmpVar := "_tmp" + valVar
statements.appendf("%svar %s %s", indent, tmpVar, customElemType)
statements.appendf("%sif err := %s.UnmarshalText([]byte(%s)); err != nil {",
indent, tmpVar, valVar)
statements.appendf("%s\treq.LogAndSendQueryError(err, %q, %q, %s)",
indent, "enum", shortQueryParam, valVar)
statements.append(indent, "\treturn")
statements.append(indent, "}")
valVar = tmpVar
}
statements.appendf("%s\t%s[%s(%s)] = struct{}{}", indent, target, customElemType, valVar)
indent = indent[:len(indent)-1]
statements.append(indent, "}")
}
}

if field.Required || isList || isSet {
if typedef != "" {
statements.appendf("%srequestBody%s = %s(%s)", indent, longFieldName, typedef, target)
} else {
statements.appendf("%srequestBody%s = %s", indent, longFieldName, target)
}

} else {
target = fmt.Sprintf("ptr.%s(%s)", pointerMethodType(realType), identifierName)
if typedef != "" {
statements.appendf("%srequestBody%s = (*%s)(%s)", indent, longFieldName, typedef, target)
} else {
statements.appendf("%srequestBody%s = %s", indent, longFieldName, target)
var deref string
if !field.Required && !isAggregate {
deref = "*"
targetName := identifierName
if customType != "" {
targetName = fmt.Sprintf("%s(%s)", strings.ToLower(pointerMethodType(realType)), targetName)
}
target = fmt.Sprintf("ptr.%s(%s)", pointerMethodType(realType), targetName)
}
if customType != "" {
target = fmt.Sprintf("(%s%s)(%s)", deref, customType, target)
}
statements.appendf("%srequestBody%s = %s", indent, longFieldName, target)

if !field.Required {
statements.append("}")
Expand Down

0 comments on commit 5effab0

Please sign in to comment.