Skip to content

Commit

Permalink
refactor: migration
Browse files Browse the repository at this point in the history
  • Loading branch information
si3nloong committed Jul 1, 2024
1 parent 8b738b1 commit b0225ee
Show file tree
Hide file tree
Showing 52 changed files with 642 additions and 564 deletions.
12 changes: 7 additions & 5 deletions codegen/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,14 @@ func (b tableInfo) Keys() []sequel.ColumnSchema {
return b.keys
}

func (b tableInfo) Columns() []sequel.ColumnSchema {
func (b tableInfo) Indexes() []sequel.IndexSchema {
return nil
}

func (b tableInfo) Columns() []sequel.ColumnSchema {
return b.columns
}

func (b tableInfo) Implements(T *types.Interface) (*types.Func, bool) {
return types.MissingMethod(b.t, T, true)
}
Expand All @@ -163,7 +167,7 @@ var (
_ (sequel.ColumnSchema) = (*columnInfo)(nil)
)

func (i columnInfo) SQLValuer() sequel.SQLFunc {
func (i columnInfo) SQLValuer() sequel.QueryFunc {
if i.model == nil {
return nil
}
Expand All @@ -172,7 +176,7 @@ func (i columnInfo) SQLValuer() sequel.SQLFunc {
}
}

func (i columnInfo) SQLScanner() sequel.SQLFunc {
func (i columnInfo) SQLScanner() sequel.QueryFunc {
if i.model == nil {
return nil
}
Expand Down Expand Up @@ -326,7 +330,6 @@ func Generate(c *config.Config) error {
// true,
"",
cfg.Database.Package,
cfg.Getter.Prefix,
cfg.Database.Dir,
cfg.Database.Filename,
); err != nil {
Expand All @@ -341,7 +344,6 @@ func Generate(c *config.Config) error {
"operator.go.tpl",
"",
cfg.Database.Operator.Package,
cfg.Getter.Prefix,
cfg.Database.Operator.Dir,
cfg.Database.Operator.Filename,
); err != nil {
Expand Down
1 change: 0 additions & 1 deletion codegen/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ func renderTemplate(
tmplName string,
pkgPath string,
pkgName string,
getter string,
dstDir string,
dstFilename string,
) error {
Expand Down
55 changes: 30 additions & 25 deletions codegen/templates/db.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

type autoIncrKeyInserter interface {
sequel.AutoIncrKeyer
sequel.SingleInserter
sequel.PrimaryKeyer
}

{{ if eq driver "postgres" -}}
{{- /* postgres */ -}}
func InsertOne[T sequel.TableColumnValuer[T], Ptr interface {
sequel.TableColumnValuer[T]
func InsertOne[T sequel.TableColumnValuer, Ptr interface {
sequel.TableColumnValuer
sequel.PtrScanner[T]
}](ctx context.Context, sqlConn sequel.DB, model Ptr) error {
switch v := any(model).(type) {
Expand All @@ -40,8 +40,8 @@ func InsertOne[T sequel.TableColumnValuer[T], Ptr interface {
}
}
{{ else }}
func InsertOne[T sequel.TableColumnValuer[T], Ptr interface {
sequel.TableColumnValuer[T]
func InsertOne[T sequel.TableColumnValuer, Ptr interface {
sequel.TableColumnValuer
sequel.PtrScanner[T]
}](ctx context.Context, sqlConn sequel.DB, model Ptr) (sql.Result, error) {
switch v := any(model).(type) {
Expand Down Expand Up @@ -80,7 +80,7 @@ func InsertOne[T sequel.TableColumnValuer[T], Ptr interface {
{{ if eq driver "postgres" -}}
{{- /* postgres */ -}}
// Insert is a helper function to insert multiple records.
func Insert[T sequel.TableColumnValuer[T], Ptr sequel.PtrScanner[T]](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) {
func Insert[T sequel.TableColumnValuer, Ptr sequel.PtrScanner[T]](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) {
noOfData := len(data)
if noOfData == 0 {
return new(sequel.EmptyResult), nil
Expand Down Expand Up @@ -176,7 +176,7 @@ func Insert[T sequel.TableColumnValuer[T], Ptr sequel.PtrScanner[T]](ctx context
}
{{ else }}
// Insert is a helper function to insert multiple records.
func Insert[T sequel.TableColumnValuer[T]](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) {
func Insert[T sequel.TableColumnValuer](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) {
noOfData := len(data)
if noOfData == 0 {
return new(sequel.EmptyResult), nil
Expand Down Expand Up @@ -995,31 +995,36 @@ type sqlStmt struct {
args []any
}
func (s *sqlStmt) Var(query string, value any) {
var (
_ sequel.Stmt = (*sqlStmt)(nil)
)
func (s *sqlStmt) Var(value any) string {
s.pos++
s.args = append(s.args, value)
{{ if isStaticVar -}}
s.WriteString(query+"?")
return {{ quote varRune }}
{{ else -}}
s.WriteString(query+wrapVar(s.pos))
return wrapVar(s.pos)
{{ end -}}
s.args = append(s.args, value)
}
func (s *sqlStmt) Vars(query string, values []any) {
s.WriteString(query)
func (s *sqlStmt) Vars(values []any) string {
noOfLen := len(values)
{{ if isStaticVar -}}
s.WriteString("(" + strings.Repeat(",?", noOfLen)[1:] + ")")
s.args = append(s.args, values...)
return "(" + strings.Repeat(",{{ varRune }}", noOfLen)[1:] + ")"
{{ else -}}
s.WriteByte('(')
buf := new(strings.Builder)
buf.WriteByte('(')
i := s.pos
s.pos += noOfLen
for ; i < s.pos; i++ {
s.WriteString(wrapVar(i + 1))
buf.WriteString(wrapVar(i + 1))
}
s.WriteByte(')')
buf.WriteByte(')')
return buf.String()
{{ end -}}
s.args = append(s.args, values...)
}
func (s sqlStmt) Args() []any {
Expand All @@ -1039,20 +1044,20 @@ func DbTable[T sequel.Tabler](model T) string {
return model.TableName()
}
func dbName(model any) string {
if v, ok := model.(sequel.Databaser); ok {
return v.DatabaseName() + "."
}
return ""
}
func Columns[T sequel.Columner](model T) []string {
if v, ok := any(model).(sequel.SQLColumner); ok {
return v.SQLColumns()
}
return model.ColumnNames()
}
func dbName(model any) string {
if v, ok := model.(sequel.Databaser); ok {
return v.DatabaseName() + "."
}
return ""
}
{{ if not isStaticVar -}}
func wrapVar(i int) string {
return {{ quote varRune }}+ strconv.Itoa(i)
Expand Down
28 changes: 14 additions & 14 deletions codegen/templates/operator.go.tpl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{{- reserveImport "github.com/si3nloong/sqlgen/sequel" }}

func And(stmts ...sequel.WhereClause) sequel.WhereClause {
return func(stmt sequel.StmtBuilder) {
stmt.WriteByte('(')
Expand Down Expand Up @@ -27,13 +28,13 @@ func Or(stmts ...sequel.WhereClause) sequel.WhereClause {

func Equal[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause {
return func(stmt sequel.StmtBuilder) {
stmt.Var(f.ColumnName()+" = ", f.Convert(value))
stmt.WriteString(f.ColumnName() + " = " + stmt.Var(f.Convert(value)))
}
}

func NotEqual[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause {
return func(stmt sequel.StmtBuilder) {
stmt.Var(f.ColumnName()+" <> ", f.Convert(value))
stmt.WriteString(f.ColumnName() + " <> " + stmt.Var(f.Convert(value)))
}
}

Expand All @@ -43,7 +44,7 @@ func In[T any](f sequel.ColumnValuer[T], values ...T) sequel.WhereClause {
for idx := range values {
args[idx] = f.Convert(values[idx])
}
stmt.Vars(f.ColumnName()+" IN ", args)
stmt.WriteString(f.ColumnName() + " IN " + stmt.Vars(args))
}
}

Expand All @@ -53,43 +54,43 @@ func NotIn[T any](f sequel.ColumnValuer[T], values ...T) sequel.WhereClause {
for idx := range values {
args[idx] = f.Convert(values[idx])
}
stmt.Vars(f.ColumnName()+" NOT IN ", args)
stmt.WriteString(f.ColumnName() + " NOT IN " + stmt.Vars(args))
}
}

func GreaterThan[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause {
return func(stmt sequel.StmtBuilder) {
stmt.Var(f.ColumnName()+" > ", f.Convert(value))
stmt.WriteString(f.ColumnName() + " > " + stmt.Var(f.Convert(value)))
}
}

func GreaterThanOrEqual[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause {
return func(stmt sequel.StmtBuilder) {
stmt.Var(f.ColumnName()+" >= ", f.Convert(value))
stmt.WriteString(f.ColumnName() + " >= " + stmt.Var(f.Convert(value)))
}
}

func LessThan[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause {
return func(stmt sequel.StmtBuilder) {
stmt.Var(f.ColumnName()+" < ", f.Convert(value))
stmt.WriteString(f.ColumnName() + " < " + stmt.Var(f.Convert(value)))
}
}

func LessThanOrEqual[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause {
return func(stmt sequel.StmtBuilder) {
stmt.Var(f.ColumnName()+" <= ", f.Convert(value))
stmt.WriteString(f.ColumnName() + " <= " + stmt.Var(f.Convert(value)))
}
}

func Like[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause {
return func(stmt sequel.StmtBuilder) {
stmt.Var(f.ColumnName()+" LIKE ", f.Convert(value))
stmt.WriteString(f.ColumnName() + " LIKE " + stmt.Var(f.Convert(value)))
}
}

func NotLike[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause {
return func(stmt sequel.StmtBuilder) {
stmt.Var(f.ColumnName()+" NOT LIKE ", f.Convert(value))
stmt.WriteString(f.ColumnName() + " NOT LIKE " + stmt.Var(f.Convert(value)))
}
}

Expand All @@ -107,8 +108,7 @@ func IsNotNull[T any](f sequel.ColumnValuer[T]) sequel.WhereClause {

func Between[T comparable](f sequel.ColumnValuer[T], from, to T) sequel.WhereClause {
return func(stmt sequel.StmtBuilder) {
stmt.Var(f.ColumnName()+" BETWEEN ", from)
stmt.Var(" AND ", to)
stmt.WriteString(f.ColumnName() + " BETWEEN " + stmt.Var(from) + " AND " + stmt.Var(to))
}
}

Expand All @@ -118,7 +118,7 @@ func Set[T any](f sequel.ColumnValuer[T], value ...T) sequel.SetClause {
if len(value) > 0 {
defaultValue = f.Convert(value[0])
}
stmt.Var(f.ColumnName()+" = ", defaultValue)
stmt.WriteString(f.ColumnName() + " = " + stmt.Var(defaultValue))
}
}

Expand All @@ -132,4 +132,4 @@ func Desc[T any](f sequel.ColumnValuer[T]) sequel.OrderByClause {
return func(sw sequel.StmtWriter) {
sw.WriteString(f.ColumnName() + " DESC")
}
}
}
38 changes: 18 additions & 20 deletions examples/db/mysql/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,8 @@ import (
"github.com/si3nloong/sqlgen/sequel/strpool"
)

type autoIncrKeyInserter interface {
sequel.AutoIncrKeyer
sequel.SingleInserter
}

func InsertOne[T sequel.TableColumnValuer[T], Ptr interface {
sequel.TableColumnValuer[T]
func InsertOne[T sequel.TableColumnValuer, Ptr interface {
sequel.TableColumnValuer
sequel.PtrScanner[T]
}](ctx context.Context, sqlConn sequel.DB, model Ptr) (sql.Result, error) {
switch v := any(model).(type) {
Expand Down Expand Up @@ -55,7 +50,7 @@ func InsertOne[T sequel.TableColumnValuer[T], Ptr interface {
}

// Insert is a helper function to insert multiple records.
func Insert[T sequel.TableColumnValuer[T]](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) {
func Insert[T sequel.TableColumnValuer](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) {
noOfData := len(data)
if noOfData == 0 {
return new(sequel.EmptyResult), nil
Expand Down Expand Up @@ -565,17 +560,20 @@ type sqlStmt struct {
args []any
}

func (s *sqlStmt) Var(query string, value any) {
var (
_ sequel.Stmt = (*sqlStmt)(nil)
)

func (s *sqlStmt) Var(value any) string {
s.pos++
s.WriteString(query + "?")
s.args = append(s.args, value)
return `?`
}

func (s *sqlStmt) Vars(query string, values []any) {
s.WriteString(query)
func (s *sqlStmt) Vars(values []any) string {
noOfLen := len(values)
s.WriteString("(" + strings.Repeat(",?", noOfLen)[1:] + ")")
s.args = append(s.args, values...)
return "(" + strings.Repeat(",?", noOfLen)[1:] + ")"
}

func (s sqlStmt) Args() []any {
Expand All @@ -595,16 +593,16 @@ func DbTable[T sequel.Tabler](model T) string {
return model.TableName()
}

func dbName(model any) string {
if v, ok := model.(sequel.Databaser); ok {
return v.DatabaseName() + "."
}
return ""
}

func Columns[T sequel.Columner](model T) []string {
if v, ok := any(model).(sequel.SQLColumner); ok {
return v.SQLColumns()
}
return model.ColumnNames()
}

func dbName(model any) string {
if v, ok := model.(sequel.Databaser); ok {
return v.DatabaseName() + "."
}
return ""
}
Loading

0 comments on commit b0225ee

Please sign in to comment.