Skip to content

Commit

Permalink
fix: generated codes, should be final version
Browse files Browse the repository at this point in the history
  • Loading branch information
si3nloong committed Jun 12, 2024
1 parent e587124 commit 327b9db
Show file tree
Hide file tree
Showing 31 changed files with 470 additions and 210 deletions.
34 changes: 21 additions & 13 deletions codegen/sequel.go.tpl
Original file line number Diff line number Diff line change
@@ -1,21 +1,10 @@
package sequel

import "database/sql/driver"

// For rename table name
type Table struct{}

type Keyer interface {
PK() (columnName string, pos int, value driver.Value)
}

type AutoIncrKeyer interface {
Keyer
IsAutoIncr()
}

type DuplicateKeyer interface {
OnDuplicateKey() string
type DatabaseNamer interface {
DatabaseName() string
}

type Tabler interface {
Expand All @@ -29,3 +18,22 @@ type Columner interface {
type Valuer interface {
Values() []any
}

type Keyer interface {
HasPK()
}

type PrimaryKeyer interface {
Keyer
PK() (string, int, any)
}

type AutoIncrKeyer interface {
PrimaryKeyer
IsAutoIncr()
}

type CompositeKeyer interface {
Keyer
CompositeKey() ([]string, []int, []any)
}
187 changes: 140 additions & 47 deletions codegen/templates/db.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,17 @@ func InsertOne[T sequel.TableColumnValuer[T], Ptr interface {
columns, values := model.Columns(), model.Values()
stmt := strpool.AcquireString()
defer strpool.ReleaseString(stmt)
stmt.WriteString("INSERT INTO " + dbName(model) + model.TableName() + " (" + strings.Join(columns, ",") + ") VALUES (")
cols := strings.Join(columns, ",")
stmt.WriteString("INSERT INTO " + dbName(model) + model.TableName() + " (" + cols + ") VALUES (")
for i := range values {
if i > 0 {
stmt.WriteString(","+wrapVar(i + 1))
stmt.WriteString("," + wrapVar(i+1))
} else {
// argument always started from 1
stmt.WriteString(wrapVar(i + 1))
}
}
stmt.WriteString(") RETURNING "+ strings.Join(columns, ",") +";")
stmt.WriteString(") RETURNING " + cols + ";")
return sqlConn.QueryRowContext(ctx, stmt.String(), values...).Scan(model.Addrs()...)
}
}
Expand Down Expand Up @@ -84,66 +85,158 @@ func InsertOne[T sequel.TableColumnValuer[T], Ptr interface {
}
{{ end }}

{{ if eq driver "postgres" -}}
{{- /* postgres */ -}}
// Insert is a helper function to insert multiple records.
func Insert[T sequel.TableColumnValuer[T]](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) {
func Insert[T sequel.TableColumnValuer[T], Ptr sequel.Scanner[T]](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) {
noOfData := len(data)
if noOfData == 0 {
return new(sequel.EmptyResult), nil
}

var (
model = data[0]
columns = model.Columns()
idx = -1
noOfCols = len(columns)
args = make([]any, 0, noOfCols*len(data))
model = data[0]
columns = model.Columns()
stmt = strpool.AcquireString()
)
defer strpool.ReleaseString(stmt)

switch vi := any(model).(type) {
case sequel.Inserter:
query := strings.Repeat(vi.InsertVarQuery()+",", len(data))
return sqlConn.ExecContext(ctx, "INSERT INTO "+dbName(model)+model.TableName()+" ("+strings.Join(columns, ",")+") VALUES "+query[:len(query)-1]+";", args...)
switch v := any(model).(type) {
case sequel.AutoIncrKeyer:
_, idx, _ = vi.PK()
noOfCols--
_, idx, _ := v.PK()
columns = append(columns[:idx], columns[idx+1:]...)
noOfCols := len(columns)
cols := strings.Join(columns, ",")
args := make([]any, 0, noOfCols*noOfData)
stmt.WriteString("INSERT INTO " + dbName(model) + model.TableName() + " (" + cols + ") VALUES ")
var offset int
for i := range data {
if i > 0 {
stmt.WriteString(",(")
} else {
stmt.WriteByte('(')
}
offset = noOfCols * i
for j := 0; j < noOfCols; j++ {
if j > 0 {
stmt.WriteString("," + wrapVar(offset+1+j))
} else {
stmt.WriteString(wrapVar(offset + 1 + j))
}
}
stmt.WriteByte(')')
values := data[i].Values()
values = append(values[:idx], values[idx+1:]...)
args = append(args, values...)
}
stmt.WriteString(" RETURNING " + cols + ";")
rows, err := sqlConn.QueryContext(ctx, stmt.String(), args...)
if err != nil {
return nil, err
}
defer rows.Close()
var i int64
for rows.Next() {
if err := rows.Scan(Ptr(&data[i]).Addrs()...); err != nil {
return nil, err
}
i++
}
return sequel.NewRowsAffectedResult(i), rows.Close()
default:
noOfCols := len(columns)
cols := strings.Join(columns, ",")
args := make([]any, 0, noOfCols*noOfData)
stmt.WriteString("INSERT INTO " + dbName(model) + model.TableName() + " (" + cols + ") VALUES ")
var offset int
for i := range data {
if i > 0 {
stmt.WriteString(",(")
} else {
stmt.WriteByte('(')
}
offset = noOfCols * i
for j := 0; j < noOfCols; j++ {
if j > 0 {
stmt.WriteString("," + wrapVar(offset+1+j))
} else {
stmt.WriteString(wrapVar(offset + 1 + j))
}
}
stmt.WriteByte(')')
args = append(args, data[i].Values()...)
}
stmt.WriteString(" RETURNING " + cols + ";")
rows, err := sqlConn.QueryContext(ctx, stmt.String(), args...)
if err != nil {
return nil, err
}
defer rows.Close()
var i int64
for rows.Next() {
if err := rows.Scan(Ptr(&data[i]).Addrs()...); err != nil {
return nil, err
}
i++
}
return sequel.NewRowsAffectedResult(i), rows.Close()
}
}
{{ else }}
// Insert is a helper function to insert multiple records.
func Insert[T sequel.TableColumnValuer[T]](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) {
noOfData := len(data)
if noOfData == 0 {
return new(sequel.EmptyResult), nil
}

var stmt = strpool.AcquireString()
var (
model = data[0]
columns = model.Columns()
stmt = strpool.AcquireString()
)
defer strpool.ReleaseString(stmt)
stmt.WriteString("INSERT INTO " + dbName(model) + model.TableName() + " (" + strings.Join(columns, ",") + ") VALUES ")
for i := range data {
if i > 0 {
stmt.WriteString(",(")
} else {
stmt.WriteByte('(')
}
{{ if not isStaticVar -}}
offset := noOfCols * i
{{ end -}}
for j := 0; j < noOfCols; j++ {
if j > 0 {
stmt.WriteByte(',')

switch v := any(model).(type) {
case sequel.AutoIncrKeyer:
_, idx, _ := v.PK()
columns = append(columns[:idx], columns[idx+1:]...)
noOfCols := len(columns)
cols := strings.Join(columns, ",")
args := make([]any, 0, noOfCols*noOfData)
stmt.WriteString("INSERT INTO " + dbName(model) + model.TableName() + " (" + cols + ") VALUES ")
placeholder := "(" + strings.Repeat(",{{ quoteVar 1 }}", noOfCols)[1:] + ")"
for i := range data {
if i > 0 {
stmt.WriteString("," + placeholder)
} else {
stmt.WriteString(placeholder)
}
{{ if isStaticVar -}}
stmt.WriteString({{ quote varRune }})
{{ else -}}
stmt.WriteString(wrapVar(offset + 1 + j))
{{ end -}}
}
if idx > -1 {
values := data[i].Values()
values = append(values[:idx], values[idx+1:]...)
args = append(args, values...)
} else {
}
stmt.WriteByte(';')
return sqlConn.ExecContext(ctx, stmt.String(), args...)
default:
noOfCols := len(columns)
cols := strings.Join(columns, ",")
args := make([]any, 0, noOfCols*noOfData)
stmt.WriteString("INSERT INTO " + dbName(model) + model.TableName() + " (" + cols + ") VALUES ")
placeholder := "(" + strings.Repeat(",{{ quoteVar 1 }}", noOfCols)[1:] + ")"
for i := range data {
if i > 0 {
stmt.WriteString("," + placeholder)
} else {
stmt.WriteString(placeholder)
}
args = append(args, data[i].Values()...)
}
stmt.WriteByte(')')
stmt.WriteByte(';')
return sqlConn.ExecContext(ctx, stmt.String(), args...)
}
stmt.WriteByte(';')
return sqlConn.ExecContext(ctx, stmt.String(), args...)
}
{{ end }}

{{ if eq driver "postgres" -}}
{{- /* postgres */ -}}
Expand Down Expand Up @@ -348,14 +441,14 @@ func Upsert[T sequel.KeyValuer[T], Ptr sequel.Scanner[T]](ctx context.Context, s
}
placeholder := ",(" + strings.Repeat(",?", noOfCols)[1:] + ")"
for i := 0; i < noOfData; i++ {
values := data[i].Values()
values = append(values[:idx], values[idx+1:]...)
args = append(args, values...)
if i > 0 {
stmt.WriteString(placeholder)
} else {
stmt.WriteString(placeholder[1:])
}
values := data[i].Values()
values = append(values[:idx], values[idx+1:]...)
args = append(args, values...)
}
case sequel.PrimaryKeyer:
pkName, _, _ := v.PK()
Expand All @@ -367,12 +460,12 @@ func Upsert[T sequel.KeyValuer[T], Ptr sequel.Scanner[T]](ctx context.Context, s
}
placeholder := ",(" + strings.Repeat(",?", noOfCols)[1:] + ")"
for i := 0; i < noOfData; i++ {
args = append(args, data[i].Values()...)
if i > 0 {
stmt.WriteString(placeholder)
} else {
stmt.WriteString(placeholder[1:])
}
args = append(args, data[i].Values()...)
}
case sequel.CompositeKeyer:
if override {
Expand All @@ -382,12 +475,12 @@ func Upsert[T sequel.KeyValuer[T], Ptr sequel.Scanner[T]](ctx context.Context, s
}
placeholder := ",(" + strings.Repeat(",?", noOfCols)[1:] + ")"
for i := 0; i < noOfData; i++ {
args = append(args, data[i].Values()...)
if i > 0 {
stmt.WriteString(placeholder)
} else {
stmt.WriteString(placeholder[1:])
}
args = append(args, data[i].Values()...)
}
_, idxs, _ := v.CompositeKey()
// Exclude primary key, don't update it
Expand All @@ -397,7 +490,7 @@ func Upsert[T sequel.KeyValuer[T], Ptr sequel.Scanner[T]](ctx context.Context, s
}
if override {
stmt.WriteString(" ON DUPLICATE KEY UPDATE ")
/* don't update primary key when we do upsert */
// Don't update primary key when we do upsert
omitDict := map[string]struct{}{}
for i := range omittedFields {
omitDict[omittedFields[i]] = struct{}{}
Expand Down
4 changes: 4 additions & 0 deletions codegen/templates/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ type Model struct {
// HasRow bool
}

func (m Model) IsCompositeKey() bool {
return len(m.Keys) > 1
}

func (m Model) HasNotOnlyPK() bool {
for i := range m.Fields {
if !lo.Contains(m.Keys, m.Fields[i]) {
Expand Down
8 changes: 7 additions & 1 deletion codegen/templates/model.go.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,15 @@ func ({{ $structName }}) HasPK() {}
{{ if .IsAutoIncr -}}
func ({{ $structName }}) IsAutoIncr() {}
{{ end -}}
func (v {{ $structName }}) PK() ([]string, []int, []any) {
{{ if .IsCompositeKey -}}
func (v {{ $structName }}) CompositeKey() ([]string, []int, []any) {
return {{ `[]string{` }}{{ range $i, $f := .Keys }}{{- if $i }}{{ ", " }}{{ end }}{{ quote (quoteIdentifier $f.ColumnName) }}{{ end }}{{- `},` }}{{ `[]int{` }}{{ range $i, $f := .Keys }}{{- if $i }}{{ ", " }}{{ end }}{{ $f.Index }}{{ end }}{{- `},` }}{{ `[]any{` }}{{ range $i, $f := .Keys }}{{- if $i }}{{ ", " }}{{ end }}{{ castAs $f }}{{ end }}{{- `}` }}
}
{{ else -}}
func (v {{ $structName }}) PK() (string, int, any) {
return {{ quote (quoteIdentifier (index .Keys 0).ColumnName) }}{{ ", " }}{{ (index .Keys 0).Index }}{{ ", " }}{{ castAs (index .Keys 0) }}
}
{{ end -}}
{{ if (and (not $hasCustomTabler) ($hasNotOnlyPK)) -}}
{{- /* If it has static table and columns other than key */ -}}
func ({{ $structName }}) FindByPKStmt() string {
Expand Down
Loading

0 comments on commit 327b9db

Please sign in to comment.