Skip to content

Commit

Permalink
fix multi level nesting parametrization(#1435)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaborpongracz committed Jan 5, 2023
1 parent f916213 commit f617051
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 15 deletions.
4 changes: 2 additions & 2 deletions generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func getExtendedGenericFieldType(file *ast.File, field ast.Expr, genericParamTyp
}
return tSpec.TypeName(), nil
default:
return getFieldType(file, field)
return getFieldType(file, field, genericParamTypeDefs)
}
}

Expand Down Expand Up @@ -288,7 +288,7 @@ func getGenericFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs ma
if file.Name == nil {
return "", errors.New("file name is nil")
}
packageName, _ = getFieldType(file, file.Name)
packageName, _ = getFieldType(file, file.Name, genericParamTypeDefs)
}

return strings.TrimLeft(fmt.Sprintf("%s.%s", packageName, fullName), "."), nil
Expand Down
26 changes: 26 additions & 0 deletions generics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,21 @@ func TestParseGenericsNested(t *testing.T) {
assert.Equal(t, string(expected), string(b))
}

func TestParseGenericsMultiLevelNesting(t *testing.T) {
t.Parallel()

searchDir := "testdata/generics_multi_level_nesting"
expected, err := os.ReadFile(filepath.Join(searchDir, "expected.json"))
assert.NoError(t, err)

p := New()
err = p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth)
assert.NoError(t, err)
b, err := json.MarshalIndent(p.swagger, "", " ")
assert.NoError(t, err)
assert.Equal(t, string(expected), string(b))
}

func TestParseGenericsProperty(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -228,6 +243,7 @@ func TestGetGenericFieldType(t *testing.T) {
X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
Indices: []ast.Expr{&ast.Ident{Name: "string"}},
},
nil,
)
assert.NoError(t, err)
assert.Equal(t, "test.Field[string]", field)
Expand All @@ -238,6 +254,7 @@ func TestGetGenericFieldType(t *testing.T) {
X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
Indices: []ast.Expr{&ast.Ident{Name: "string"}},
},
nil,
)
assert.NoError(t, err)
assert.Equal(t, "Field[string]", field)
Expand All @@ -248,6 +265,7 @@ func TestGetGenericFieldType(t *testing.T) {
X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.Ident{Name: "int"}},
},
nil,
)
assert.NoError(t, err)
assert.Equal(t, "test.Field[string,int]", field)
Expand All @@ -258,6 +276,7 @@ func TestGetGenericFieldType(t *testing.T) {
X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.ArrayType{Elt: &ast.Ident{Name: "int"}}},
},
nil,
)
assert.NoError(t, err)
assert.Equal(t, "test.Field[string,[]int]", field)
Expand All @@ -268,6 +287,7 @@ func TestGetGenericFieldType(t *testing.T) {
X: &ast.BadExpr{},
Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.Ident{Name: "int"}},
},
nil,
)
assert.Error(t, err)

Expand All @@ -277,37 +297,43 @@ func TestGetGenericFieldType(t *testing.T) {
X: &ast.Ident{Name: "types", Obj: &ast.Object{Decl: &ast.TypeSpec{Name: &ast.Ident{Name: "Field"}}}},
Indices: []ast.Expr{&ast.Ident{Name: "string"}, &ast.ArrayType{Elt: &ast.BadExpr{}}},
},
nil,
)
assert.Error(t, err)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.Ident{Name: "string"}},
nil,
)
assert.NoError(t, err)
assert.Equal(t, "test.Field[string]", field)

field, err = getFieldType(
&ast.File{Name: nil},
&ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.Ident{Name: "string"}},
nil,
)
assert.Error(t, err)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexExpr{X: &ast.BadExpr{}, Index: &ast.Ident{Name: "string"}},
nil,
)
assert.Error(t, err)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexExpr{X: &ast.Ident{Name: "Field"}, Index: &ast.BadExpr{}},
nil,
)
assert.Error(t, err)

field, err = getFieldType(
&ast.File{Name: &ast.Ident{Name: "test"}},
&ast.IndexExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "field"}, Sel: &ast.Ident{Name: "Name"}}, Index: &ast.Ident{Name: "string"}},
nil,
)
assert.NoError(t, err)
assert.Equal(t, "field.Name[string]", field)
Expand Down
12 changes: 6 additions & 6 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -1311,7 +1311,7 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st
}

if fieldName == "" {
typeName, err := getFieldType(file, field.Type)
typeName, err := getFieldType(file, field.Type, nil)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -1344,7 +1344,7 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st
}

if schema == nil {
typeName, err := getFieldType(file, field.Type)
typeName, err := getFieldType(file, field.Type, nil)
if err == nil {
// named type
schema, err = parser.getTypeSchema(typeName, file, true)
Expand Down Expand Up @@ -1377,26 +1377,26 @@ func (parser *Parser) parseStructField(file *ast.File, field *ast.Field) (map[st
return map[string]spec.Schema{fieldName: *schema}, tagRequired, nil
}

func getFieldType(file *ast.File, field ast.Expr) (string, error) {
func getFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) {
switch fieldType := field.(type) {
case *ast.Ident:
return fieldType.Name, nil
case *ast.SelectorExpr:
packageName, err := getFieldType(file, fieldType.X)
packageName, err := getFieldType(file, fieldType.X, genericParamTypeDefs)
if err != nil {
return "", err
}

return fullTypeName(packageName, fieldType.Sel.Name), nil
case *ast.StarExpr:
fullName, err := getFieldType(file, fieldType.X)
fullName, err := getFieldType(file, fieldType.X, genericParamTypeDefs)
if err != nil {
return "", err
}

return fullName, nil
default:
return getGenericFieldType(file, field, nil)
return getGenericFieldType(file, field, genericParamTypeDefs)
}
}

Expand Down
14 changes: 7 additions & 7 deletions parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3694,28 +3694,28 @@ func TestParser_Skip(t *testing.T) {
func TestGetFieldType(t *testing.T) {
t.Parallel()

field, err := getFieldType(&ast.File{}, &ast.Ident{Name: "User"})
field, err := getFieldType(&ast.File{}, &ast.Ident{Name: "User"}, nil)
assert.NoError(t, err)
assert.Equal(t, "User", field)

_, err = getFieldType(&ast.File{}, &ast.FuncType{})
_, err = getFieldType(&ast.File{}, &ast.FuncType{}, nil)
assert.Error(t, err)

field, err = getFieldType(&ast.File{}, &ast.SelectorExpr{X: &ast.Ident{Name: "models"}, Sel: &ast.Ident{Name: "User"}})
field, err = getFieldType(&ast.File{}, &ast.SelectorExpr{X: &ast.Ident{Name: "models"}, Sel: &ast.Ident{Name: "User"}}, nil)
assert.NoError(t, err)
assert.Equal(t, "models.User", field)

_, err = getFieldType(&ast.File{}, &ast.SelectorExpr{X: &ast.FuncType{}, Sel: &ast.Ident{Name: "User"}})
_, err = getFieldType(&ast.File{}, &ast.SelectorExpr{X: &ast.FuncType{}, Sel: &ast.Ident{Name: "User"}}, nil)
assert.Error(t, err)

field, err = getFieldType(&ast.File{}, &ast.StarExpr{X: &ast.Ident{Name: "User"}})
field, err = getFieldType(&ast.File{}, &ast.StarExpr{X: &ast.Ident{Name: "User"}}, nil)
assert.NoError(t, err)
assert.Equal(t, "User", field)

field, err = getFieldType(&ast.File{}, &ast.StarExpr{X: &ast.FuncType{}})
field, err = getFieldType(&ast.File{}, &ast.StarExpr{X: &ast.FuncType{}}, nil)
assert.Error(t, err)

field, err = getFieldType(&ast.File{}, &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "models"}, Sel: &ast.Ident{Name: "User"}}})
field, err = getFieldType(&ast.File{}, &ast.StarExpr{X: &ast.SelectorExpr{X: &ast.Ident{Name: "models"}, Sel: &ast.Ident{Name: "User"}}}, nil)
assert.NoError(t, err)
assert.Equal(t, "models.User", field)
}
Expand Down
16 changes: 16 additions & 0 deletions testdata/generics_multi_level_nesting/api/api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package api

import (
"net/http"
)

// GetPosts
// @Summary Test Generics with multi level nesting
// @Description Test one of the edge cases found in generics
// @Accept json
// @Produce json
// @Success 200 {object} web.TestResponse
// @Router /use-struct-and-generics-with-multi-level-nesting [get]
func GetPosts(w http.ResponseWriter, r *http.Request) {

}
142 changes: 142 additions & 0 deletions testdata/generics_multi_level_nesting/expected.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
{
"swagger": "2.0",
"info": {
"description": "This is a sample server Petstore server.",
"title": "Swagger Example API",
"contact": {},
"version": "1.0"
},
"host": "localhost:4000",
"basePath": "/api",
"paths": {
"/use-struct-and-generics-with-multi-level-nesting": {
"get": {
"description": "Test one of the edge cases found in generics",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"summary": "Test Generics with multi level nesting",
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/web.TestResponse"
}
}
}
}
}
},
"definitions": {
"web.DataPoint-float64": {
"type": "object",
"properties": {
"timestamp": {
"type": "integer"
},
"value": {
"type": "number"
}
}
},
"web.DataPoint-int64": {
"type": "object",
"properties": {
"timestamp": {
"type": "integer"
},
"value": {
"type": "integer"
}
}
},
"web.Entity-float64": {
"type": "object",
"properties": {
"line_with_fix_type": {
"type": "array",
"items": {
"$ref": "#/definitions/web.DataPoint-float64"
}
},
"line_with_generic_type": {
"type": "array",
"items": {
"$ref": "#/definitions/web.DataPoint-float64"
}
},
"multiple_lines": {
"type": "array",
"items": {
"$ref": "#/definitions/web.NamedLineData-float64"
}
}
}
},
"web.Entity-int64": {
"type": "object",
"properties": {
"line_with_fix_type": {
"type": "array",
"items": {
"$ref": "#/definitions/web.DataPoint-float64"
}
},
"line_with_generic_type": {
"type": "array",
"items": {
"$ref": "#/definitions/web.DataPoint-int64"
}
},
"multiple_lines": {
"type": "array",
"items": {
"$ref": "#/definitions/web.NamedLineData-int64"
}
}
}
},
"web.NamedLineData-float64": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/definitions/web.DataPoint-float64"
}
},
"name": {
"type": "string"
}
}
},
"web.NamedLineData-int64": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/definitions/web.DataPoint-int64"
}
},
"name": {
"type": "string"
}
}
},
"web.TestResponse": {
"type": "object",
"properties": {
"field_1": {
"$ref": "#/definitions/web.Entity-int64"
},
"field_2": {
"$ref": "#/definitions/web.Entity-float64"
}
}
}
}
}
17 changes: 17 additions & 0 deletions testdata/generics_multi_level_nesting/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package main

import (
"net/http"

"github.com/swaggo/swag/testdata/generics_nested_my_version/api"
)

// @title Swagger Example API
// @version 1.0
// @description This is a sample server Petstore server.
// @host localhost:4000
// @basePath /api
func main() {
http.HandleFunc("/posts/", api.GetPosts)
http.ListenAndServe(":8080", nil)
}
Loading

0 comments on commit f617051

Please sign in to comment.