This repository has been archived by the owner on Feb 25, 2023. It is now read-only.
/
util.go
101 lines (84 loc) · 2.42 KB
/
util.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
package cursor
import (
"errors"
"go/ast"
"go/types"
"os"
"strings"
builders "github.com/tdakkota/astbuilders"
"golang.org/x/tools/go/packages"
)
const pkg = "github.com/tdakkota/cursor"
// createFunction is a generic helper for function creation.
func createFunction(name string, typ ast.Expr, bodyFunc builders.BodyFunc) builders.FunctionBuilder {
selector := ast.NewIdent("m")
return builders.NewFunctionBuilder(name).
Recv(&ast.Field{
Names: []*ast.Ident{selector},
Type: typ,
}).
AddParameters([]*ast.Field{
{
Names: []*ast.Ident{ast.NewIdent("cur")},
Type: builders.RefFor(builders.SelectorName("cursor", "Cursor")),
},
}...).
AddResults([]*ast.Field{
{
Names: []*ast.Ident{ast.NewIdent("err")},
Type: ast.NewIdent("error"),
},
}...).
Body(bodyFunc)
}
// ErrFailedToFindCursor reports that github.com/tdakkota/cursor import failed.
var ErrFailedToFindCursor = errors.New("import cursor package")
func load(pkg string) ([]*packages.Package, error) {
cfg := &packages.Config{
Mode: packages.NeedTypes | packages.NeedImports,
Env: os.Environ(),
}
return packages.Load(cfg, pkg)
}
func target(pkgs []*packages.Package, name string) (*types.Interface, error) {
for _, pkg := range pkgs {
obj := pkg.Types.Scope().Lookup(name)
if obj == nil {
continue
}
i, ok := obj.Type().(*types.Named)
if !ok {
return nil, ErrFailedToFindCursor
}
return i.Underlying().(*types.Interface), nil
}
return nil, ErrFailedToFindCursor
}
func checkErr(s builders.StatementBuilder) builders.StatementBuilder {
nilIdent := ast.NewIdent("nil")
errIdent := ast.NewIdent("err")
cond := builders.NotEq(errIdent, nilIdent)
return s.If(nil, cond, func(ifBody builders.StatementBuilder) builders.StatementBuilder {
return ifBody.Return(errIdent)
})
}
func callCurFunc(selector ast.Expr, name string) (*ast.BlockStmt, error) {
s := builders.NewStatementBuilder()
sel := builders.Selector(selector, ast.NewIdent(name))
s = s.Define(ast.NewIdent("err"))(builders.Call(sel, ast.NewIdent("cur")))
s = checkErr(s)
return s.CompleteAsBlock(), nil
}
func elemType(pkg *types.Package, elem types.Type) ast.Expr {
typ := types.TypeString(elem, func(i *types.Package) string {
if i.Path() != pkg.Path() {
return i.Name()
}
return ""
})
split := strings.Split(typ, ".")
if len(split) > 1 {
return builders.SelectorName(split[0], split[1], split[2:]...)
}
return ast.NewIdent(split[0])
}