Skip to content

Commit

Permalink
Update for Go 1.19 and support generics.
Browse files Browse the repository at this point in the history
Update go.mod to indicate support for Go 1.19.

Add support for generic interfaces, i.e. interfaces that accept a type
parameter.

For example, you can now have this kind of interface:

```go
type Interface[Kind any] interface {
	DoTheThing() Kind
}
```

and if you run `impl 's StringImpl' 'Interface[string]` it will generate
the following code:

```go
func (s StringImpl) DoTheThing() string {
	// normal impl stub here
}
```

Fixes josharian#44.
  • Loading branch information
paddycarver committed Jan 16, 2024
1 parent 30a6beb commit 7f70e7c
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 92 deletions.
8 changes: 5 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
module github.com/josharian/impl

go 1.14
go 1.19

require golang.org/x/tools v0.4.0

require (
golang.org/x/tools v0.4.0
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
golang.org/x/mod v0.7.0 // indirect
golang.org/x/sys v0.3.0 // indirect
)
42 changes: 0 additions & 42 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,48 +1,6 @@
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/mod v0.2.0 h1:KU7oHjnv3XNWfa5COkzUifxZmxp1TyI7ImMXqFxLwvQ=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA=
golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ=
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375 h1:SjQ2+AKWgZLc1xej6WSzL+Dfs5Uyd5xcZH1mGC411IA=
golang.org/x/tools v0.0.0-20200522201501-cb1345f3a375/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.4.0 h1:7mTAgkunk3fr4GAloyyCasadO6h9zSsQZbwvcaIciV4=
golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk=
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8=
158 changes: 127 additions & 31 deletions impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,34 @@ var (
flagRecvPkg = flag.String("recvpkg", "", "package name of the receiver")
)

func parseTypeParams(in string) (string, []string, error) {
firstOpenBracket := strings.Index(in, "[")
if firstOpenBracket < 0 {
return in, []string{}, nil
}
// there are type parameters in our interface
id := in[:firstOpenBracket]
firstCloseBracket := strings.LastIndex(in, "]")
if firstCloseBracket < 0 {
// make sure we're closing our list of type parameters
return "", nil, fmt.Errorf("invalid interface name (cannot have [ without ]): %s", in)
}
if firstCloseBracket != len(in)-1 {
// make sure the first close bracket is actually the last character of the interface name
return "", nil, fmt.Errorf("invalid interface name (cannot have ] anywhere except the last character): %s", in)
}
params := strings.Split(in[firstOpenBracket+1:firstCloseBracket], ",")
typeParams := make([]string, 0, len(params))
for _, param := range params {
typeParams = append(typeParams, strings.TrimSpace(param))
}
if len(typeParams) < 1 {
// make sure if we're declaring type parameters, we declare at least one
return "", nil, fmt.Errorf("invalid interface name (cannot have empty type parameters): %s", in)
}
return id, typeParams, nil
}

// findInterface returns the import path and identifier of an interface.
// For example, given "http.ResponseWriter", findInterface returns
// "net/http", "ResponseWriter".
Expand All @@ -34,9 +62,9 @@ var (
// If an unqualified interface such as "UserDefinedInterface" is given, then
// the interface definition is presumed to be in the package within srcDir and
// findInterface returns "", "UserDefinedInterface".
func findInterface(iface string, srcDir string) (path string, id string, err error) {
if len(strings.Fields(iface)) != 1 {
return "", "", fmt.Errorf("couldn't parse interface: %s", iface)
func findInterface(iface string, srcDir string) (path string, id string, typeParams []string, err error) {
if len(strings.Fields(iface)) != 1 && !strings.Contains(iface, "[") {
return "", "", nil, fmt.Errorf("couldn't parse interface: %s", iface)
}

srcPath := filepath.Join(srcDir, "__go_impl__.go")
Expand All @@ -46,25 +74,31 @@ func findInterface(iface string, srcDir string) (path string, id string, err err
dot := strings.LastIndex(iface, ".")
// make sure iface does not end with "/" (e.g. reject net/http/)
if slash+1 == len(iface) {
return "", "", fmt.Errorf("interface name cannot end with a '/' character: %s", iface)
return "", "", nil, fmt.Errorf("interface name cannot end with a '/' character: %s", iface)
}
// make sure iface does not end with "." (e.g. reject net/http.)
if dot+1 == len(iface) {
return "", "", fmt.Errorf("interface name cannot end with a '.' character: %s", iface)
return "", "", nil, fmt.Errorf("interface name cannot end with a '.' character: %s", iface)
}
// make sure iface has at least one "." after "/" (e.g. reject net/http/httputil)
if strings.Count(iface[slash:], ".") == 0 {
return "", "", fmt.Errorf("invalid interface name: %s", iface)
return "", "", nil, fmt.Errorf("invalid interface name: %s", iface)
}
path = iface[:dot]
id = iface[dot+1:]
id, typeParams, err = parseTypeParams(id)
if err != nil {
return "", "", nil, err
}
return iface[:dot], iface[dot+1:], nil
return path, id, typeParams, nil
}

src := []byte("package hack\n" + "var i " + iface)
// If we couldn't determine the import path, goimports will
// auto fix the import path.
imp, err := imports.Process(srcPath, src, nil)
if err != nil {
return "", "", fmt.Errorf("couldn't parse interface: %s", iface)
return "", "", nil, fmt.Errorf("couldn't parse interface: %s", iface)
}

// imp should now contain an appropriate import.
Expand All @@ -78,7 +112,7 @@ func findInterface(iface string, srcDir string) (path string, id string, err err
qualified := strings.Contains(iface, ".")

if len(f.Imports) == 0 && qualified {
return "", "", fmt.Errorf("unrecognized interface: %s", iface)
return "", "", nil, fmt.Errorf("unrecognized interface: %s", iface)
}

if !qualified {
Expand All @@ -89,10 +123,22 @@ func findInterface(iface string, srcDir string) (path string, id string, err err
// var i Reader
decl := f.Decls[0].(*ast.GenDecl) // var i io.Reader
spec := decl.Specs[0].(*ast.ValueSpec) // i io.Reader
sel := spec.Type.(*ast.Ident)
id = sel.Name // Reader
if indxExpr, ok := spec.Type.(*ast.IndexExpr); ok {
// a generic type with one type parameter shows up as an IndexExpr
id = indxExpr.X.(*ast.Ident).Name
typeParams = append(typeParams, indxExpr.Index.(*ast.Ident).Name)
} else if indxListExpr, ok := spec.Type.(*ast.IndexListExpr); ok {
// a generic type with multiple type parameters shows up as an IndexListExpr
id = indxListExpr.X.(*ast.Ident).Name
for _, typeParam := range indxListExpr.Indices {
typeParams = append(typeParams, typeParam.(*ast.Ident).Name)
}
} else {
sel := spec.Type.(*ast.Ident)
id = sel.Name // Reader
}

return path, id, nil
return path, id, typeParams, nil
}

// If qualified, the code looks like:
Expand All @@ -111,10 +157,22 @@ func findInterface(iface string, srcDir string) (path string, id string, err err
}
decl := f.Decls[1].(*ast.GenDecl) // var i io.Reader
spec := decl.Specs[0].(*ast.ValueSpec) // i io.Reader
sel := spec.Type.(*ast.SelectorExpr) // io.Reader
id = sel.Sel.Name // Reader
if indxExpr, ok := spec.Type.(*ast.IndexExpr); ok {
// a generic type with one type parameter shows up as an IndexExpr
id = indxExpr.X.(*ast.SelectorExpr).Sel.Name
typeParams = append(typeParams, indxExpr.Index.(*ast.Ident).Name)
} else if indxListExpr, ok := spec.Type.(*ast.IndexListExpr); ok {
// a generic type with multiple type parameters shows up as an IndexListExpr
id = indxListExpr.X.(*ast.SelectorExpr).Sel.Name
for _, typeParam := range indxListExpr.Indices {
typeParams = append(typeParams, typeParam.(*ast.Ident).Name)
}
} else {
sel := spec.Type.(*ast.SelectorExpr) // io.Reader
id = sel.Sel.Name // Reader
}

return path, id, nil
return path, id, typeParams, nil
}

// Pkg is a parsed build.Package.
Expand All @@ -125,20 +183,27 @@ type Pkg struct {
recvPkg string
}

// Spec is ast.TypeSpec with the associated comment map.
type Spec struct {
*ast.TypeSpec
ast.CommentMap
TypeParams map[string]string
}

// typeSpec locates the *ast.TypeSpec for type id in the import path.
func typeSpec(path string, id string, srcDir string) (Pkg, *ast.TypeSpec, error) {
func typeSpec(path, id string, typeParams []string, srcDir string) (Pkg, Spec, error) {
var pkg *build.Package
var err error

if path == "" {
pkg, err = build.ImportDir(srcDir, 0)
if err != nil {
return Pkg{}, nil, fmt.Errorf("couldn't find package in %s: %v", srcDir, err)
return Pkg{}, Spec{}, fmt.Errorf("couldn't find package in %s: %v", srcDir, err)
}
} else {
pkg, err = build.Import(path, srcDir, 0)
if err != nil {
return Pkg{}, nil, fmt.Errorf("couldn't find package %s: %v", path, err)
return Pkg{}, Spec{}, fmt.Errorf("couldn't find package %s: %v", path, err)
}
}

Expand All @@ -153,21 +218,44 @@ func typeSpec(path string, id string, srcDir string) (Pkg, *ast.TypeSpec, error)
}

for _, decl := range f.Decls {
decl, ok := decl.(*ast.GenDecl)
if !ok || decl.Tok != token.TYPE {
genDecl, ok := decl.(*ast.GenDecl)
if !ok {
continue
}
decl := genDecl
if decl.Tok != token.TYPE {
continue
}
for _, spec := range decl.Specs {
spec := spec.(*ast.TypeSpec)
if spec.Name.Name != id {
continue
}
tParams := make(map[string]string, len(typeParams))
if spec.TypeParams != nil {
var specParamNames []string
for _, typeParam := range spec.TypeParams.List {
for _, name := range typeParam.Names {
if name == nil {
continue
}
specParamNames = append(specParamNames, name.Name)
}
}
if len(specParamNames) != len(typeParams) {
continue
}
for pos, specParamName := range specParamNames {
tParams[specParamName] = typeParams[pos]
}
}
p := Pkg{Package: pkg, FileSet: fset}
return p, spec, nil
s := Spec{TypeSpec: spec, TypeParams: tParams}
return p, s, nil
}
}
}
return Pkg{}, nil, fmt.Errorf("type %s not found in %s", id, path)
return Pkg{}, Spec{}, fmt.Errorf("type %s not found in %s", id, path)
}

// gofmt pretty-prints e.
Expand Down Expand Up @@ -203,9 +291,17 @@ func (p Pkg) fullType(e ast.Expr) string {
return p.gofmt(e)
}

func (p Pkg) params(field *ast.Field) []Param {
func (p Pkg) params(field *ast.Field, genericTypes map[string]string) []Param {
var params []Param
typ := p.fullType(field.Type)
var typ string
ident, ok := field.Type.(*ast.Ident)
if !ok || ident == nil {
typ = p.fullType(field.Type)
} else if genType, ok := genericTypes[ident.Name]; ok {
typ = genType
} else {
typ = p.fullType(field.Type)
}
for _, name := range field.Names {
params = append(params, Param{Name: name.Name, Type: typ})
}
Expand Down Expand Up @@ -244,12 +340,12 @@ const (
WithoutComments EmitComments = false
)

func (p Pkg) funcsig(f *ast.Field, comments EmitComments) Func {
func (p Pkg) funcsig(f *ast.Field, genericParams map[string]string, cmap ast.CommentMap, comments EmitComments) Func {
fn := Func{Name: f.Names[0].Name}
typ := f.Type.(*ast.FuncType)
if typ.Params != nil {
for _, field := range typ.Params.List {
for _, param := range p.params(field) {
for _, param := range p.params(field, genericParams) {
// only for method parameters:
// assign a blank identifier "_" to an anonymous parameter
if param.Name == "" {
Expand All @@ -261,7 +357,7 @@ func (p Pkg) funcsig(f *ast.Field, comments EmitComments) Func {
}
if typ.Results != nil {
for _, field := range typ.Results.List {
fn.Res = append(fn.Res, p.params(field)...)
fn.Res = append(fn.Res, p.params(field, genericParams)...)
}
}
if comments == WithComments && f.Doc != nil {
Expand All @@ -286,13 +382,13 @@ func funcs(iface, srcDir, recvPkg string, comments EmitComments) ([]Func, error)
}

// Locate the interface.
path, id, err := findInterface(iface, srcDir)
path, id, typeParams, err := findInterface(iface, srcDir)
if err != nil {
return nil, err
}

// Parse the package and find the interface declaration.
p, spec, err := typeSpec(path, id, srcDir)
p, spec, err := typeSpec(path, id, typeParams, srcDir)
if err != nil {
return nil, fmt.Errorf("interface %s not found: %s", iface, err)
}
Expand All @@ -319,7 +415,7 @@ func funcs(iface, srcDir, recvPkg string, comments EmitComments) ([]Func, error)
continue
}

fn := p.funcsig(fndecl, comments)
fn := p.funcsig(fndecl, spec.TypeParams, spec.CommentMap.Filter(fndecl), comments)
fns = append(fns, fn)
}
return fns, nil
Expand Down Expand Up @@ -450,7 +546,7 @@ to prevent shell globbing.
recvs := strings.Fields(recv)
receiver := recvs[len(recvs)-1] // note that this correctly handles "s *Struct" and "*Struct"
receiver = strings.TrimPrefix(receiver, "*")
pkg, _, err := typeSpec("", receiver, *flagSrcDir)
pkg, _, err := typeSpec("", receiver, nil, *flagSrcDir)
if err == nil {
recvPkg = pkg.Package.Name
}
Expand Down
Loading

0 comments on commit 7f70e7c

Please sign in to comment.