Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extended generics support #1277

Merged
merged 8 commits into from Aug 2, 2022
186 changes: 165 additions & 21 deletions generics.go
Expand Up @@ -4,10 +4,35 @@
package swag

import (
"fmt"
"go/ast"
"strings"
)

var genericsDefinitions = map[*TypeSpecDef]map[string]*TypeSpecDef{}

type genericTypeSpec struct {
ArrayDepth int
TypeSpec *TypeSpecDef
Name string
}

func (s *genericTypeSpec) Type() ast.Expr {
if s.TypeSpec != nil {
return s.TypeSpec.TypeSpec.Type
}

return &ast.Ident{Name: s.Name}
}

func (s *genericTypeSpec) TypeDocName() string {
if s.TypeSpec != nil {
return strings.Replace(TypeDocName(s.TypeSpec.FullName(), s.TypeSpec.TypeSpec), "-", "_", -1)
}

return s.Name
}

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
fullName := typeSpecDef.FullName()

Expand All @@ -26,29 +51,47 @@ func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
return fullName
}

func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string) *TypeSpecDef {
genericParams := strings.Split(strings.TrimRight(fullGenericForm, "]"), "[")
if len(genericParams) == 1 {
return nil
func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
if spec, ok := genericsDefinitions[original][fullGenericForm]; ok {
return spec
}

genericParams = strings.Split(genericParams[1], ",")
for i, p := range genericParams {
genericParams[i] = strings.TrimSpace(p)
pkgName := strings.Split(fullGenericForm, ".")[0]
genericTypeName, genericParams := splitStructName(fullGenericForm)
if genericParams == nil {
return nil
}
genericParamTypeDefs := map[string]*TypeSpecDef{}

genericParamTypeDefs := map[string]*genericTypeSpec{}
if len(genericParams) != len(original.TypeSpec.TypeParams.List) {
return nil
}

for i, genericParam := range genericParams {
tdef, ok := pkgDefs.uniqueDefinitions[genericParam]
if !ok {
return nil
arrayDepth := 0
for {
var isArray = len(genericParam) > 2 && genericParam[:2] == "[]"
if isArray {
genericParam = genericParam[2:]
ubogdan marked this conversation as resolved.
Show resolved Hide resolved
arrayDepth++
} else {
break
}
}

genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = tdef
tdef := pkgDefs.FindTypeSpec(genericParam, original.File, parseDependency)
if tdef == nil {
genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = &genericTypeSpec{
ArrayDepth: arrayDepth,
TypeSpec: nil,
Name: genericParam,
ubogdan marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = &genericTypeSpec{
ArrayDepth: arrayDepth,
TypeSpec: tdef,
}
}
}

parametrizedTypeSpec := &TypeSpecDef{
Expand All @@ -66,16 +109,29 @@ func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, ful
Obj: original.TypeSpec.Name.Obj,
}

genNameParts := strings.Split(fullGenericForm, "[")
if strings.Contains(genNameParts[0], ".") {
genNameParts[0] = strings.Split(genNameParts[0], ".")[1]
if strings.Contains(genericTypeName, ".") {
genericTypeName = strings.Split(genericTypeName, ".")[1]
}

ident.Name = genNameParts[0] + "-" + strings.Replace(strings.Join(genericParams, "-"), ".", "_", -1)
ident.Name = strings.Replace(strings.Replace(ident.Name, "\t", "", -1), " ", "", -1)
var typeName = []string{TypeDocName(fullTypeName(pkgName, genericTypeName), parametrizedTypeSpec.TypeSpec)}

parametrizedTypeSpec.TypeSpec.Name = ident
for _, def := range original.TypeSpec.TypeParams.List {
if specDef, ok := genericParamTypeDefs[def.Names[0].Name]; ok {
var prefix = ""
if specDef.ArrayDepth > 0 {
prefix = "array_"
if specDef.ArrayDepth > 1 {
prefix = fmt.Sprintf("array%d_", specDef.ArrayDepth)
}
}
typeName = append(typeName, prefix+specDef.TypeDocName())
}
}

ident.Name = strings.Join(typeName, "-")
ident.Name = string(IgnoreNameOverridePrefix) + strings.Replace(strings.Replace(ident.Name, ".", "_", -1), "_", ".", 1)

parametrizedTypeSpec.TypeSpec.Name = ident
origStructType := original.TypeSpec.Type.(*ast.StructType)

newStructTypeDef := &ast.StructType{
Expand All @@ -101,18 +157,106 @@ func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, ful
}

parametrizedTypeSpec.TypeSpec.Type = newStructTypeDef

if genericsDefinitions[original] == nil {
genericsDefinitions[original] = map[string]*TypeSpecDef{}
}
genericsDefinitions[original][fullGenericForm] = parametrizedTypeSpec
return parametrizedTypeSpec
}

func resolveType(expr ast.Expr, field *ast.Field, genericParamTypeDefs map[string]*TypeSpecDef) ast.Expr {
// splitStructName splits a generic struct name in his parts
func splitStructName(fullGenericForm string) (string, []string) {
// split only at the first '[' and remove the last ']'
genericParams := strings.SplitN(strings.TrimSpace(fullGenericForm)[:len(fullGenericForm)-1], "[", 2)
if len(genericParams) == 1 {
return "", nil
}

// generic type name
genericTypeName := genericParams[0]

// generic params
insideBrackets := 0
lastParam := ""
params := strings.Split(genericParams[1], ",")
genericParams = []string{}
for _, p := range params {
numOpened := strings.Count(p, "[")
numClosed := strings.Count(p, "]")
if numOpened == numClosed && insideBrackets == 0 {
genericParams = append(genericParams, strings.TrimSpace(p))
continue
}

insideBrackets += numOpened - numClosed
lastParam += p + ","

if insideBrackets == 0 {
genericParams = append(genericParams, strings.TrimSpace(strings.TrimRight(lastParam, ",")))
lastParam = ""
}
}

return genericTypeName, genericParams
}

func resolveType(expr ast.Expr, field *ast.Field, genericParamTypeDefs map[string]*genericTypeSpec) ast.Expr {
if asIdent, ok := expr.(*ast.Ident); ok {
if genTypeSpec, ok := genericParamTypeDefs[asIdent.Name]; ok {
return genTypeSpec.TypeSpec.Type
if genTypeSpec.ArrayDepth > 0 {
genTypeSpec.ArrayDepth--
return &ast.ArrayType{Elt: resolveType(expr, field, genericParamTypeDefs)}
}
return genTypeSpec.Type()
}
} else if asArray, ok := expr.(*ast.ArrayType); ok {
return &ast.ArrayType{Elt: resolveType(asArray.Elt, field, genericParamTypeDefs), Len: asArray.Len, Lbrack: asArray.Lbrack}
}

ubogdan marked this conversation as resolved.
Show resolved Hide resolved
return field.Type
}

func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
switch fieldType := field.(type) {
case *ast.IndexListExpr:
spec := &TypeSpecDef{
File: file,
TypeSpec: getGenericTypeSpec(fieldType.X),
PkgPath: file.Name.Name,
}
fullName := spec.FullName() + "["

for _, index := range fieldType.Indices {
var fieldName string
var err error

switch item := index.(type) {
case *ast.ArrayType:
fieldName, err = getFieldType(file, item.Elt)
fieldName = "[]" + fieldName
default:
fieldName, err = getFieldType(file, index)
}

if err != nil {
return "", err
}

fullName += fieldName + ", "
}

return strings.TrimRight(fullName, ", ") + "]", nil
}

return "", fmt.Errorf("unknown field type %#v", field)
}

func getGenericTypeSpec(field ast.Expr) *ast.TypeSpec {
switch indexType := field.(type) {
case *ast.Ident:
return indexType.Obj.Decl.(*ast.TypeSpec)
case *ast.ArrayType:
return indexType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec)
}
return nil
}
11 changes: 10 additions & 1 deletion generics_other.go
Expand Up @@ -3,10 +3,19 @@

package swag

import (
"fmt"
"go/ast"
)

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
return typeSpecDef.FullName()
}

func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string) *TypeSpecDef {
func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
return original
}

func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
return "", fmt.Errorf("unknown field type %#v", field)
}