Skip to content

Commit

Permalink
Merge 5effab0 into 8175628
Browse files Browse the repository at this point in the history
  • Loading branch information
argouber committed Oct 31, 2019
2 parents 8175628 + 5effab0 commit 0375008
Show file tree
Hide file tree
Showing 15 changed files with 1,613 additions and 387 deletions.
209 changes: 119 additions & 90 deletions codegen/method.go
Expand Up @@ -1012,14 +1012,14 @@ func getQueryMethodForPrimitiveType(typeSpec compile.TypeSpec) string {
queryMethod = "GetQueryInt16"
case *compile.I32Spec:
queryMethod = "GetQueryInt32"
case *compile.EnumSpec:
queryMethod = "GetQueryInt32"
case *compile.I64Spec:
queryMethod = "GetQueryInt64"
case *compile.DoubleSpec:
queryMethod = "GetQueryFloat64"
case *compile.StringSpec:
queryMethod = "GetQueryValue"
case *compile.EnumSpec:
queryMethod = "GetQueryValue"
default:
panic(fmt.Sprintf(
"Unsupported type (%T) for %s as query string parameter",
Expand Down Expand Up @@ -1057,18 +1057,12 @@ func getQueryEncodeExprPrimitive(typeSpec compile.TypeSpec) string {
} else {
encodeExpression = "strconv.FormatBool(%s)"
}
case *compile.I8Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I16Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I32Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I64Spec:
if isTypedef {
encodeExpression = "strconv.FormatInt(int64(%s), 10)"
} else {
encodeExpression = "strconv.FormatInt(%s, 10)"
}
case
*compile.I8Spec,
*compile.I16Spec,
*compile.I32Spec,
*compile.I64Spec:
encodeExpression = "strconv.FormatInt(int64(%s), 10)"
case *compile.DoubleSpec:
if isTypedef {
encodeExpression = "strconv.FormatFloat(float64(%s), 'G', -1, 64)"
Expand All @@ -1082,7 +1076,7 @@ func getQueryEncodeExprPrimitive(typeSpec compile.TypeSpec) string {
encodeExpression = "%s"
}
case *compile.EnumSpec:
encodeExpression = "strconv.Itoa(int(%s))"
encodeExpression = "(%s).String()"
default:
// This is intentional -- lets evaluate why we would want other types here before opening the flood gates
panic(fmt.Sprintf(
Expand Down Expand Up @@ -1273,6 +1267,16 @@ func makeUniqIdent(identifier string, seen map[string]int) string {
return identifier
}

func getCustomType(pkgHelper *PackageHelper, itemType compile.TypeSpec) (string, error) {
switch itemType.(type) {
case
*compile.TypedefSpec,
*compile.EnumSpec:
return GoType(pkgHelper, itemType)
}
return "", nil
}

func (ms *MethodSpec) setParseQueryParamStatements(
funcSpec *compile.FunctionSpec, packageHelper *PackageHelper, hasNoBody bool,
) error {
Expand All @@ -1283,6 +1287,7 @@ func (ms *MethodSpec) setParseQueryParamStatements(
var finalError error
var stack = []string{}
seenIdents := map[string]int{}
indent := ""

visitor := func(
goPrefix string, thriftPrefix string, field *compile.FieldSpec,
Expand All @@ -1303,45 +1308,33 @@ func (ms *MethodSpec) setParseQueryParamStatements(
}
}

var err error
var typedef string
if _, ok := field.Type.(*compile.TypedefSpec); ok {
typedef, err = GoType(packageHelper, field.Type)
if err != nil {
finalError = err
return true
}
}

if _, ok := field.Type.(*compile.EnumSpec); ok {
typedef, err = GoType(packageHelper, field.Type)
if err != nil {
finalError = err
return true
}
customType, err := getCustomType(packageHelper, field.Type)
if err != nil {
finalError = err
return true
}

var aggrValueTypedef string
var isList, isSet bool
var customElemType string
var isEnumElem bool
switch t := realType.(type) {
// Before you ask -- yes duplicated code because ValueSpec is not defined in the generic interface
case *compile.ListSpec:
isList = true
if _, ok := t.ValueSpec.(*compile.TypedefSpec); ok {
aggrValueTypedef, err = GoType(packageHelper, t.ValueSpec)
if err != nil {
finalError = err
return true
}
customElemType, err = getCustomType(packageHelper, t.ValueSpec)
if err != nil {
finalError = err
return true
}
_, isEnumElem = t.ValueSpec.(*compile.EnumSpec)
case *compile.SetSpec:
isSet = true
if _, ok := t.ValueSpec.(*compile.TypedefSpec); ok {
aggrValueTypedef, err = GoType(packageHelper, t.ValueSpec)
if err != nil {
finalError = err
return true
}
customElemType, err = getCustomType(packageHelper, t.ValueSpec)
if err != nil {
finalError = err
return true
}
_, isEnumElem = t.ValueSpec.(*compile.EnumSpec)
case *compile.StructSpec:
typeName, err := GoType(packageHelper, realType)
if err != nil {
Expand All @@ -1362,84 +1355,120 @@ func (ms *MethodSpec) setParseQueryParamStatements(
statements.append("if _queryNeeded {")
}

statements.appendf("if requestBody%s == nil {", longFieldName)
statements.appendf("\trequestBody%s = &%s{}", longFieldName, typeName)
statements.append("}")
statements.appendf("%sif requestBody%s == nil {", indent, longFieldName)
statements.appendf("%s\trequestBody%s = &%s{}", indent, longFieldName, typeName)
statements.append(indent, "}")

return false
}
isAggregate := isList || isSet // we do not support maps

// For disambiguation of similar names
baseIdent := makeUniqIdent(CamelCase(longQueryName), seenIdents)
identifierName := baseIdent + "Query"
okIdentifierName := baseIdent + "Ok"

// make sure value is present
if field.Required {
statements.appendf("%s := req.CheckQueryValue(%q)",
okIdentifierName, shortQueryParam,
)
statements.appendf("if !%s {", okIdentifierName)
statements.append("\treturn")
statements.append("}")
statements.appendf("%s%s := req.CheckQueryValue(%q)", indent, okIdentifierName, shortQueryParam)
statements.appendf("%sif !%s {", indent, okIdentifierName)
statements.append(indent, "\treturn")
statements.append(indent, "}")
} else {
statements.appendf("%s := req.HasQueryValue(%q)",
okIdentifierName, shortQueryParam,
)
statements.appendf("if %s {", okIdentifierName)
statements.appendf("%s%s := req.HasQueryValue(%q)", indent, okIdentifierName, shortQueryParam)
statements.appendf("%sif %s {", indent, okIdentifierName)
indent += "\t"
}

queryMethodName := getQueryMethodForType(realType)

statements.appendf("%s, ok := req.%s(%q)",
identifierName, queryMethodName, shortQueryParam,
)
queryRValue := fmt.Sprintf("req.%s(%q)", getQueryMethodForType(realType), shortQueryParam)

// Transform if enum
if _, isEnumType := field.Type.(*compile.EnumSpec); isEnumType {
statements.appendf("var %s %s", identifierName, customType)
tmpVar := "_tmp" + identifierName
statements.appendf("%s%s, ok := %s", indent, tmpVar, queryRValue)
statements.append(indent, "if ok {")
statements.appendf("%s\tif err := %s.UnmarshalText([]byte(%s)); err != nil {",
indent, identifierName, tmpVar)
statements.appendf("%s\t\treq.LogAndSendQueryError(err, %q, %q, %s)",
indent, "enum", shortQueryParam, tmpVar)
statements.append(indent, "\t\tok = false")
statements.append(indent, "\t}")
statements.append(indent, "}")
} else {
statements.appendf("%s%s, ok := %s", indent, identifierName, queryRValue)
}

statements.append("if !ok {")
statements.append("\treturn")
statements.append("}")

target := identifierName

indent := ""
if !field.Required {
indent += "\t"
}

// if field is a list and list value is typedef, list values must be converted first
if aggrValueTypedef != "" {
target = fmt.Sprintf("%sFinal", identifierName)
// If field is an "aggregate" with custom element types, we need to convert them first
// Note that enums and typedefs are what get in here
if customElemType != "" {
target += "Final"
valVar := "v"
if isList {
statements.appendf(
"%s%s := make([]%s, len(%s))",
indent, target, aggrValueTypedef, identifierName,
indent, target, customElemType, identifierName,
)
statements.appendf("%sfor i, v := range %s {", indent, identifierName)
statements.appendf("%s%s[i] = %s(v)", indent+"\t", target, aggrValueTypedef)
statements.appendf("%s}", indent)
statements.appendf("%sfor i, %s := range %s {", indent, valVar, identifierName)
indent += "\t"
if isEnumElem {
tmpVar := "_tmp" + valVar
statements.appendf("%svar %s %s", indent, tmpVar, customElemType)
statements.appendf("%sif err := %s.UnmarshalText([]byte(%s)); err != nil {",
indent, tmpVar, valVar)
statements.appendf("%s\treq.LogAndSendQueryError(err, %q, %q, %s)",
indent, "enum", shortQueryParam, valVar)
// TODO: @argo: we should ideally log a warning here
statements.append(indent, "\treturn")
statements.append(indent, "}")
valVar = tmpVar
}
statements.appendf("%s\t%s[i] = %s(%s)", indent, target, customElemType, valVar)
indent = indent[:len(indent)-1]
statements.append(indent, "}")
} else if isSet {
statements.appendf(
"%s%s := make(map[%s]struct{}, len(%s))",
indent, target, aggrValueTypedef, identifierName,
indent, target, customElemType, identifierName,
)
statements.appendf("%sfor _, v := range %s {", indent, identifierName)
statements.appendf("%s%s[%s(v)] = struct{}{}", indent+"\t", target, aggrValueTypedef)
statements.appendf("%s}", indent)
indent += "\t"
if isEnumElem {
tmpVar := "_tmp" + valVar
statements.appendf("%svar %s %s", indent, tmpVar, customElemType)
statements.appendf("%sif err := %s.UnmarshalText([]byte(%s)); err != nil {",
indent, tmpVar, valVar)
statements.appendf("%s\treq.LogAndSendQueryError(err, %q, %q, %s)",
indent, "enum", shortQueryParam, valVar)
statements.append(indent, "\treturn")
statements.append(indent, "}")
valVar = tmpVar
}
statements.appendf("%s\t%s[%s(%s)] = struct{}{}", indent, target, customElemType, valVar)
indent = indent[:len(indent)-1]
statements.append(indent, "}")
}
}

if field.Required || isList || isSet {
if typedef != "" {
statements.appendf("%srequestBody%s = %s(%s)", indent, longFieldName, typedef, target)
} else {
statements.appendf("%srequestBody%s = %s", indent, longFieldName, target)
}

} else {
target = fmt.Sprintf("ptr.%s(%s)", pointerMethodType(realType), identifierName)
if typedef != "" {
statements.appendf("%srequestBody%s = (*%s)(%s)", indent, longFieldName, typedef, target)
} else {
statements.appendf("%srequestBody%s = %s", indent, longFieldName, target)
var deref string
if !field.Required && !isAggregate {
deref = "*"
targetName := identifierName
if customType != "" {
targetName = fmt.Sprintf("%s(%s)", strings.ToLower(pointerMethodType(realType)), targetName)
}
target = fmt.Sprintf("ptr.%s(%s)", pointerMethodType(realType), targetName)
}
if customType != "" {
target = fmt.Sprintf("(%s%s)(%s)", deref, customType, target)
}
statements.appendf("%srequestBody%s = %s", indent, longFieldName, target)

if !field.Required {
statements.append("}")
Expand Down

0 comments on commit 0375008

Please sign in to comment.