/
main.go
143 lines (128 loc) · 3.05 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package main
import (
_ "embed"
"flag"
"github.com/reddec/rpc/internal/compile"
"go/token"
"go/types"
"golang.org/x/tools/go/packages"
"os"
"path/filepath"
"strconv"
"strings"
"text/template"
)
//go:embed ts.gotemplate
var templateText string
func main() {
lineNum, err := strconv.Atoi(os.Getenv("GOLINE"))
if err != nil {
panic("GOLINE env incorrect")
}
fileName, err := filepath.Abs(os.Getenv("GOFILE"))
if err != nil {
panic(err)
}
packageName := os.Getenv("GOPACKAGE")
pkgs, err := packages.Load(&packages.Config{
Mode: packages.NeedTypes | packages.NeedImports | packages.NeedName | packages.NeedSyntax,
})
if err != nil {
panic(err)
}
// find correct package
var pkg *packages.Package
for _, p := range pkgs {
if p.Name == packageName {
pkg = p
break
}
}
if pkg == nil {
panic("unknown package " + packageName)
}
scope := pkg.Types.Scope()
var typeName string
for _, name := range scope.Names() {
tp := scope.Lookup(name)
pos := pkg.Fset.Position(tp.Pos())
if pos.Filename == fileName && pos.Line == lineNum+1 {
typeName = name
break
}
}
if typeName == "" {
panic("directive should be on top of struct declaration")
}
output := flag.String("out", strings.ToLower(typeName)+".ts", "Output file")
shim := flag.String("shim", "", "Comma-separated list of TS types shim (ex: github.com/jackc/pgtype.JSONB:any")
flag.Parse()
obj := scope.Lookup(typeName)
if obj == nil {
panic("typename not found")
}
base := obj.Type().(*types.Named)
tpl := getTemplate()
var tl = compile.New()
for _, opt := range strings.Split(*shim, ",") {
sourceType, tsType, ok := strings.Cut(opt, ":")
if !ok {
continue
}
tl.Custom(sourceType, compile.TSVar{Type: tsType})
}
tl.CommentLookup(func(pos token.Pos) string {
rp := pkg.Fset.Position(pos)
prevLine := pkg.Fset.File(pos).Pos(rp.Offset - rp.Column - 1)
for _, s := range pkg.Syntax {
for _, g := range s.Comments {
if prevLine >= g.Pos() && prevLine <= g.End() {
return strings.TrimSpace(g.Text())
}
}
}
return ""
})
api := tl.ScanAPI(base)
vc := viewContext{
API: api,
Objects: tl.Objects(),
Aliases: tl.Aliases(),
}
// save
if err := os.MkdirAll(filepath.Dir(*output), 0755); err != nil {
panic(err)
}
f, err := os.Create(*output)
if err != nil {
panic(err)
}
defer f.Close()
if err := tpl.Execute(f, &vc); err != nil {
panic(err)
}
}
type viewContext struct {
API compile.API
Objects map[string][]compile.Param
Aliases map[string]compile.Type
}
func getTemplate() *template.Template {
return template.Must(template.New("").Funcs(map[string]any{
"join": func(sep string, list []string) string { return strings.Join(list, sep) },
"comment": func(ident int, text string) string {
if text == "" {
return ""
}
var ans []string
for _, line := range strings.Split(text, "\n") {
ans = append(ans, "// "+line)
}
if len(ans) == 0 {
return ""
}
return strings.Join(ans, "\n"+strings.Repeat(" ", ident))
},
"lower": strings.ToLower,
}).Delims("[[", "]]").Parse(templateText))
}