Skip to content

Commit

Permalink
Merge pull request #431 from uber/lu.typedef
Browse files Browse the repository at this point in the history
Support typedef in query params
  • Loading branch information
ChuntaoLu committed Aug 16, 2018
2 parents 05e8279 + efc0709 commit 68c6199
Show file tree
Hide file tree
Showing 21 changed files with 2,206 additions and 111 deletions.
124 changes: 102 additions & 22 deletions codegen/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ func getQueryMethodForType(typeSpec compile.TypeSpec) string {
case *compile.StringSpec:
queryMethod = "GetQueryValue"
case *compile.ListSpec:
switch t.ValueSpec.(type) {
switch compile.RootTypeSpec(t.ValueSpec).(type) {
case *compile.BoolSpec:
queryMethod = "GetQueryBoolList"
case *compile.I8Spec:
Expand Down Expand Up @@ -1065,37 +1065,72 @@ func getQueryEncodeExpression(
) string {
var encodeExpression string

switch t := typeSpec.(type) {
_, isTypedef := typeSpec.(*compile.TypedefSpec)

switch t := compile.RootTypeSpec(typeSpec).(type) {
case *compile.BoolSpec:
encodeExpression = "strconv.FormatBool(%s)"
if isTypedef {
encodeExpression = "strconv.FormatBool(bool(%s))"
} else {
encodeExpression = "strconv.FormatBool(%s)"
}
case *compile.I8Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I16Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I32Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I64Spec:
encodeExpression = "strconv.FormatInt(%s, 10)"
if isTypedef {
encodeExpression = "strconv.FormatInt(int64(%s), 10)"
} else {
encodeExpression = "strconv.FormatInt(%s, 10)"
}
case *compile.DoubleSpec:
encodeExpression = "strconv.FormatFloat(%s, 'G', -1, 64)"
if isTypedef {
encodeExpression = "strconv.FormatFloat(float64(%s), 'G', -1, 64)"
} else {
encodeExpression = "strconv.FormatFloat(%s, 'G', -1, 64)"
}
case *compile.StringSpec:
encodeExpression = "%s"
if isTypedef {
encodeExpression = "string(%s)"
} else {
encodeExpression = "%s"
}
case *compile.ListSpec:
switch t.ValueSpec.(type) {
_, isValueTypedef := t.ValueSpec.(*compile.TypedefSpec)
switch compile.RootTypeSpec(t.ValueSpec).(type) {
case *compile.BoolSpec:
encodeExpression = "strconv.FormatBool(%s)"
if isValueTypedef {
encodeExpression = "strconv.FormatBool(bool(%s))"
} else {
encodeExpression = "strconv.FormatBool(%s)"
}
case *compile.I8Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I16Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I32Spec:
encodeExpression = "strconv.Itoa(int(%s))"
case *compile.I64Spec:
encodeExpression = "strconv.FormatInt(%s, 10)"
if isValueTypedef {
encodeExpression = "strconv.FormatInt(int64(%s), 10)"
} else {
encodeExpression = "strconv.FormatInt(%s, 10)"
}
case *compile.DoubleSpec:
encodeExpression = "strconv.FormatFloat(%s, 'G', -1, 64)"
if isValueTypedef {
encodeExpression = "strconv.FormatFloat(float64(%s), 'G', -1, 64)"
} else {
encodeExpression = "strconv.FormatFloat(%s, 'G', -1, 64)"
}
case *compile.StringSpec:
encodeExpression = "%s"
if isValueTypedef {
encodeExpression = "string(%s)"
} else {
encodeExpression = "%s"
}
default:
panic(fmt.Sprintf(
"Unsupported list value type (%T) %v for query string parameter",
Expand Down Expand Up @@ -1172,24 +1207,24 @@ func (ms *MethodSpec) setWriteQueryParamStatements(

if field.Required {
if isList {
encodeExpr := getQueryEncodeExpression(realType, "value")
encodeExpr := getQueryEncodeExpression(field.Type, "value")
statements.appendf("for _, value := range %s {", "r"+longFieldName)
statements.appendf("\tqueryValues.Add(\"%s\", %s)", longQueryName, encodeExpr)
statements.append("}")
} else {
encodeExpr := getQueryEncodeExpression(realType, "r"+longFieldName)
encodeExpr := getQueryEncodeExpression(field.Type, "r"+longFieldName)
statements.appendf("%s := %s", identifierName, encodeExpr)
statements.appendf("queryValues.Set(\"%s\", %s)", longQueryName, identifierName)
}
} else {
statements.appendf("if r%s != nil {", longFieldName)
if isList {
encodeExpr := getQueryEncodeExpression(realType, "value")
encodeExpr := getQueryEncodeExpression(field.Type, "value")
statements.appendf("for _, value := range %s {", "r"+longFieldName)
statements.appendf("\tqueryValues.Add(\"%s\", %s)", longQueryName, encodeExpr)
statements.append("}")
} else {
encodeExpr := getQueryEncodeExpression(realType, "*r"+longFieldName)
encodeExpr := getQueryEncodeExpression(field.Type, "*r"+longFieldName)
statements.appendf("\t%s := %s", identifierName, encodeExpr)
statements.appendf("\tqueryValues.Set(\"%s\", %s)", longQueryName, identifierName)
}
Expand Down Expand Up @@ -1235,9 +1270,27 @@ func (ms *MethodSpec) setParseQueryParamStatements(
}
}

_, isList := realType.(*compile.ListSpec)
var err error
var typedef string
if _, ok := field.Type.(*compile.TypedefSpec); ok {
typedef, err = GoType(packageHelper, field.Type)
if err != nil {
finalError = err
return true
}
}

var listValueTypedef string
t, isList := realType.(*compile.ListSpec)
if isList {
longQueryName = longQueryName + "[]"
if _, ok := t.ValueSpec.(*compile.TypedefSpec); ok {
listValueTypedef, err = GoType(packageHelper, t.ValueSpec)
if err != nil {
finalError = err
return true
}
}
}

// If the type is a struct then we cannot really do anything
Expand Down Expand Up @@ -1301,15 +1354,42 @@ func (ms *MethodSpec) setParseQueryParamStatements(
statements.append("\treturn")
statements.append("}")

if field.Required {
statements.appendf("requestBody%s = %s", longFieldName, identifierName)
target := identifierName

indent := ""
if !field.Required {
indent += "\t"
}

// if field is a list and list value is typedef, list values must be converted first
if listValueTypedef != "" {
target = fmt.Sprintf("%sFinal", identifierName)
statements.appendf(
"%s%s := make([]%s, len(%s))",
indent, target, listValueTypedef, identifierName,
)
statements.appendf("%sfor i, v := range %s {", indent, identifierName)
statements.appendf("%s%s[i] = %s(v)", indent+"\t", target, listValueTypedef)
statements.appendf("%s}", indent)
}

if field.Required || isList {
if typedef != "" {
statements.appendf("%srequestBody%s = %s(%s)", indent, longFieldName, typedef, target)
} else {
statements.appendf("%srequestBody%s = %s", indent, longFieldName, target)
}

} else {
if isList {
statements.appendf("\trequestBody%s = %s", longFieldName, identifierName)
target = fmt.Sprintf("ptr.%s(%s)", pointerMethodType(realType), identifierName)
if typedef != "" {
statements.appendf("%srequestBody%s = (*%s)(%s)", indent, longFieldName, typedef, target)
} else {
pointerMethod := pointerMethodType(realType)
statements.appendf("\trequestBody%s = ptr.%s(%s)", longFieldName, pointerMethod, identifierName)
statements.appendf("%srequestBody%s = %s", indent, longFieldName, target)
}
}

if !field.Required {
statements.append("}")
}

Expand Down
36 changes: 36 additions & 0 deletions codegen/test_data/bar.json
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,42 @@
"\tanOptFloat64Query := strconv.FormatFloat(*r.AnOptFloat64, 'G', -1, 64)",
"\tqueryValues.Set(\"anOptFloat64\", anOptFloat64Query)",
"}",
"aUUIDQuery := string(r.AUUID)",
"queryValues.Set(\"aUUID\", aUUIDQuery)",
"if r.AnOptUUID != nil {",
"\tanOptUUIDQuery := string(*r.AnOptUUID)",
"\tqueryValues.Set(\"anOptUUID\", anOptUUIDQuery)",
"}",
"for _, value := range r.AListUUID {",
"\tqueryValues.Add(\"aListUUID[]\", string(value))",
"}",
"if r.AnOptListUUID != nil {",
"for _, value := range r.AnOptListUUID {",
"\tqueryValues.Add(\"anOptListUUID[]\", string(value))",
"}",
"}",
"for _, value := range r.AStringList {",
"\tqueryValues.Add(\"aStringList[]\", value)",
"}",
"if r.AnOptStringList != nil {",
"for _, value := range r.AnOptStringList {",
"\tqueryValues.Add(\"anOptStringList[]\", value)",
"}",
"}",
"for _, value := range r.AUUIDList {",
"\tqueryValues.Add(\"aUUIDList[]\", string(value))",
"}",
"if r.AnOptUUIDList != nil {",
"for _, value := range r.AnOptUUIDList {",
"\tqueryValues.Add(\"anOptUUIDList[]\", string(value))",
"}",
"}",
"aTsQuery := strconv.FormatInt(int64(r.ATs), 10)",
"queryValues.Set(\"aTs\", aTsQuery)",
"if r.AnOptTs != nil {",
"\tanOptTsQuery := strconv.FormatInt(int64(*r.AnOptTs), 10)",
"\tqueryValues.Set(\"anOptTs\", anOptTsQuery)",
"}",
"fullURL += \"?\" + queryValues.Encode()"
],
"ReqHeaderGoStatements": null,
Expand Down
36 changes: 36 additions & 0 deletions codegen/test_data/clients/bar.gogen
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,42 @@ func (c *barClient) ArgWithManyQueryParams(
anOptFloat64Query := strconv.FormatFloat(*r.AnOptFloat64, 'G', -1, 64)
queryValues.Set("anOptFloat64", anOptFloat64Query)
}
aUUIDQuery := string(r.AUUID)
queryValues.Set("aUUID", aUUIDQuery)
if r.AnOptUUID != nil {
anOptUUIDQuery := string(*r.AnOptUUID)
queryValues.Set("anOptUUID", anOptUUIDQuery)
}
for _, value := range r.AListUUID {
queryValues.Add("aListUUID[]", string(value))
}
if r.AnOptListUUID != nil {
for _, value := range r.AnOptListUUID {
queryValues.Add("anOptListUUID[]", string(value))
}
}
for _, value := range r.AStringList {
queryValues.Add("aStringList[]", value)
}
if r.AnOptStringList != nil {
for _, value := range r.AnOptStringList {
queryValues.Add("anOptStringList[]", value)
}
}
for _, value := range r.AUUIDList {
queryValues.Add("aUUIDList[]", string(value))
}
if r.AnOptUUIDList != nil {
for _, value := range r.AnOptUUIDList {
queryValues.Add("anOptUUIDList[]", string(value))
}
}
aTsQuery := strconv.FormatInt(int64(r.ATs), 10)
queryValues.Set("aTs", aTsQuery)
if r.AnOptTs != nil {
anOptTsQuery := strconv.FormatInt(int64(*r.AnOptTs), 10)
queryValues.Set("anOptTs", anOptTsQuery)
}
fullURL += "?" + queryValues.Encode()

err := req.WriteJSON("GET", fullURL, headers, nil)
Expand Down
Loading

0 comments on commit 68c6199

Please sign in to comment.