-
Notifications
You must be signed in to change notification settings - Fork 3
/
sql2code.go
151 lines (130 loc) · 3.6 KB
/
sql2code.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
package sql2code
import (
"errors"
"fmt"
"os"
"github.com/zhufuyi/pkg/sql2code/parser"
)
// Args 参数
type Args struct {
SQL string // DDL sql
DDLFile string // 读取文件的DDL sql
DBDsn string // 从db获取表的DDL sql
DBTable string
Package string // 生成字段的包名(只有model类型有效)
GormType bool // 是否显示gorm type名称(只有model类型代码有效)
JSONTag bool // 是否包括json tag
JSONNamedType int // json命名类型,0:和列名一致,其他值表示驼峰
IsEmbed bool // 是否嵌入gorm.Model
CodeType string // 指定生成代码用途,支持4中类型,分别是 model(默认), json, dao, handler
ForceTableName bool
Charset string
Collation string
TablePrefix string
ColumnPrefix string
NoNullType bool
NullStyle string
}
func (a *Args) checkValid() error {
if a.SQL == "" && a.DDLFile == "" && (a.DBDsn == "" && a.DBTable == "") {
return errors.New("you must specify sql or ddl file")
}
return nil
}
func getSQL(args *Args) (string, error) {
if args.SQL != "" {
return args.SQL, nil
}
sql := ""
if args.DDLFile != "" {
b, err := os.ReadFile(args.DDLFile)
if err != nil {
return sql, fmt.Errorf("read %s failed, %s", args.DDLFile, err)
}
return string(b), nil
} else if args.DBDsn != "" {
if args.DBTable == "" {
return sql, errors.New("miss mysql table")
}
sqlStr, err := parser.GetTableInfo(args.DBDsn, args.DBTable)
if err != nil {
return sql, err
}
return sqlStr, nil
}
return sql, errors.New("no SQL input(-sql|-f|-db-dsn)")
}
func getOptions(args *Args) []parser.Option {
var opts []parser.Option
if args.Charset != "" {
opts = append(opts, parser.WithCharset(args.Charset))
}
if args.Collation != "" {
opts = append(opts, parser.WithCollation(args.Collation))
}
if args.JSONTag {
opts = append(opts, parser.WithJSONTag(args.JSONNamedType))
}
if args.TablePrefix != "" {
opts = append(opts, parser.WithTablePrefix(args.TablePrefix))
}
if args.ColumnPrefix != "" {
opts = append(opts, parser.WithColumnPrefix(args.ColumnPrefix))
}
if args.NoNullType {
opts = append(opts, parser.WithNoNullType())
}
if args.IsEmbed {
opts = append(opts, parser.WithEmbed())
}
if args.NullStyle != "" {
switch args.NullStyle {
case "sql":
opts = append(opts, parser.WithNullStyle(parser.NullInSql))
case "ptr":
opts = append(opts, parser.WithNullStyle(parser.NullInPointer))
default:
fmt.Printf("invalid null style: %s\n", args.NullStyle)
return nil
}
} else {
opts = append(opts, parser.WithNullStyle(parser.NullDisable))
}
if args.Package != "" {
opts = append(opts, parser.WithPackage(args.Package))
}
if args.GormType {
opts = append(opts, parser.WithGormType())
}
if args.ForceTableName {
opts = append(opts, parser.WithForceTableName())
}
return opts
}
// GenerateOne 根据sql生成gorm代码,sql可以从参数、文件、db三种方式获取,优先从高到低
func GenerateOne(args *Args) (string, error) {
codes, err := Generate(args)
if err != nil {
return "", err
}
if args.CodeType == "" {
args.CodeType = parser.CodeTypeModel // 默认为model code
}
out, ok := codes[args.CodeType]
if !ok {
return "", fmt.Errorf("unknown code type %s", args.CodeType)
}
return out, nil
}
// Generate 生成model, json, dao, handler不同用途代码
func Generate(args *Args) (map[string]string, error) {
if err := args.checkValid(); err != nil {
return nil, err
}
sql, err := getSQL(args)
if err != nil {
return nil, err
}
opt := getOptions(args)
return parser.ParseSQL(sql, opt...)
}