Skip to content

Commit

Permalink
Update for Go 1.19 and support generics.
Browse files Browse the repository at this point in the history
Update go.mod to indicate support for Go 1.19.

Add support for generic interfaces, i.e. interfaces that accept a type
parameter.

For example, you can now have this kind of interface:

```go
type Interface[Kind any] interface {
	DoTheThing() Kind
}
```

and if you run `impl 's StringImpl' 'Interface[string]` it will generate
the following code:

```go
func (s StringImpl) DoTheThing() string {
	// normal impl stub here
}
```

Fixes josharian#44.
  • Loading branch information
paddycarver committed Dec 19, 2022
1 parent 2363157 commit b07c7ab
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 43 deletions.
7 changes: 6 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
module github.com/josharian/impl

go 1.14
go 1.19

require golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375

require (
golang.org/x/mod v0.2.0 // indirect
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 // indirect
)
145 changes: 117 additions & 28 deletions impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,34 @@ var (
flagRecvPkg = flag.String("recvpkg", "", "package name of the receiver")
)

func parseTypeParams(in string) (string, []string, error) {
firstOpenBracket := strings.Index(in, "[")
if firstOpenBracket < 0 {
return in, []string{}, nil
}
// there are type parameters in our interface
id := in[:firstOpenBracket]
firstCloseBracket := strings.LastIndex(in, "]")
if firstCloseBracket < 0 {
// make sure we're closing our list of type parameters
return "", nil, fmt.Errorf("invalid interface name (cannot have [ without ]): %s", in)
}
if firstCloseBracket != len(in)-1 {
// make sure the first close bracket is actually the last character of the interface name
return "", nil, fmt.Errorf("invalid interface name (cannot have ] anywhere except the last character): %s", in)
}
params := strings.Split(in[firstOpenBracket+1:firstCloseBracket], ",")
typeParams := make([]string, 0, len(params))
for _, param := range params {
typeParams = append(typeParams, strings.TrimSpace(param))
}
if len(typeParams) < 1 {
// make sure if we're declaring type parameters, we declare at least one
return "", nil, fmt.Errorf("invalid interface name (cannot have empty type parameters): %s", in)
}
return id, typeParams, nil
}

// findInterface returns the import path and identifier of an interface.
// For example, given "http.ResponseWriter", findInterface returns
// "net/http", "ResponseWriter".
Expand All @@ -34,9 +62,9 @@ var (
// If an unqualified interface such as "UserDefinedInterface" is given, then
// the interface definition is presumed to be in the package within srcDir and
// findInterface returns "", "UserDefinedInterface".
func findInterface(iface string, srcDir string) (path string, id string, err error) {
if len(strings.Fields(iface)) != 1 {
return "", "", fmt.Errorf("couldn't parse interface: %s", iface)
func findInterface(iface string, srcDir string) (path string, id string, typeParams []string, err error) {
if len(strings.Fields(iface)) != 1 && !strings.Contains(iface, "[") {
return "", "", nil, fmt.Errorf("couldn't parse interface: %s", iface)
}

srcPath := filepath.Join(srcDir, "__go_impl__.go")
Expand All @@ -46,25 +74,31 @@ func findInterface(iface string, srcDir string) (path string, id string, err err
dot := strings.LastIndex(iface, ".")
// make sure iface does not end with "/" (e.g. reject net/http/)
if slash+1 == len(iface) {
return "", "", fmt.Errorf("interface name cannot end with a '/' character: %s", iface)
return "", "", nil, fmt.Errorf("interface name cannot end with a '/' character: %s", iface)
}
// make sure iface does not end with "." (e.g. reject net/http.)
if dot+1 == len(iface) {
return "", "", fmt.Errorf("interface name cannot end with a '.' character: %s", iface)
return "", "", nil, fmt.Errorf("interface name cannot end with a '.' character: %s", iface)
}
// make sure iface has at least one "." after "/" (e.g. reject net/http/httputil)
if strings.Count(iface[slash:], ".") == 0 {
return "", "", fmt.Errorf("invalid interface name: %s", iface)
return "", "", nil, fmt.Errorf("invalid interface name: %s", iface)
}
path = iface[:dot]
id = iface[dot+1:]
id, typeParams, err = parseTypeParams(id)
if err != nil {
return "", "", nil, err
}
return iface[:dot], iface[dot+1:], nil
return path, id, typeParams, nil
}

src := []byte("package hack\n" + "var i " + iface)
// If we couldn't determine the import path, goimports will
// auto fix the import path.
imp, err := imports.Process(srcPath, src, nil)
if err != nil {
return "", "", fmt.Errorf("couldn't parse interface: %s", iface)
return "", "", nil, fmt.Errorf("couldn't parse interface: %s", iface)
}

// imp should now contain an appropriate import.
Expand All @@ -78,7 +112,7 @@ func findInterface(iface string, srcDir string) (path string, id string, err err
qualified := strings.Contains(iface, ".")

if len(f.Imports) == 0 && qualified {
return "", "", fmt.Errorf("unrecognized interface: %s", iface)
return "", "", nil, fmt.Errorf("unrecognized interface: %s", iface)
}

if !qualified {
Expand All @@ -89,10 +123,22 @@ func findInterface(iface string, srcDir string) (path string, id string, err err
// var i Reader
decl := f.Decls[0].(*ast.GenDecl) // var i io.Reader
spec := decl.Specs[0].(*ast.ValueSpec) // i io.Reader
sel := spec.Type.(*ast.Ident)
id = sel.Name // Reader
if indxExpr, ok := spec.Type.(*ast.IndexExpr); ok {
// a generic type with one type parameter shows up as an IndexExpr
id = indxExpr.X.(*ast.Ident).Name
typeParams = append(typeParams, indxExpr.Index.(*ast.Ident).Name)
} else if indxListExpr, ok := spec.Type.(*ast.IndexListExpr); ok {
// a generic type with multiple type parameters shows up as an IndexListExpr
id = indxListExpr.X.(*ast.Ident).Name
for _, typeParam := range indxListExpr.Indices {
typeParams = append(typeParams, typeParam.(*ast.Ident).Name)
}
} else {
sel := spec.Type.(*ast.Ident)
id = sel.Name // Reader
}

return path, id, nil
return path, id, typeParams, nil
}

// If qualified, the code looks like:
Expand All @@ -111,10 +157,22 @@ func findInterface(iface string, srcDir string) (path string, id string, err err
}
decl := f.Decls[1].(*ast.GenDecl) // var i io.Reader
spec := decl.Specs[0].(*ast.ValueSpec) // i io.Reader
sel := spec.Type.(*ast.SelectorExpr) // io.Reader
id = sel.Sel.Name // Reader
if indxExpr, ok := spec.Type.(*ast.IndexExpr); ok {
// a generic type with one type parameter shows up as an IndexExpr
id = indxExpr.X.(*ast.SelectorExpr).Sel.Name
typeParams = append(typeParams, indxExpr.Index.(*ast.Ident).Name)
} else if indxListExpr, ok := spec.Type.(*ast.IndexListExpr); ok {
// a generic type with multiple type parameters shows up as an IndexListExpr
id = indxListExpr.X.(*ast.SelectorExpr).Sel.Name
for _, typeParam := range indxListExpr.Indices {
typeParams = append(typeParams, typeParam.(*ast.Ident).Name)
}
} else {
sel := spec.Type.(*ast.SelectorExpr) // io.Reader
id = sel.Sel.Name // Reader
}

return path, id, nil
return path, id, typeParams, nil
}

// Pkg is a parsed build.Package.
Expand All @@ -129,10 +187,11 @@ type Pkg struct {
type Spec struct {
*ast.TypeSpec
ast.CommentMap
TypeParams map[string]string
}

// typeSpec locates the *ast.TypeSpec for type id in the import path.
func typeSpec(path string, id string, srcDir string) (Pkg, Spec, error) {
func typeSpec(path, id string, typeParams []string, srcDir string) (Pkg, Spec, error) {
var pkg *build.Package
var err error

Expand Down Expand Up @@ -161,17 +220,39 @@ func typeSpec(path string, id string, srcDir string) (Pkg, Spec, error) {
cmap := ast.NewCommentMap(fset, f, f.Comments)

for _, decl := range f.Decls {
decl, ok := decl.(*ast.GenDecl)
if !ok || decl.Tok != token.TYPE {
genDecl, ok := decl.(*ast.GenDecl)
if !ok {
continue
}
decl := genDecl
if decl.Tok != token.TYPE {
continue
}
for _, spec := range decl.Specs {
spec := spec.(*ast.TypeSpec)
if spec.Name.Name != id {
continue
}
tParams := make(map[string]string, len(typeParams))
if spec.TypeParams != nil {
var specParamNames []string
for _, typeParam := range spec.TypeParams.List {
for _, name := range typeParam.Names {
if name == nil {
continue
}
specParamNames = append(specParamNames, name.Name)
}
}
if len(specParamNames) != len(typeParams) {
continue
}
for pos, specParamName := range specParamNames {
tParams[specParamName] = typeParams[pos]
}
}
p := Pkg{Package: pkg, FileSet: fset}
s := Spec{TypeSpec: spec, CommentMap: cmap.Filter(decl)}
s := Spec{TypeSpec: spec, CommentMap: cmap.Filter(decl), TypeParams: tParams}
return p, s, nil
}
}
Expand Down Expand Up @@ -212,9 +293,17 @@ func (p Pkg) fullType(e ast.Expr) string {
return p.gofmt(e)
}

func (p Pkg) params(field *ast.Field) []Param {
func (p Pkg) params(field *ast.Field, genericTypes map[string]string) []Param {
var params []Param
typ := p.fullType(field.Type)
var typ string
ident, ok := field.Type.(*ast.Ident)
if !ok || ident == nil {
typ = p.fullType(field.Type)
} else if genType, ok := genericTypes[ident.Name]; ok {
typ = genType
} else {
typ = p.fullType(field.Type)
}
for _, name := range field.Names {
params = append(params, Param{Name: name.Name, Type: typ})
}
Expand Down Expand Up @@ -253,12 +342,12 @@ const (
WithoutComments EmitComments = false
)

func (p Pkg) funcsig(f *ast.Field, cmap ast.CommentMap, comments EmitComments) Func {
func (p Pkg) funcsig(f *ast.Field, genericParams map[string]string, cmap ast.CommentMap, comments EmitComments) Func {
fn := Func{Name: f.Names[0].Name}
typ := f.Type.(*ast.FuncType)
if typ.Params != nil {
for _, field := range typ.Params.List {
for _, param := range p.params(field) {
for _, param := range p.params(field, genericParams) {
// only for method parameters:
// assign a blank identifier "_" to an anonymous parameter
if param.Name == "" {
Expand All @@ -270,7 +359,7 @@ func (p Pkg) funcsig(f *ast.Field, cmap ast.CommentMap, comments EmitComments) F
}
if typ.Results != nil {
for _, field := range typ.Results.List {
fn.Res = append(fn.Res, p.params(field)...)
fn.Res = append(fn.Res, p.params(field, genericParams)...)
}
}
if commentsBefore(f, cmap.Comments()) && comments == WithComments {
Expand All @@ -295,13 +384,13 @@ func funcs(iface, srcDir, recvPkg string, comments EmitComments) ([]Func, error)
}

// Locate the interface.
path, id, err := findInterface(iface, srcDir)
path, id, typeParams, err := findInterface(iface, srcDir)
if err != nil {
return nil, err
}

// Parse the package and find the interface declaration.
p, spec, err := typeSpec(path, id, srcDir)
p, spec, err := typeSpec(path, id, typeParams, srcDir)
if err != nil {
return nil, fmt.Errorf("interface %s not found: %s", iface, err)
}
Expand All @@ -328,7 +417,7 @@ func funcs(iface, srcDir, recvPkg string, comments EmitComments) ([]Func, error)
continue
}

fn := p.funcsig(fndecl, spec.CommentMap.Filter(fndecl), comments)
fn := p.funcsig(fndecl, spec.TypeParams, spec.CommentMap.Filter(fndecl), comments)
fns = append(fns, fn)
}
return fns, nil
Expand Down Expand Up @@ -476,7 +565,7 @@ to prevent shell globbing.
recvs := strings.Fields(recv)
receiver := recvs[len(recvs)-1] // note that this correctly handles "s *Struct" and "*Struct"
receiver = strings.TrimPrefix(receiver, "*")
pkg, _, err := typeSpec("", receiver, *flagSrcDir)
pkg, _, err := typeSpec("", receiver, nil, *flagSrcDir)
if err == nil {
recvPkg = pkg.Package.Name
}
Expand Down
Loading

0 comments on commit b07c7ab

Please sign in to comment.