Skip to content

Commit

Permalink
chore: update template and generated codes
Browse files Browse the repository at this point in the history
  • Loading branch information
si3nloong committed Jun 10, 2024
1 parent b572e57 commit fd97913
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 18 deletions.
2 changes: 1 addition & 1 deletion codegen/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ func parseGoPackage(
// Check auto increment
_, model.IsAutoIncr = tag.Lookup(TagOptionAutoIncrement)
if model.IsAutoIncr && len(model.Keys) > 0 {
return fmt.Errorf(`sqlgen: you cannot have a composite key if you already have auto increment key`)
return fmt.Errorf(`sqlgen: you cannot have a composite key if you define auto increment key`)
}
model.Keys = append(model.Keys, tf)
}
Expand Down
141 changes: 127 additions & 14 deletions codegen/templates/db.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,39 @@
{{- reserveImport "github.com/si3nloong/sqlgen/sequel" }}
{{- reserveImport "github.com/si3nloong/sqlgen/sequel/strpool" }}

type autoIncrKeySingleInserter interface {
sequel.AutoIncrKeyer
sequel.SingleInserter
}

type primaryKeySingleInserter interface {
sequel.PrimaryKeyer
sequel.SingleInserter
}

func InsertOne[T sequel.TableColumnValuer[T], Ptr interface {
sequel.TableColumnValuer[T]
sequel.Scanner[T]
}](ctx context.Context, sqlConn sequel.DB, model Ptr) (sql.Result, error) {
switch v := any(model).(type) {
case autoIncrKeySingleInserter:
_, idx, _ := v.PK()
values := model.Values()
values = append(values[:idx], values[idx+1:]...)
return sqlConn.ExecContext(ctx, v.InsertOneStmt(), values...)
case primaryKeySingleInserter:
return sqlConn.ExecContext(ctx, v.InsertOneStmt(), model.Values()...)
case sequel.PrimaryKeyer:
_, idx, _ := v.PK()
columns, values := model.Columns(), model.Values()
columns = append(columns[:idx], columns[idx+1:]...)
values = append(values[:idx], values[idx+1:]...)
return sqlConn.ExecContext(ctx, "INSERT INTO "+dbName(model)+model.TableName()+" ("+strings.Join(columns, ",")+") VALUES ("+strings.Repeat(",?", len(columns))[1:]+");", values...)
case sequel.CompositeKeyer:
panic("TODO")
default:
panic("unreachable")
}
var (
args = model.Values()
columns []string
Expand Down Expand Up @@ -122,6 +151,16 @@ func Insert[T sequel.TableColumnValuer[T]](ctx context.Context, sqlConn sequel.D
return sqlConn.ExecContext(ctx, stmt.String(), args...)
}

type autoIncrKeySingleUpserter interface {
sequel.AutoIncrKeyer
sequel.SingleUpserter
}

type primaryKeySingleUpserter interface {
sequel.PrimaryKeyer
sequel.SingleUpserter
}

func UpsertOne[T sequel.KeyValuer[T], Ptr sequel.KeyValueScanner[T]](ctx context.Context, sqlConn sequel.DB, model Ptr, override bool, omittedFields ...string) (sql.Result, error) {
var (
{{ if eq driver "mysql" -}}
Expand Down Expand Up @@ -376,8 +415,44 @@ func Upsert[T sequel.KeyValuer[T], Ptr sequel.Scanner[T]](ctx context.Context, s
{{ end -}}
}

type primaryKeyFinder interface {
sequel.PrimaryKeyer
sequel.KeyFinder
}

// FindByPK is to find single record using primary key.
func FindByPK[T sequel.KeyValuer[T], Ptr sequel.KeyValueScanner[T]](ctx context.Context, sqlConn sequel.DB, model Ptr) error {
switch v := any(model).(type) {
case primaryKeyFinder:
_, _, pk := v.PK()
return sqlConn.QueryRowContext(ctx, v.FindByPKStmt(), pk).Scan(model.Addrs()...)
case sequel.PrimaryKeyer:
columns := model.Columns()
pkName, _, pk := v.PK()
return sqlConn.QueryRowContext(ctx, "SELECT "+strings.Join(columns, ",")+" FROM "+dbName(model)+model.TableName()+" WHERE "+pkName+" = {{ quoteVar 1 }} LIMIT 1;", pk).Scan(model.Addrs()...)
case sequel.CompositeKeyer:
columns := model.Columns()
names, _, keys := v.CompositeKey()
{{ if isStaticVar -}}
return sqlConn.QueryRowContext(ctx, "SELECT "+strings.Join(columns, ",")+" FROM "+dbName(model)+model.TableName()+" WHERE "+strings.Join(names, " = ? AND ")+" = ? LIMIT 1;", keys...).Scan(model.Addrs()...)
{{ else -}}
stmt := strpool.AcquireString()
defer strpool.ReleaseString(stmt)
stmt.WriteString("SELECT "+strings.Join(columns, ",")+" FROM "+dbName(model)+model.TableName()+" WHERE ")
max := len(names)
for i := 1; i <= max; i++ {
if i == 1 {
stmt.WriteString(names[i]+" = "+ wrapVar(i))
} else {
stmt.WriteString(" AND "+ names[i]+" = "+ wrapVar(i))
}
}
stmt.WriteString(" LIMIT 1;")
return sqlConn.QueryRowContext(ctx, stmt.String(), keys...).Scan(model.Addrs()...)
{{ end -}}
default:
panic("unreachable")
}
switch vi := any(model).(type) {
case sequel.KeyFinder:
_, _, pk := vi.PK()
Expand All @@ -392,20 +467,24 @@ func FindByPK[T sequel.KeyValuer[T], Ptr sequel.KeyValueScanner[T]](ctx context.
}
}

type primaryKeyUpdater interface {
sequel.PrimaryKeyer
sequel.KeyUpdater
}

// UpdateByPK is to update single record using primary key.
func UpdateByPK[T sequel.KeyValuer[T]](ctx context.Context, sqlConn sequel.DB, model T) (sql.Result, error) {
var (
pkName, idx, pk = model.PK()
columns, values = model.Columns(), model.Values()
)
switch vi := any(model).(type) {
case sequel.KeyUpdater:
values = append(values[:idx], append(values[idx+1:], pk)...)
return sqlConn.ExecContext(ctx, vi.UpdateByPKStmt(), values...)
default:
columns = append(columns[:idx], columns[idx+1:]...)
values = append(values[:idx], values[idx+1:]...)
switch v := any(model).(type) {
case primaryKeyUpdater:
_, pkIdx, pk := v.PK()
values := model.Values()
values = append(values[:pkIdx], append(values[pkIdx+1:], pk)...)
return sqlConn.ExecContext(ctx, v.UpdateByPKStmt(), values...)
case sequel.PrimaryKeyer:
pkName, pkIdx, pk := v.PK()
values := model.Values()
columns := model.Columns()
values = append(values[:pkIdx], append(values[pkIdx+1:], pk)...)
{{ if isStaticVar -}}
return sqlConn.ExecContext(ctx, "UPDATE "+dbName(model)+model.TableName()+" SET "+strings.Join(columns, " = {{ quoteVar 1 }},")+" = {{ quoteVar 1 }} WHERE "+pkName+" = {{ quoteVar 1 }};", append(values, pk)...)
{{ else -}}
Expand All @@ -421,13 +500,47 @@ func UpdateByPK[T sequel.KeyValuer[T]](ctx context.Context, sqlConn sequel.DB, m
stmt.WriteString(" WHERE "+ pkName +" = "+ wrapVar(len(columns) + 2)+ ";")
return sqlConn.ExecContext(ctx, stmt.String(), append(values, pk)...)
{{ end -}}
default:
panic("unreachable")
}
}

type primaryKeyDeleter interface {
sequel.PrimaryKeyer
sequel.KeyDeleter
}

// DeleteByPK is to update single record using primary key.
func DeleteByPK[T sequel.KeyValuer[T]](ctx context.Context, sqlConn sequel.DB, model T) (sql.Result, error) {
pkName, _, pk := model.PK()
return sqlConn.ExecContext(ctx, "DELETE FROM "+ dbName(model) + model.TableName() +" WHERE "+ pkName +" = {{ quoteVar 1 }};", pk)
switch v := any(model).(type) {
case primaryKeyDeleter:
_, _, pk := v.PK()
return sqlConn.ExecContext(ctx, v.DeleteByPKStmt(), pk)
case sequel.PrimaryKeyer:
pkName, _, pk := v.PK()
return sqlConn.ExecContext(ctx, "DELETE FROM "+dbName(model)+model.TableName()+" WHERE "+pkName+" = {{ quoteVar 1 }};", pk)
case sequel.CompositeKeyer:
names, _, keys := v.CompositeKey()
{{ if isStaticVar -}}
return sqlConn.ExecContext(ctx, "DELETE FROM "+dbName(model)+model.TableName()+" WHERE "+strings.Join(names, " = ? AND ")+" = ?;", keys...)
{{ else -}}
stmt := strpool.AcquireString()
defer strpool.ReleaseString(stmt)
stmt.WriteString("DELETE FROM "+dbName(model)+model.TableName()+" WHERE ")
max := len(names)
for i := 1; i <= max; i++ {
if i == 1 {
stmt.WriteString(names[i]+" = "+ wrapVar(i))
} else {
stmt.WriteString(" AND "+ names[i]+" = "+ wrapVar(i))
}
}
stmt.WriteByte(';')
return sqlConn.ExecContext(ctx, stmt.String(), keys...)
{{ end -}}
default:
panic("unreachable")
}
}

type SelectStmt struct {
Expand Down
16 changes: 13 additions & 3 deletions sequel/model.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package sequel

// For rename table name
type Table struct{}
type Table struct{ Name string }

type DatabaseNamer interface {
DatabaseName() string
Expand All @@ -20,12 +20,22 @@ type Valuer interface {
}

type Keyer interface {
PK() ([]string, []int, []any)
HasPK()
}

type PrimaryKeyer interface {
Keyer
PK() (string, int, any)
}

type AutoIncrKeyer interface {
PrimaryKeyer
AutoIncr()
}

type CompositeKeyer interface {
Keyer
IsAutoIncr()
CompositeKey() ([]string, []int, []any)
}

type DuplicateKeyer interface {
Expand Down
5 changes: 5 additions & 0 deletions sequel/sequel.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ type KeyUpdater interface {
UpdateByPKStmt() string
}

type KeyDeleter interface {
Keyer
DeleteByPKStmt() string
}

type Inserter interface {
Columner
InsertVarQuery() string
Expand Down

0 comments on commit fd97913

Please sign in to comment.