Skip to content

Commit

Permalink
fix: Generic Fields does not handle Arrays (#1311)
Browse files Browse the repository at this point in the history
* fix: Generic Fields does not handle Arrays

- Support for *ast.IndexExpr added
- tests extended to cover use cases

fixes #1306

* test: Extend tests to improve code coverage
  • Loading branch information
FabianMartin committed Aug 28, 2022
1 parent 007219f commit cf1c4a7
Show file tree
Hide file tree
Showing 9 changed files with 419 additions and 76 deletions.
117 changes: 84 additions & 33 deletions generics.go
Expand Up @@ -6,10 +6,13 @@ package swag
import (
"errors"
"fmt"
"github.com/go-openapi/spec"
"go/ast"
"strings"
"sync"
)

var genericDefinitionsMutex = &sync.RWMutex{}
var genericsDefinitions = map[*TypeSpecDef]map[string]*TypeSpecDef{}

type genericTypeSpec struct {
Expand Down Expand Up @@ -55,9 +58,12 @@ func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
return fullName
}

func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
if spec, ok := genericsDefinitions[original][fullGenericForm]; ok {
return spec
func (pkgDefs *PackagesDefinitions) parametrizeStruct(file *ast.File, original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
genericDefinitionsMutex.RLock()
tSpec, ok := genericsDefinitions[original][fullGenericForm]
genericDefinitionsMutex.RUnlock()
if ok {
return tSpec
}

pkgName := strings.Split(fullGenericForm, ".")[0]
Expand All @@ -81,7 +87,10 @@ func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, ful
arrayDepth++
}

tdef := pkgDefs.FindTypeSpec(genericParam, original.File, parseDependency)
tdef := pkgDefs.FindTypeSpec(genericParam, file, parseDependency)
if tdef != nil && !strings.Contains(genericParam, ".") {
genericParam = fullTypeName(file.Name.Name, genericParam)
}
genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = &genericTypeSpec{
ArrayDepth: arrayDepth,
TypeSpec: tdef,
Expand Down Expand Up @@ -156,6 +165,8 @@ func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, ful
newStructTypeDef.Fields.List = append(newStructTypeDef.Fields.List, newField)
}

genericDefinitionsMutex.Lock()
defer genericDefinitionsMutex.Unlock()
parametrizedTypeSpec.TypeSpec.Type = newStructTypeDef
if genericsDefinitions[original] == nil {
genericsDefinitions[original] = map[string]*TypeSpecDef{}
Expand Down Expand Up @@ -225,78 +236,118 @@ func resolveType(expr ast.Expr, field *ast.Field, genericParamTypeDefs map[strin
return field.Type
}

func getExtendedGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
switch fieldType := field.(type) {
case *ast.ArrayType:
fieldName, err := getExtendedGenericFieldType(file, fieldType.Elt)
return "[]" + fieldName, err
case *ast.StarExpr:
return getExtendedGenericFieldType(file, fieldType.X)
default:
return getFieldType(file, field)
}
}

func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
var fullName string
var baseName string
var err error
switch fieldType := field.(type) {
case *ast.IndexListExpr:
fullName, err := getGenericTypeName(file, fieldType.X)
baseName, err = getGenericTypeName(file, fieldType.X)
if err != nil {
return "", err
}
fullName += "["
fullName = baseName + "["

for _, index := range fieldType.Indices {
var fieldName string
var err error

switch item := index.(type) {
case *ast.ArrayType:
fieldName, err = getFieldType(file, item.Elt)
fieldName = "[]" + fieldName
default:
fieldName, err = getFieldType(file, index)
}

fieldName, err := getExtendedGenericFieldType(file, index)
if err != nil {
return "", err
}

fullName += fieldName + ","
}

return strings.TrimRight(fullName, ",") + "]", nil
fullName = strings.TrimRight(fullName, ",") + "]"
case *ast.IndexExpr:
x, err := getFieldType(file, fieldType.X)
baseName, err = getGenericTypeName(file, fieldType.X)
if err != nil {
return "", err
}

i, err := getFieldType(file, fieldType.Index)
indexName, err := getExtendedGenericFieldType(file, fieldType.Index)
if err != nil {
return "", err
}

packageName := ""
if !strings.Contains(x, ".") {
if file.Name == nil {
return "", errors.New("file name is nil")
}
packageName, _ = getFieldType(file, file.Name)
}
fullName = fmt.Sprintf("%s[%s]", baseName, indexName)
}

return strings.TrimLeft(fmt.Sprintf("%s.%s[%s]", packageName, x, i), "."), nil
if fullName == "" {
return "", fmt.Errorf("unknown field type %#v", field)
}

var packageName string
if !strings.Contains(baseName, ".") {
if file.Name == nil {
return "", errors.New("file name is nil")
}
packageName, _ = getFieldType(file, file.Name)
}

return "", fmt.Errorf("unknown field type %#v", field)
return strings.TrimLeft(fmt.Sprintf("%s.%s", packageName, fullName), "."), nil
}

func getGenericTypeName(file *ast.File, field ast.Expr) (string, error) {
switch indexType := field.(type) {
case *ast.Ident:
spec := &TypeSpecDef{
if indexType.Obj == nil {
return getFieldType(file, field)
}

tSpec := &TypeSpecDef{
File: file,
TypeSpec: indexType.Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return spec.FullName(), nil
return tSpec.FullName(), nil
case *ast.ArrayType:
spec := &TypeSpecDef{
tSpec := &TypeSpecDef{
File: file,
TypeSpec: indexType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return spec.FullName(), nil
return tSpec.FullName(), nil
case *ast.SelectorExpr:
return fmt.Sprintf("%s.%s", indexType.X.(*ast.Ident).Name, indexType.Sel.Name), nil
}
return "", fmt.Errorf("unknown type %#v", field)
}

func (parser *Parser) parseGenericTypeExpr(file *ast.File, typeExpr ast.Expr) (*spec.Schema, error) {
switch expr := typeExpr.(type) {
// suppress debug messages for these types
case *ast.InterfaceType:
case *ast.StructType:
case *ast.Ident:
case *ast.StarExpr:
case *ast.SelectorExpr:
case *ast.ArrayType:
case *ast.MapType:
case *ast.FuncType:
case *ast.IndexExpr:
name, err := getExtendedGenericFieldType(file, expr)
if err == nil {
if schema, err := parser.getTypeSchema(name, file, false); err == nil {
return spec.MapProperty(schema), nil
}
}

parser.debug.Printf("Type definition of type '%T' is not supported yet. Using 'object' instead. (%s)\n", typeExpr, err)
default:
parser.debug.Printf("Type definition of type '%T' is not supported yet. Using 'object' instead.\n", typeExpr)
}

return PrimitiveSchema(OBJECT), nil
}
21 changes: 20 additions & 1 deletion generics_other.go
Expand Up @@ -5,17 +5,36 @@ package swag

import (
"fmt"
"github.com/go-openapi/spec"
"go/ast"
)

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
return typeSpecDef.FullName()
}

func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
func (pkgDefs *PackagesDefinitions) parametrizeStruct(file *ast.File, original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
return original
}

func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
return "", fmt.Errorf("unknown field type %#v", field)
}

func (parser *Parser) parseGenericTypeExpr(file *ast.File, typeExpr ast.Expr) (*spec.Schema, error) {
switch typeExpr.(type) {
// suppress debug messages for these types
case *ast.InterfaceType:
case *ast.StructType:
case *ast.Ident:
case *ast.StarExpr:
case *ast.SelectorExpr:
case *ast.ArrayType:
case *ast.MapType:
case *ast.FuncType:
default:
parser.debug.Printf("Type definition of type '%T' is not supported yet. Using 'object' instead.\n", typeExpr)
}

return PrimitiveSchema(OBJECT), nil
}
67 changes: 67 additions & 0 deletions generics_other_test.go
@@ -0,0 +1,67 @@
//go:build !go1.18
// +build !go1.18

package swag

import (
"fmt"
"github.com/stretchr/testify/assert"
"go/ast"
"testing"
)

type testLogger struct {
Messages []string
}

func (t *testLogger) Printf(format string, v ...interface{}) {
t.Messages = append(t.Messages, fmt.Sprintf(format, v...))
}

func TestParametrizeStruct(t *testing.T) {
t.Parallel()

pd := PackagesDefinitions{
packages: make(map[string]*PackageDefinitions),
}

tSpec := &TypeSpecDef{
TypeSpec: &ast.TypeSpec{
Name: &ast.Ident{Name: "Field"},
Type: &ast.StructType{Struct: 100, Fields: &ast.FieldList{Opening: 101, Closing: 102}},
},
}

tr := pd.parametrizeStruct(&ast.File{}, tSpec, "", false)
assert.Equal(t, tr, tSpec)

tr = pd.parametrizeStruct(&ast.File{}, tSpec, "", true)
assert.Equal(t, tr, tSpec)
}

func TestParseGenericTypeExpr(t *testing.T) {
t.Parallel()

parser := New()
logger := &testLogger{}
SetDebugger(logger)(parser)

_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.InterfaceType{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.StructType{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.Ident{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.StarExpr{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.SelectorExpr{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.ArrayType{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.MapType{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.FuncType{})
assert.Empty(t, logger.Messages)
_, _ = parser.parseGenericTypeExpr(&ast.File{}, &ast.BadExpr{})
assert.NotEmpty(t, logger.Messages)
}

0 comments on commit cf1c4a7

Please sign in to comment.