Skip to content

Commit

Permalink
Re-add support for query params in all http methods
Browse files Browse the repository at this point in the history
The original fix had issues, which have now been fixed.  Unit tests added
  • Loading branch information
argouber committed Sep 18, 2019
1 parent e5d20d5 commit be7daec
Show file tree
Hide file tree
Showing 18 changed files with 11,293 additions and 7,449 deletions.
117 changes: 62 additions & 55 deletions codegen/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ const (
AntHTTPReqDefBoxed = "%s.http.req.def"
)

const queryAnnotationPrefix = "query."
const headerAnnotationPrefix = "headers."

// PathSegment represents a part of the http path.
type PathSegment struct {
Type string
Expand Down Expand Up @@ -238,14 +241,15 @@ func NewMethod(

method.setValidStatusCodes()

if method.HTTPMethod == "GET" && method.RequestType != "" {
if method.RequestType != "" {
hasNoBody := method.HTTPMethod == "GET" || method.HTTPMethod == "DELETE"
if method.IsEndpoint {
err := method.setParseQueryParamStatements(funcSpec, packageHelper)
err := method.setParseQueryParamStatements(funcSpec, packageHelper, hasNoBody)
if err != nil {
return nil, err
}
} else {
err := method.setWriteQueryParamStatements(funcSpec, packageHelper)
err := method.setWriteQueryParamStatements(funcSpec, packageHelper, hasNoBody)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -278,30 +282,6 @@ func NewMethod(
return method, nil
}

func (ms *MethodSpec) scanForNonParams(funcSpec *compile.FunctionSpec) bool {
hasNonParams := false

visitor := func(
goPrefix string, thriftPrefix string, field *compile.FieldSpec,
) bool {
realType := compile.RootTypeSpec(field.Type)
// ignore nested structs
if _, ok := realType.(*compile.StructSpec); ok {
return false
}

param, ok := field.Annotations[ms.annotations.HTTPRef]
if !ok || strings.HasPrefix(param, "params") {
hasNonParams = true
return true
}

return false
}
walkFieldGroups(compile.FieldGroup(funcSpec.ArgsSpec), visitor)
return hasNonParams
}

// setRequestType sets the request type of the method specification. If the
// "zanzibar.http.req.def.boxed" is true, then the first parameter will be used as
// the request body; otherwise a new struct is generated to bundle the request
Expand Down Expand Up @@ -626,8 +606,8 @@ func (ms *MethodSpec) setEndpointRequestHeaderFields(
}

if param, ok := field.Annotations[ms.annotations.HTTPRef]; ok {
if strings.HasPrefix(param, "headers.") {
headerName := param[8:]
if strings.HasPrefix(param, headerAnnotationPrefix) {
headerName := strings.TrimPrefix(param, headerAnnotationPrefix)
camelHeaderName := CamelCase(headerName)

fieldThriftType, err := GoType(packageHelper, field.Type)
Expand Down Expand Up @@ -752,8 +732,8 @@ func (ms *MethodSpec) setResponseHeaderFields(
goPrefix string, thriftPrefix string, field *compile.FieldSpec,
) bool {
if param, ok := field.Annotations[ms.annotations.HTTPRef]; ok {
if strings.HasPrefix(param, "headers.") {
headerName := param[8:]
if strings.HasPrefix(param, headerAnnotationPrefix) {
headerName := strings.TrimPrefix(param, headerAnnotationPrefix)
ms.ResHeaderFields[headerName] = HeaderFieldInfo{
FieldIdentifier: goPrefix + "." + PascalCase(field.Name),
IsPointer: !field.Required,
Expand Down Expand Up @@ -794,8 +774,8 @@ func (ms *MethodSpec) setClientRequestHeaderFields(
}

if param, ok := field.Annotations[ms.annotations.HTTPRef]; ok {
if strings.HasPrefix(param, "headers.") {
headerName := param[8:]
if strings.HasPrefix(param, headerAnnotationPrefix) {
headerName := strings.TrimPrefix(param, headerAnnotationPrefix)
bodyIdentifier := goPrefix + "." + PascalCase(field.Name)
var headerNameValuePair string
if field.Required {
Expand Down Expand Up @@ -1104,6 +1084,7 @@ func getQueryEncodeExprPrimitive(typeSpec compile.TypeSpec) string {
case *compile.EnumSpec:
encodeExpression = "strconv.Itoa(int(%s))"
default:
// This is intentional -- lets evaluate why we would want other types here before opening the flood gates
panic(fmt.Sprintf(
"Unsupported type (%T) for %s as query string parameter",
typeSpec, typeSpec.ThriftName(),
Expand All @@ -1127,16 +1108,44 @@ func getQueryEncodeExpression(typeSpec compile.TypeSpec, valueName string) strin
return fmt.Sprintf(encodeExpression, valueName)
}

func (ms *MethodSpec) hasQueryParams(field *compile.FieldSpec, defaultIsQuery bool) bool {

httpRefAnnotation := field.Annotations[ms.annotations.HTTPRef]
if strings.HasPrefix(httpRefAnnotation, queryAnnotationPrefix) {
return true
}
// If it is a struct, recursively look to see if any of the fields are query params
if container, ok := compile.RootTypeSpec(field.Type).(*compile.StructSpec); ok {
visitor := func(goPrefix string, thriftPrefix string, field *compile.FieldSpec) bool {
annotation := field.Annotations[ms.annotations.HTTPRef]
if strings.HasPrefix(annotation, queryAnnotationPrefix) {
return true
}
return annotation == "" && defaultIsQuery
}
return walkFieldGroups(container.Fields, visitor)
}
return httpRefAnnotation == "" && defaultIsQuery
}

func (ms *MethodSpec) setWriteQueryParamStatements(
funcSpec *compile.FunctionSpec, packageHelper *PackageHelper,
funcSpec *compile.FunctionSpec, packageHelper *PackageHelper, hasNoBody bool,
) error {
var statements LineBuilder
var hasQueryFields bool
var stack = []string{}
var stack []string
isVoidReturn := funcSpec.ResultSpec.ReturnType == nil

visitor := func(
goPrefix string, thriftPrefix string, field *compile.FieldSpec,
) bool {
// Skip if there are no query params in the field or its components
// However, if by definition there is no request body, untagged fields are mapped to query-params
// because "existing behavior" && cannot easily change that
if !ms.hasQueryParams(field, hasNoBody) {
return false
}

realType := compile.RootTypeSpec(field.Type)
longFieldName := goPrefix + "." + PascalCase(field.Name)

Expand All @@ -1146,19 +1155,21 @@ func (ms *MethodSpec) setWriteQueryParamStatements(
statements.append("}")
}
}

if _, ok := realType.(*compile.StructSpec); ok {
// If a field is a struct then skip

// If a field is a struct we need to look inside
if field.Required {
statements.appendf("if r%s == nil {", longFieldName)
// TODO: generate correct number of nils...
statements.append("\treturn nil, nil, errors.New(")
// Generate correct number of nils...
if isVoidReturn {
statements.append("\treturn nil, errors.New(")
} else {
statements.append("\treturn nil, nil, errors.New(")
}
statements.appendf("\t\t\"The field %s is required\",",
longFieldName,
)
statements.append("\t)")
statements.appendf("}")
statements.append("}")
} else {
stack = append(stack, longFieldName)

Expand All @@ -1168,11 +1179,6 @@ func (ms *MethodSpec) setWriteQueryParamStatements(
return false
}

httpRefAnnotation := field.Annotations[ms.annotations.HTTPRef]
if httpRefAnnotation != "" && !strings.HasPrefix(httpRefAnnotation, "query") {
return false
}

longQueryName := ms.getLongQueryName(field, thriftPrefix)
identifierName := CamelCase(longQueryName) + "Query"
_, isList := realType.(*compile.ListSpec)
Expand Down Expand Up @@ -1234,7 +1240,7 @@ func (ms *MethodSpec) setWriteQueryParamStatements(
}

func (ms *MethodSpec) setParseQueryParamStatements(
funcSpec *compile.FunctionSpec, packageHelper *PackageHelper,
funcSpec *compile.FunctionSpec, packageHelper *PackageHelper, hasNoBody bool,
) error {
// If a thrift field has a http.ref annotation then we
// should not read this field from query parameters.
Expand All @@ -1250,6 +1256,13 @@ func (ms *MethodSpec) setParseQueryParamStatements(
longFieldName := goPrefix + "." + PascalCase(field.Name)
longQueryName := ms.getLongQueryName(field, thriftPrefix)

// Skip if there are no query params in the field or its components
// However, if by definition there is no request body, untagged fields are mapped to query-params
// because "existing behavior" && cannot easily change that
if !ms.hasQueryParams(field, hasNoBody) {
return false
}

if len(stack) > 0 {
if !strings.HasPrefix(longFieldName, stack[len(stack)-1]) {
stack = stack[:len(stack)-1]
Expand Down Expand Up @@ -1325,11 +1338,6 @@ func (ms *MethodSpec) setParseQueryParamStatements(
}
identifierName := CamelCase(longQueryName) + "Query"

httpRefAnnotation := field.Annotations[ms.annotations.HTTPRef]
if httpRefAnnotation != "" && !strings.HasPrefix(httpRefAnnotation, "query") {
return false
}

okIdentifierName := CamelCase(longQueryName) + "Ok"
if field.Required {
statements.appendf("%s := req.CheckQueryValue(%q)",
Expand Down Expand Up @@ -1427,9 +1435,8 @@ func (ms *MethodSpec) getLongQueryName(field *compile.FieldSpec, thriftPrefix st

queryName := field.Name
queryAnnotation := field.Annotations[ms.annotations.HTTPRef]
if strings.HasPrefix(queryAnnotation, "query.") {
// len("query.") == 6
queryName = queryAnnotation[6:]
if strings.HasPrefix(queryAnnotation, queryAnnotationPrefix) {
queryName = strings.TrimPrefix(queryAnnotation, queryAnnotationPrefix)
}

if thriftPrefix == "" {
Expand Down
Loading

0 comments on commit be7daec

Please sign in to comment.