Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix generics issue #1345 #1349

Merged
merged 8 commits into from Oct 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
77 changes: 51 additions & 26 deletions generics.go
Expand Up @@ -144,7 +144,7 @@ func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, origi

parametrizedTypeSpec.TypeSpec.Name = ident

newType := resolveGenericType(original.TypeSpec.Type, genericParamTypeDefs)
newType := pkgDefs.resolveGenericType(original.File, original.TypeSpec.Type, genericParamTypeDefs, parseDependency)

genericDefinitionsMutex.Lock()
defer genericDefinitionsMutex.Unlock()
Expand Down Expand Up @@ -197,22 +197,33 @@ func splitStructName(fullGenericForm string) (string, []string) {
return genericTypeName, genericParams
}

func resolveGenericType(expr ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) ast.Expr {
func (pkgDefs *PackagesDefinitions) resolveGenericType(file *ast.File, expr ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec, parseDependency bool) ast.Expr {
switch astExpr := expr.(type) {
case *ast.Ident:
if genTypeSpec, ok := genericParamTypeDefs[astExpr.Name]; ok {
if genTypeSpec.ArrayDepth > 0 {
genTypeSpec.ArrayDepth--
return &ast.ArrayType{Elt: resolveGenericType(expr, genericParamTypeDefs)}
retType := genTypeSpec.Type()
for i := 0; i < genTypeSpec.ArrayDepth; i++ {
retType = &ast.ArrayType{Elt: retType}
}
return genTypeSpec.Type()
return retType
}
case *ast.ArrayType:
return &ast.ArrayType{
Elt: resolveGenericType(astExpr.Elt, genericParamTypeDefs),
Elt: pkgDefs.resolveGenericType(file, astExpr.Elt, genericParamTypeDefs, parseDependency),
Len: astExpr.Len,
Lbrack: astExpr.Lbrack,
}
case *ast.StarExpr:
return &ast.StarExpr{
Star: astExpr.Star,
X: pkgDefs.resolveGenericType(file, astExpr.X, genericParamTypeDefs, parseDependency),
}
case *ast.IndexExpr, *ast.IndexListExpr:
fullGenericName, _ := getGenericFieldType(file, expr, genericParamTypeDefs)
typeDef := pkgDefs.findGenericTypeSpec(fullGenericName, file, parseDependency)
if typeDef != nil {
return typeDef.TypeSpec.Type
}
case *ast.StructType:
newStructTypeDef := &ast.StructType{
Struct: astExpr.Struct,
Expand All @@ -225,37 +236,51 @@ func resolveGenericType(expr ast.Expr, genericParamTypeDefs map[string]*genericT

for _, field := range astExpr.Fields.List {
newField := &ast.Field{
Type: field.Type,
Doc: field.Doc,
Names: field.Names,
Tag: field.Tag,
Comment: field.Comment,
}

newField.Type = resolveGenericType(field.Type, genericParamTypeDefs)
if newField.Type == nil {
newField.Type = field.Type
}
newField.Type = pkgDefs.resolveGenericType(file, field.Type, genericParamTypeDefs, parseDependency)

newStructTypeDef.Fields.List = append(newStructTypeDef.Fields.List, newField)
}
return newStructTypeDef
}
return nil
return expr
}

func getExtendedGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
func getExtendedGenericFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) {
switch fieldType := field.(type) {
case *ast.ArrayType:
fieldName, err := getExtendedGenericFieldType(file, fieldType.Elt)
fieldName, err := getExtendedGenericFieldType(file, fieldType.Elt, genericParamTypeDefs)
return "[]" + fieldName, err
case *ast.StarExpr:
return getExtendedGenericFieldType(file, fieldType.X)
return getExtendedGenericFieldType(file, fieldType.X, genericParamTypeDefs)
case *ast.Ident:
if genericParamTypeDefs != nil {
if typeSpec, ok := genericParamTypeDefs[fieldType.Name]; ok {
return typeSpec.Name, nil
}
}
if fieldType.Obj == nil {
return fieldType.Name, nil
}

tSpec := &TypeSpecDef{
File: file,
TypeSpec: fieldType.Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return tSpec.FullName(), nil
default:
return getFieldType(file, field)
}
}

func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
func getGenericFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) {
var fullName string
var baseName string
var err error
Expand All @@ -268,7 +293,7 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
fullName = baseName + "["

for _, index := range fieldType.Indices {
fieldName, err := getExtendedGenericFieldType(file, index)
fieldName, err := getExtendedGenericFieldType(file, index, genericParamTypeDefs)
if err != nil {
return "", err
}
Expand All @@ -283,7 +308,7 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
return "", err
}

indexName, err := getExtendedGenericFieldType(file, fieldType.Index)
indexName, err := getExtendedGenericFieldType(file, fieldType.Index, genericParamTypeDefs)
if err != nil {
return "", err
}
Expand All @@ -307,27 +332,27 @@ func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
}

func getGenericTypeName(file *ast.File, field ast.Expr) (string, error) {
switch indexType := field.(type) {
switch fieldType := field.(type) {
case *ast.Ident:
if indexType.Obj == nil {
return getFieldType(file, field)
if fieldType.Obj == nil {
return fieldType.Name, nil
}

tSpec := &TypeSpecDef{
File: file,
TypeSpec: indexType.Obj.Decl.(*ast.TypeSpec),
TypeSpec: fieldType.Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return tSpec.FullName(), nil
case *ast.ArrayType:
tSpec := &TypeSpecDef{
File: file,
TypeSpec: indexType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec),
TypeSpec: fieldType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec),
PkgPath: file.Name.Name,
}
return tSpec.FullName(), nil
case *ast.SelectorExpr:
return fmt.Sprintf("%s.%s", indexType.X.(*ast.Ident).Name, indexType.Sel.Name), nil
return fmt.Sprintf("%s.%s", fieldType.X.(*ast.Ident).Name, fieldType.Sel.Name), nil
}
return "", fmt.Errorf("unknown type %#v", field)
}
Expand All @@ -344,10 +369,10 @@ func (parser *Parser) parseGenericTypeExpr(file *ast.File, typeExpr ast.Expr) (*
case *ast.MapType:
case *ast.FuncType:
case *ast.IndexExpr:
name, err := getExtendedGenericFieldType(file, expr)
name, err := getExtendedGenericFieldType(file, expr, nil)
if err == nil {
if schema, err := parser.getTypeSchema(name, file, false); err == nil {
return spec.MapProperty(schema), nil
return schema, nil
}
}

Expand Down
8 changes: 7 additions & 1 deletion generics_other.go
Expand Up @@ -9,6 +9,12 @@ import (
"go/ast"
)

type genericTypeSpec struct {
ArrayDepth int
TypeSpec *TypeSpecDef
Name string
}

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
return typeSpecDef.FullName()
}
Expand All @@ -17,7 +23,7 @@ func (pkgDefs *PackagesDefinitions) parametrizeGenericType(file *ast.File, origi
return original
}

func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
func getGenericFieldType(file *ast.File, field ast.Expr, genericParamTypeDefs map[string]*genericTypeSpec) (string, error) {
return "", fmt.Errorf("unknown field type %#v", field)
}

Expand Down
2 changes: 1 addition & 1 deletion parser.go
Expand Up @@ -1331,7 +1331,7 @@ func getFieldType(file *ast.File, field ast.Expr) (string, error) {

return fullName, nil
default:
return getGenericFieldType(file, field)
return getGenericFieldType(file, field, nil)
}
}

Expand Down
2 changes: 2 additions & 0 deletions testdata/generics_property/api/api.go
Expand Up @@ -22,6 +22,8 @@ type CreateMovie struct {
Producer types.Field[*Person]
Audience Audience[Person]
AudienceNames Audience[string]
Detail1 types.Field[types.Field[Person]]
Detail2 types.Field[types.Field[string]]
}

type Person struct {
Expand Down