diff --git a/codegen/codegen.go b/codegen/codegen.go index ab4c3f3..bff1ea4 100644 --- a/codegen/codegen.go +++ b/codegen/codegen.go @@ -136,10 +136,14 @@ func (b tableInfo) Keys() []sequel.ColumnSchema { return b.keys } -func (b tableInfo) Columns() []sequel.ColumnSchema { +func (b tableInfo) Indexes() []sequel.IndexSchema { return nil } +func (b tableInfo) Columns() []sequel.ColumnSchema { + return b.columns +} + func (b tableInfo) Implements(T *types.Interface) (*types.Func, bool) { return types.MissingMethod(b.t, T, true) } @@ -163,7 +167,7 @@ var ( _ (sequel.ColumnSchema) = (*columnInfo)(nil) ) -func (i columnInfo) SQLValuer() sequel.SQLFunc { +func (i columnInfo) SQLValuer() sequel.QueryFunc { if i.model == nil { return nil } @@ -172,7 +176,7 @@ func (i columnInfo) SQLValuer() sequel.SQLFunc { } } -func (i columnInfo) SQLScanner() sequel.SQLFunc { +func (i columnInfo) SQLScanner() sequel.QueryFunc { if i.model == nil { return nil } @@ -326,7 +330,6 @@ func Generate(c *config.Config) error { // true, "", cfg.Database.Package, - cfg.Getter.Prefix, cfg.Database.Dir, cfg.Database.Filename, ); err != nil { @@ -341,7 +344,6 @@ func Generate(c *config.Config) error { "operator.go.tpl", "", cfg.Database.Operator.Package, - cfg.Getter.Prefix, cfg.Database.Operator.Dir, cfg.Database.Operator.Filename, ); err != nil { diff --git a/codegen/init.go b/codegen/init.go index 5cf8854..8a9f9bd 100644 --- a/codegen/init.go +++ b/codegen/init.go @@ -37,7 +37,6 @@ func renderTemplate( tmplName string, pkgPath string, pkgName string, - getter string, dstDir string, dstFilename string, ) error { diff --git a/codegen/templates/db.go.tpl b/codegen/templates/db.go.tpl index 0b0ba7d..716e4df 100644 --- a/codegen/templates/db.go.tpl +++ b/codegen/templates/db.go.tpl @@ -8,13 +8,13 @@ type autoIncrKeyInserter interface { sequel.AutoIncrKeyer - sequel.SingleInserter + sequel.PrimaryKeyer } {{ if eq driver "postgres" -}} {{- /* postgres */ -}} -func InsertOne[T sequel.TableColumnValuer[T], Ptr interface { - sequel.TableColumnValuer[T] +func InsertOne[T sequel.TableColumnValuer, Ptr interface { + sequel.TableColumnValuer sequel.PtrScanner[T] }](ctx context.Context, sqlConn sequel.DB, model Ptr) error { switch v := any(model).(type) { @@ -40,8 +40,8 @@ func InsertOne[T sequel.TableColumnValuer[T], Ptr interface { } } {{ else }} -func InsertOne[T sequel.TableColumnValuer[T], Ptr interface { - sequel.TableColumnValuer[T] +func InsertOne[T sequel.TableColumnValuer, Ptr interface { + sequel.TableColumnValuer sequel.PtrScanner[T] }](ctx context.Context, sqlConn sequel.DB, model Ptr) (sql.Result, error) { switch v := any(model).(type) { @@ -80,7 +80,7 @@ func InsertOne[T sequel.TableColumnValuer[T], Ptr interface { {{ if eq driver "postgres" -}} {{- /* postgres */ -}} // Insert is a helper function to insert multiple records. -func Insert[T sequel.TableColumnValuer[T], Ptr sequel.PtrScanner[T]](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) { +func Insert[T sequel.TableColumnValuer, Ptr sequel.PtrScanner[T]](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) { noOfData := len(data) if noOfData == 0 { return new(sequel.EmptyResult), nil @@ -176,7 +176,7 @@ func Insert[T sequel.TableColumnValuer[T], Ptr sequel.PtrScanner[T]](ctx context } {{ else }} // Insert is a helper function to insert multiple records. -func Insert[T sequel.TableColumnValuer[T]](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) { +func Insert[T sequel.TableColumnValuer](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) { noOfData := len(data) if noOfData == 0 { return new(sequel.EmptyResult), nil @@ -995,31 +995,36 @@ type sqlStmt struct { args []any } -func (s *sqlStmt) Var(query string, value any) { +var ( + _ sequel.Stmt = (*sqlStmt)(nil) +) + +func (s *sqlStmt) Var(value any) string { s.pos++ + s.args = append(s.args, value) {{ if isStaticVar -}} - s.WriteString(query+"?") + return {{ quote varRune }} {{ else -}} - s.WriteString(query+wrapVar(s.pos)) + return wrapVar(s.pos) {{ end -}} - s.args = append(s.args, value) } -func (s *sqlStmt) Vars(query string, values []any) { - s.WriteString(query) +func (s *sqlStmt) Vars(values []any) string { noOfLen := len(values) {{ if isStaticVar -}} - s.WriteString("(" + strings.Repeat(",?", noOfLen)[1:] + ")") + s.args = append(s.args, values...) + return "(" + strings.Repeat(",{{ varRune }}", noOfLen)[1:] + ")" {{ else -}} - s.WriteByte('(') + buf := new(strings.Builder) + buf.WriteByte('(') i := s.pos s.pos += noOfLen for ; i < s.pos; i++ { - s.WriteString(wrapVar(i + 1)) + buf.WriteString(wrapVar(i + 1)) } - s.WriteByte(')') + buf.WriteByte(')') + return buf.String() {{ end -}} - s.args = append(s.args, values...) } func (s sqlStmt) Args() []any { @@ -1039,13 +1044,6 @@ func DbTable[T sequel.Tabler](model T) string { return model.TableName() } -func dbName(model any) string { - if v, ok := model.(sequel.Databaser); ok { - return v.DatabaseName() + "." - } - return "" -} - func Columns[T sequel.Columner](model T) []string { if v, ok := any(model).(sequel.SQLColumner); ok { return v.SQLColumns() @@ -1053,6 +1051,13 @@ func Columns[T sequel.Columner](model T) []string { return model.ColumnNames() } +func dbName(model any) string { + if v, ok := model.(sequel.Databaser); ok { + return v.DatabaseName() + "." + } + return "" +} + {{ if not isStaticVar -}} func wrapVar(i int) string { return {{ quote varRune }}+ strconv.Itoa(i) diff --git a/codegen/templates/operator.go.tpl b/codegen/templates/operator.go.tpl index 3cf177f..1513041 100644 --- a/codegen/templates/operator.go.tpl +++ b/codegen/templates/operator.go.tpl @@ -1,4 +1,5 @@ {{- reserveImport "github.com/si3nloong/sqlgen/sequel" }} + func And(stmts ...sequel.WhereClause) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { stmt.WriteByte('(') @@ -27,13 +28,13 @@ func Or(stmts ...sequel.WhereClause) sequel.WhereClause { func Equal[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" = ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " = " + stmt.Var(f.Convert(value))) } } func NotEqual[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" <> ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " <> " + stmt.Var(f.Convert(value))) } } @@ -43,7 +44,7 @@ func In[T any](f sequel.ColumnValuer[T], values ...T) sequel.WhereClause { for idx := range values { args[idx] = f.Convert(values[idx]) } - stmt.Vars(f.ColumnName()+" IN ", args) + stmt.WriteString(f.ColumnName() + " IN " + stmt.Vars(args)) } } @@ -53,43 +54,43 @@ func NotIn[T any](f sequel.ColumnValuer[T], values ...T) sequel.WhereClause { for idx := range values { args[idx] = f.Convert(values[idx]) } - stmt.Vars(f.ColumnName()+" NOT IN ", args) + stmt.WriteString(f.ColumnName() + " NOT IN " + stmt.Vars(args)) } } func GreaterThan[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" > ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " > " + stmt.Var(f.Convert(value))) } } func GreaterThanOrEqual[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" >= ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " >= " + stmt.Var(f.Convert(value))) } } func LessThan[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" < ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " < " + stmt.Var(f.Convert(value))) } } func LessThanOrEqual[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" <= ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " <= " + stmt.Var(f.Convert(value))) } } func Like[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" LIKE ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " LIKE " + stmt.Var(f.Convert(value))) } } func NotLike[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" NOT LIKE ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " NOT LIKE " + stmt.Var(f.Convert(value))) } } @@ -107,8 +108,7 @@ func IsNotNull[T any](f sequel.ColumnValuer[T]) sequel.WhereClause { func Between[T comparable](f sequel.ColumnValuer[T], from, to T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" BETWEEN ", from) - stmt.Var(" AND ", to) + stmt.WriteString(f.ColumnName() + " BETWEEN " + stmt.Var(from) + " AND " + stmt.Var(to)) } } @@ -118,7 +118,7 @@ func Set[T any](f sequel.ColumnValuer[T], value ...T) sequel.SetClause { if len(value) > 0 { defaultValue = f.Convert(value[0]) } - stmt.Var(f.ColumnName()+" = ", defaultValue) + stmt.WriteString(f.ColumnName() + " = " + stmt.Var(defaultValue)) } } @@ -132,4 +132,4 @@ func Desc[T any](f sequel.ColumnValuer[T]) sequel.OrderByClause { return func(sw sequel.StmtWriter) { sw.WriteString(f.ColumnName() + " DESC") } -} \ No newline at end of file +} diff --git a/examples/db/mysql/db.go b/examples/db/mysql/db.go index b877221..86cc758 100755 --- a/examples/db/mysql/db.go +++ b/examples/db/mysql/db.go @@ -13,13 +13,8 @@ import ( "github.com/si3nloong/sqlgen/sequel/strpool" ) -type autoIncrKeyInserter interface { - sequel.AutoIncrKeyer - sequel.SingleInserter -} - -func InsertOne[T sequel.TableColumnValuer[T], Ptr interface { - sequel.TableColumnValuer[T] +func InsertOne[T sequel.TableColumnValuer, Ptr interface { + sequel.TableColumnValuer sequel.PtrScanner[T] }](ctx context.Context, sqlConn sequel.DB, model Ptr) (sql.Result, error) { switch v := any(model).(type) { @@ -55,7 +50,7 @@ func InsertOne[T sequel.TableColumnValuer[T], Ptr interface { } // Insert is a helper function to insert multiple records. -func Insert[T sequel.TableColumnValuer[T]](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) { +func Insert[T sequel.TableColumnValuer](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) { noOfData := len(data) if noOfData == 0 { return new(sequel.EmptyResult), nil @@ -565,17 +560,20 @@ type sqlStmt struct { args []any } -func (s *sqlStmt) Var(query string, value any) { +var ( + _ sequel.Stmt = (*sqlStmt)(nil) +) + +func (s *sqlStmt) Var(value any) string { s.pos++ - s.WriteString(query + "?") s.args = append(s.args, value) + return `?` } -func (s *sqlStmt) Vars(query string, values []any) { - s.WriteString(query) +func (s *sqlStmt) Vars(values []any) string { noOfLen := len(values) - s.WriteString("(" + strings.Repeat(",?", noOfLen)[1:] + ")") s.args = append(s.args, values...) + return "(" + strings.Repeat(",?", noOfLen)[1:] + ")" } func (s sqlStmt) Args() []any { @@ -595,16 +593,16 @@ func DbTable[T sequel.Tabler](model T) string { return model.TableName() } -func dbName(model any) string { - if v, ok := model.(sequel.Databaser); ok { - return v.DatabaseName() + "." - } - return "" -} - func Columns[T sequel.Columner](model T) []string { if v, ok := any(model).(sequel.SQLColumner); ok { return v.SQLColumns() } return model.ColumnNames() } + +func dbName(model any) string { + if v, ok := model.(sequel.Databaser); ok { + return v.DatabaseName() + "." + } + return "" +} diff --git a/examples/db/mysql/migrate.go b/examples/db/mysql/migrate.go new file mode 100644 index 0000000..6515318 --- /dev/null +++ b/examples/db/mysql/migrate.go @@ -0,0 +1,159 @@ +package mysqldb + +import ( + "context" + "strings" + "sync" + + "github.com/si3nloong/sqlgen/sequel" + "github.com/si3nloong/sqlgen/sequel/strpool" +) + +var ( + mu sync.RWMutex +) + +type Column struct { + Name string + Type string + DataType string + Size int + Pos int + IsNullable bool + Default any + Comment string +} + +type Index struct { + Name string + Type string + Nullable bool +} + +func TableExists(ctx context.Context, sqlConn sequel.DB, dbName, tableName string) (bool, error) { + var count int64 + if err := sqlConn.QueryRowContext(ctx, `SELECT COUNT(*) FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?;`, dbName, tableName).Scan(&count); err != nil { + return false, err + } + return count > 0, nil +} + +func UnsafeMigrate[T interface { + sequel.Tabler + sequel.Migrator +}](ctx context.Context, sqlConn sequel.DB, dbName string) error { + mu.Lock() + defer mu.Unlock() + + var v T + tableName := v.TableName() + exists, err := TableExists(ctx, sqlConn, dbName, tableName) + if err != nil { + return err + } + def := v.Schemas() + stmt := strpool.AcquireString() + defer strpool.ReleaseString(stmt) + // If the table exists, we will use alter table + if exists { + // Alter table need to check primary key, foreign key and indexes + stmt.WriteString("ALTER TABLE " + tableName + " (") + colDict := make(map[string]Column) + if err := tableColumns(ctx, sqlConn, dbName, tableName, func(c Column, _ int) { + colDict[c.Name] = c + }); err != nil { + return err + } + for i, col := range def.Columns { + if _, ok := colDict[col.Name]; !ok { + stmt.WriteString(",ADD COLUMN " + col.Definition) + } else { + stmt.WriteString(",MODIFY COLUMN " + col.Definition) + } + if i > 0 { + stmt.WriteString(" FIRST") + } else { + stmt.WriteString(" AFTER " + def.Columns[i-1].Name) + } + } + clear(colDict) + idxDict := make(map[string]Index) + if err := tableIndexes(ctx, sqlConn, dbName, tableName, func(idx Index, _ int) { + idxDict[idx.Name] = idx + }); err != nil { + return err + } + for _, idx := range def.Indexes { + if _, ok := idxDict[idx.Name]; !ok { + stmt.WriteString(",DROP INDEX " + idx.Name) + } else { + stmt.WriteString(",ADD " + idx.Definition) + } + } + clear(idxDict) + stmt.WriteString(") ENGINE=INNODB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;") + } else { + stmt.WriteString("CREATE TABLE IF NOT EXISTS " + tableName + " (") + for i, col := range def.Columns { + if i > 0 { + stmt.WriteString("," + col.Definition + " AFTER " + def.Columns[i-1].Name) + } else { + stmt.WriteString(col.Definition + " FIRST") + } + } + switch vi := any(v).(type) { + case sequel.PrimaryKeyer: + pkName, _, _ := vi.PK() + stmt.WriteString(",PRIMARY KEY(" + pkName + ")") + case sequel.CompositeKeyer: + keys, _, _ := vi.CompositeKey() + stmt.WriteString(",PRIMARY KEY(" + strings.Join(keys, ",") + ")") + } + for _, idx := range def.Indexes { + stmt.WriteString(",ADD " + idx.Definition) + } + stmt.WriteString(") ENGINE=INNODB CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;") + } + if _, err := sqlConn.ExecContext(ctx, stmt.String()); err != nil { + return err + } + return nil +} + +func tableColumns(ctx context.Context, sqlConn sequel.DB, dbName, tableName string, reduceFunc func(Column, int)) error { + rows, err := sqlConn.QueryContext(ctx, `SELECT ORDINAL_POSITION, COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, DATA_TYPE, COLUMN_COMMENT FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? ORDER BY ORDINAL_POSITION;`, dbName, tableName) + if err != nil { + return err + } + defer rows.Close() + + var i int + for rows.Next() { + var column Column + if err := rows.Scan(&column.Pos, &column.Name, &column.Type, &column.Default, &column.IsNullable, &column.DataType, &column.Comment); err != nil { + return err + } + reduceFunc(column, i) + i++ + } + return rows.Close() +} + +func tableIndexes(ctx context.Context, sqlConn sequel.DB, dbName, tableName string, reduceFunc func(Index, int)) error { + rows, err := sqlConn.QueryContext(ctx, `SELECT DISTINCT INDEX_NAME, INDEX_TYPE, NON_UNIQUE FROM INFORMATION_SCHEMA.STATISTICS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?;`, dbName, tableName) + if err != nil { + return err + } + defer rows.Close() + + var i int + for rows.Next() { + var index Index + if err := rows.Scan(&index.Name, &index.Type, &index.Nullable); err != nil { + return err + } + reduceFunc(index, i) + i++ + } + return rows.Close() +} diff --git a/examples/db/mysql/operator.go b/examples/db/mysql/operator.go index 5f28e62..0acb1ce 100755 --- a/examples/db/mysql/operator.go +++ b/examples/db/mysql/operator.go @@ -32,13 +32,13 @@ func Or(stmts ...sequel.WhereClause) sequel.WhereClause { func Equal[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" = ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " = " + stmt.Var(f.Convert(value))) } } func NotEqual[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" <> ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " <> " + stmt.Var(f.Convert(value))) } } @@ -48,7 +48,7 @@ func In[T any](f sequel.ColumnValuer[T], values ...T) sequel.WhereClause { for idx := range values { args[idx] = f.Convert(values[idx]) } - stmt.Vars(f.ColumnName()+" IN ", args) + stmt.WriteString(f.ColumnName() + " IN " + stmt.Vars(args)) } } @@ -58,43 +58,43 @@ func NotIn[T any](f sequel.ColumnValuer[T], values ...T) sequel.WhereClause { for idx := range values { args[idx] = f.Convert(values[idx]) } - stmt.Vars(f.ColumnName()+" NOT IN ", args) + stmt.WriteString(f.ColumnName() + " NOT IN " + stmt.Vars(args)) } } func GreaterThan[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" > ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " > " + stmt.Var(f.Convert(value))) } } func GreaterThanOrEqual[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" >= ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " >= " + stmt.Var(f.Convert(value))) } } func LessThan[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" < ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " < " + stmt.Var(f.Convert(value))) } } func LessThanOrEqual[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" <= ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " <= " + stmt.Var(f.Convert(value))) } } func Like[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" LIKE ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " LIKE " + stmt.Var(f.Convert(value))) } } func NotLike[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" NOT LIKE ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " NOT LIKE " + stmt.Var(f.Convert(value))) } } @@ -112,8 +112,7 @@ func IsNotNull[T any](f sequel.ColumnValuer[T]) sequel.WhereClause { func Between[T comparable](f sequel.ColumnValuer[T], from, to T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" BETWEEN ", from) - stmt.Var(" AND ", to) + stmt.WriteString(f.ColumnName() + " BETWEEN " + stmt.Var(from) + " AND " + stmt.Var(to)) } } @@ -123,7 +122,7 @@ func Set[T any](f sequel.ColumnValuer[T], value ...T) sequel.SetClause { if len(value) > 0 { defaultValue = f.Convert(value[0]) } - stmt.Var(f.ColumnName()+" = ", defaultValue) + stmt.WriteString(f.ColumnName() + " = " + stmt.Var(defaultValue)) } } diff --git a/examples/db/postgres/db.go b/examples/db/postgres/db.go index d7a9718..90dc111 100755 --- a/examples/db/postgres/db.go +++ b/examples/db/postgres/db.go @@ -12,13 +12,8 @@ import ( "github.com/si3nloong/sqlgen/sequel/strpool" ) -type autoIncrKeyInserter interface { - sequel.AutoIncrKeyer - sequel.SingleInserter -} - -func InsertOne[T sequel.TableColumnValuer[T], Ptr interface { - sequel.TableColumnValuer[T] +func InsertOne[T sequel.TableColumnValuer, Ptr interface { + sequel.TableColumnValuer sequel.PtrScanner[T] }](ctx context.Context, sqlConn sequel.DB, model Ptr) error { switch v := any(model).(type) { @@ -45,7 +40,7 @@ func InsertOne[T sequel.TableColumnValuer[T], Ptr interface { } // Insert is a helper function to insert multiple records. -func Insert[T sequel.TableColumnValuer[T], Ptr sequel.PtrScanner[T]](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) { +func Insert[T sequel.TableColumnValuer, Ptr sequel.PtrScanner[T]](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) { noOfData := len(data) if noOfData == 0 { return new(sequel.EmptyResult), nil @@ -704,23 +699,27 @@ type sqlStmt struct { args []any } -func (s *sqlStmt) Var(query string, value any) { +var ( + _ sequel.Stmt = (*sqlStmt)(nil) +) + +func (s *sqlStmt) Var(value any) string { s.pos++ - s.WriteString(query + wrapVar(s.pos)) s.args = append(s.args, value) + return wrapVar(s.pos) } -func (s *sqlStmt) Vars(query string, values []any) { - s.WriteString(query) +func (s *sqlStmt) Vars(values []any) string { noOfLen := len(values) - s.WriteByte('(') + buf := new(strings.Builder) + buf.WriteByte('(') i := s.pos s.pos += noOfLen for ; i < s.pos; i++ { - s.WriteString(wrapVar(i + 1)) + buf.WriteString(wrapVar(i + 1)) } - s.WriteByte(')') - s.args = append(s.args, values...) + buf.WriteByte(')') + return buf.String() } func (s sqlStmt) Args() []any { @@ -740,13 +739,6 @@ func DbTable[T sequel.Tabler](model T) string { return model.TableName() } -func dbName(model any) string { - if v, ok := model.(sequel.Databaser); ok { - return v.DatabaseName() + "." - } - return "" -} - func Columns[T sequel.Columner](model T) []string { if v, ok := any(model).(sequel.SQLColumner); ok { return v.SQLColumns() @@ -754,6 +746,13 @@ func Columns[T sequel.Columner](model T) []string { return model.ColumnNames() } +func dbName(model any) string { + if v, ok := model.(sequel.Databaser); ok { + return v.DatabaseName() + "." + } + return "" +} + func wrapVar(i int) string { return "$" + strconv.Itoa(i) } diff --git a/examples/db/postgres/operator.go b/examples/db/postgres/operator.go index c622cac..5412140 100755 --- a/examples/db/postgres/operator.go +++ b/examples/db/postgres/operator.go @@ -32,13 +32,13 @@ func Or(stmts ...sequel.WhereClause) sequel.WhereClause { func Equal[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" = ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " = " + stmt.Var(f.Convert(value))) } } func NotEqual[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" <> ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " <> " + stmt.Var(f.Convert(value))) } } @@ -48,7 +48,7 @@ func In[T any](f sequel.ColumnValuer[T], values ...T) sequel.WhereClause { for idx := range values { args[idx] = f.Convert(values[idx]) } - stmt.Vars(f.ColumnName()+" IN ", args) + stmt.WriteString(f.ColumnName() + " IN " + stmt.Vars(args)) } } @@ -58,43 +58,43 @@ func NotIn[T any](f sequel.ColumnValuer[T], values ...T) sequel.WhereClause { for idx := range values { args[idx] = f.Convert(values[idx]) } - stmt.Vars(f.ColumnName()+" NOT IN ", args) + stmt.WriteString(f.ColumnName() + " NOT IN " + stmt.Vars(args)) } } func GreaterThan[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" > ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " > " + stmt.Var(f.Convert(value))) } } func GreaterThanOrEqual[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" >= ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " >= " + stmt.Var(f.Convert(value))) } } func LessThan[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" < ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " < " + stmt.Var(f.Convert(value))) } } func LessThanOrEqual[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" <= ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " <= " + stmt.Var(f.Convert(value))) } } func Like[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" LIKE ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " LIKE " + stmt.Var(f.Convert(value))) } } func NotLike[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" NOT LIKE ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " NOT LIKE " + stmt.Var(f.Convert(value))) } } @@ -112,8 +112,7 @@ func IsNotNull[T any](f sequel.ColumnValuer[T]) sequel.WhereClause { func Between[T comparable](f sequel.ColumnValuer[T], from, to T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" BETWEEN ", from) - stmt.Var(" AND ", to) + stmt.WriteString(f.ColumnName() + " BETWEEN " + stmt.Var(from) + " AND " + stmt.Var(to)) } } @@ -123,7 +122,7 @@ func Set[T any](f sequel.ColumnValuer[T], value ...T) sequel.SetClause { if len(value) > 0 { defaultValue = f.Convert(value[0]) } - stmt.Var(f.ColumnName()+" = ", defaultValue) + stmt.WriteString(f.ColumnName() + " = " + stmt.Var(defaultValue)) } } diff --git a/examples/db/sqlite/db.go b/examples/db/sqlite/db.go index c8e3a03..647d8d7 100755 --- a/examples/db/sqlite/db.go +++ b/examples/db/sqlite/db.go @@ -13,13 +13,8 @@ import ( "github.com/si3nloong/sqlgen/sequel/strpool" ) -type autoIncrKeyInserter interface { - sequel.AutoIncrKeyer - sequel.SingleInserter -} - -func InsertOne[T sequel.TableColumnValuer[T], Ptr interface { - sequel.TableColumnValuer[T] +func InsertOne[T sequel.TableColumnValuer, Ptr interface { + sequel.TableColumnValuer sequel.PtrScanner[T] }](ctx context.Context, sqlConn sequel.DB, model Ptr) (sql.Result, error) { switch v := any(model).(type) { @@ -55,7 +50,7 @@ func InsertOne[T sequel.TableColumnValuer[T], Ptr interface { } // Insert is a helper function to insert multiple records. -func Insert[T sequel.TableColumnValuer[T]](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) { +func Insert[T sequel.TableColumnValuer](ctx context.Context, sqlConn sequel.DB, data []T) (sql.Result, error) { noOfData := len(data) if noOfData == 0 { return new(sequel.EmptyResult), nil @@ -565,17 +560,20 @@ type sqlStmt struct { args []any } -func (s *sqlStmt) Var(query string, value any) { +var ( + _ sequel.Stmt = (*sqlStmt)(nil) +) + +func (s *sqlStmt) Var(value any) string { s.pos++ - s.WriteString(query + "?") s.args = append(s.args, value) + return "?" } -func (s *sqlStmt) Vars(query string, values []any) { - s.WriteString(query) +func (s *sqlStmt) Vars(values []any) string { noOfLen := len(values) - s.WriteString("(" + strings.Repeat(",?", noOfLen)[1:] + ")") s.args = append(s.args, values...) + return "(" + strings.Repeat(",?", noOfLen)[1:] + ")" } func (s sqlStmt) Args() []any { @@ -595,16 +593,16 @@ func DbTable[T sequel.Tabler](model T) string { return model.TableName() } -func dbName(model any) string { - if v, ok := model.(sequel.Databaser); ok { - return v.DatabaseName() + "." - } - return "" -} - func Columns[T sequel.Columner](model T) []string { if v, ok := any(model).(sequel.SQLColumner); ok { return v.SQLColumns() } return model.ColumnNames() } + +func dbName(model any) string { + if v, ok := model.(sequel.Databaser); ok { + return v.DatabaseName() + "." + } + return "" +} diff --git a/examples/db/sqlite/operator.go b/examples/db/sqlite/operator.go index 4cab677..5e5899e 100755 --- a/examples/db/sqlite/operator.go +++ b/examples/db/sqlite/operator.go @@ -32,13 +32,13 @@ func Or(stmts ...sequel.WhereClause) sequel.WhereClause { func Equal[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" = ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " = " + stmt.Var(f.Convert(value))) } } func NotEqual[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" <> ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " <> " + stmt.Var(f.Convert(value))) } } @@ -48,7 +48,7 @@ func In[T any](f sequel.ColumnValuer[T], values ...T) sequel.WhereClause { for idx := range values { args[idx] = f.Convert(values[idx]) } - stmt.Vars(f.ColumnName()+" IN ", args) + stmt.WriteString(f.ColumnName() + " IN " + stmt.Vars(args)) } } @@ -58,43 +58,43 @@ func NotIn[T any](f sequel.ColumnValuer[T], values ...T) sequel.WhereClause { for idx := range values { args[idx] = f.Convert(values[idx]) } - stmt.Vars(f.ColumnName()+" NOT IN ", args) + stmt.WriteString(f.ColumnName() + " NOT IN " + stmt.Vars(args)) } } func GreaterThan[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" > ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " > " + stmt.Var(f.Convert(value))) } } func GreaterThanOrEqual[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" >= ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " >= " + stmt.Var(f.Convert(value))) } } func LessThan[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" < ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " < " + stmt.Var(f.Convert(value))) } } func LessThanOrEqual[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" <= ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " <= " + stmt.Var(f.Convert(value))) } } func Like[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" LIKE ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " LIKE " + stmt.Var(f.Convert(value))) } } func NotLike[T comparable](f sequel.ColumnValuer[T], value T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" NOT LIKE ", f.Convert(value)) + stmt.WriteString(f.ColumnName() + " NOT LIKE " + stmt.Var(f.Convert(value))) } } @@ -112,8 +112,7 @@ func IsNotNull[T any](f sequel.ColumnValuer[T]) sequel.WhereClause { func Between[T comparable](f sequel.ColumnValuer[T], from, to T) sequel.WhereClause { return func(stmt sequel.StmtBuilder) { - stmt.Var(f.ColumnName()+" BETWEEN ", from) - stmt.Var(" AND ", to) + stmt.WriteString(f.ColumnName() + " BETWEEN " + stmt.Var(from) + " AND " + stmt.Var(to)) } } @@ -123,7 +122,7 @@ func Set[T any](f sequel.ColumnValuer[T], value ...T) sequel.SetClause { if len(value) > 0 { defaultValue = f.Convert(value[0]) } - stmt.Var(f.ColumnName()+" = ", defaultValue) + stmt.WriteString(f.ColumnName() + " = " + stmt.Var(defaultValue)) } } diff --git a/examples/testcase/main/generated.go b/examples/testcase/main/generated.go index e7ee91e..1fca8ca 100755 --- a/examples/testcase/main/generated.go +++ b/examples/testcase/main/generated.go @@ -7,6 +7,9 @@ import ( "github.com/si3nloong/sqlgen/sequel" ) +func (A) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (A) TableName() string { return "`a`" } diff --git a/examples/testcase/schema/custom-declare/generated.go b/examples/testcase/schema/custom-declare/generated.go index 4d8882e..f2d8809 100755 --- a/examples/testcase/schema/custom-declare/generated.go +++ b/examples/testcase/schema/custom-declare/generated.go @@ -6,6 +6,9 @@ import ( "github.com/si3nloong/sqlgen/sequel" ) +func (A) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (A) InsertPlaceholders(row int) string { return "(?)" } diff --git a/examples/testcase/schema/dynamic-table-name/generated.go b/examples/testcase/schema/dynamic-table-name/generated.go index 5d3e110..89fb23f 100755 --- a/examples/testcase/schema/dynamic-table-name/generated.go +++ b/examples/testcase/schema/dynamic-table-name/generated.go @@ -7,6 +7,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (Model) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Model) ColumnNames() []string { return []string{"`name`"} } @@ -26,6 +29,9 @@ func (v Model) GetName() sequel.ColumnValuer[string] { return sequel.Column("`name`", v.Name, func(val string) driver.Value { return string(val) }) } +func (A) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (A) HasPK() {} func (v A) PK() (string, int, any) { return "`id`", 0, int64(v.ID) diff --git a/examples/testcase/schema/generated.go b/examples/testcase/schema/generated.go deleted file mode 100755 index 688ea34..0000000 --- a/examples/testcase/schema/generated.go +++ /dev/null @@ -1,121 +0,0 @@ -package schema - -import ( - "database/sql" - "database/sql/driver" - "time" - - "github.com/si3nloong/sqlgen/sequel" - "github.com/si3nloong/sqlgen/sequel/types" -) - -func (A) TableName() string { - return "`Apple`" -} -func (A) ColumnNames() []string { - return []string{"`id`", "`text`", "`created_at`"} -} -func (v A) Values() []any { - return []any{string(v.ID), string(v.Text), time.Time(v.CreatedAt)} -} -func (v *A) Addrs() []any { - return []any{types.String(&v.ID), types.String(&v.Text), (*time.Time)(&v.CreatedAt)} -} -func (A) InsertPlaceholders(row int) string { - return "(?,?,?)" -} -func (v A) InsertOneStmt() (string, []any) { - return "INSERT INTO `Apple` (`id`,`text`,`created_at`) VALUES (?,?,?);", v.Values() -} -func (v A) GetID() sequel.ColumnValuer[string] { - return sequel.Column("`id`", v.ID, func(val string) driver.Value { return string(val) }) -} -func (v A) GetText() sequel.ColumnValuer[LongText] { - return sequel.Column("`text`", v.Text, func(val LongText) driver.Value { return string(val) }) -} -func (v A) GetCreatedAt() sequel.ColumnValuer[time.Time] { - return sequel.Column("`created_at`", v.CreatedAt, func(val time.Time) driver.Value { return time.Time(val) }) -} - -func (B) TableName() string { - return "`b`" -} -func (B) ColumnNames() []string { - return []string{"`id`", "`created_at`"} -} -func (v B) Values() []any { - return []any{string(v.ID), time.Time(v.CreatedAt)} -} -func (v *B) Addrs() []any { - return []any{types.String(&v.ID), (*time.Time)(&v.CreatedAt)} -} -func (B) InsertPlaceholders(row int) string { - return "(?,?)" -} -func (v B) InsertOneStmt() (string, []any) { - return "INSERT INTO `b` (`id`,`created_at`) VALUES (?,?);", v.Values() -} -func (v B) GetID() sequel.ColumnValuer[string] { - return sequel.Column("`id`", v.ID, func(val string) driver.Value { return string(val) }) -} -func (v B) GetCreatedAt() sequel.ColumnValuer[time.Time] { - return sequel.Column("`created_at`", v.CreatedAt, func(val time.Time) driver.Value { return time.Time(val) }) -} - -func (C) TableName() string { - return "`c`" -} -func (C) HasPK() {} -func (v C) PK() (string, int, any) { - return "`id`", 0, int64(v.ID) -} -func (C) ColumnNames() []string { - return []string{"`id`"} -} -func (v C) Values() []any { - return []any{int64(v.ID)} -} -func (v *C) Addrs() []any { - return []any{types.Integer(&v.ID)} -} -func (C) InsertPlaceholders(row int) string { - return "(?)" -} -func (v C) InsertOneStmt() (string, []any) { - return "INSERT INTO `c` (`id`) VALUES (?);", v.Values() -} -func (v C) FindOneByPKStmt() (string, []any) { - return "SELECT `id` FROM `c` WHERE `id` = ? LIMIT 1;", []any{int64(v.ID)} -} -func (v C) GetID() sequel.ColumnValuer[int64] { - return sequel.Column("`id`", v.ID, func(val int64) driver.Value { return int64(val) }) -} - -func (D) TableName() string { - return "`d`" -} -func (D) HasPK() {} -func (v D) PK() (string, int, any) { - return "`id`", 0, (driver.Valuer)(v.ID) -} -func (D) ColumnNames() []string { - return []string{"`id`"} -} -func (v D) Values() []any { - return []any{(driver.Valuer)(v.ID)} -} -func (v *D) Addrs() []any { - return []any{(sql.Scanner)(&v.ID)} -} -func (D) InsertPlaceholders(row int) string { - return "(?)" -} -func (v D) InsertOneStmt() (string, []any) { - return "INSERT INTO `d` (`id`) VALUES (?);", v.Values() -} -func (v D) FindOneByPKStmt() (string, []any) { - return "SELECT `id` FROM `d` WHERE `id` = ? LIMIT 1;", []any{(driver.Valuer)(v.ID)} -} -func (v D) GetID() sequel.ColumnValuer[sql.NullString] { - return sequel.Column("`id`", v.ID, func(val sql.NullString) driver.Value { return (driver.Valuer)(val) }) -} diff --git a/examples/testcase/schema/nopk/generated.go b/examples/testcase/schema/nopk/generated.go index 4a736ff..bf3a9a0 100755 --- a/examples/testcase/schema/nopk/generated.go +++ b/examples/testcase/schema/nopk/generated.go @@ -7,6 +7,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (Customer) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Customer) TableName() string { return "`customer`" } diff --git a/examples/testcase/schema/table-name/generated.go b/examples/testcase/schema/table-name/generated.go index 2407456..e2170f4 100755 --- a/examples/testcase/schema/table-name/generated.go +++ b/examples/testcase/schema/table-name/generated.go @@ -7,6 +7,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (CustomTableName1) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (CustomTableName1) TableName() string { return "`CustomTableName_1`" } @@ -29,6 +32,9 @@ func (v CustomTableName1) GetText() sequel.ColumnValuer[string] { return sequel.Column("`text`", v.Text, func(val string) driver.Value { return string(val) }) } +func (CustomTableName2) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (CustomTableName2) TableName() string { return "`table_2`" } @@ -51,6 +57,9 @@ func (v CustomTableName2) GetText() sequel.ColumnValuer[string] { return sequel.Column("`text`", v.Text, func(val string) driver.Value { return string(val) }) } +func (CustomTableName3) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (CustomTableName3) TableName() string { return "`table_3`" } diff --git a/examples/testcase/struct-field/alias/generated.go b/examples/testcase/struct-field/alias/generated.go index 77473e8..fbb5799 100755 --- a/examples/testcase/struct-field/alias/generated.go +++ b/examples/testcase/struct-field/alias/generated.go @@ -9,6 +9,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (AliasStruct) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (AliasStruct) TableName() string { return "`alias_struct`" } @@ -62,6 +65,9 @@ func (v AliasStruct) GetUpdated() sequel.ColumnValuer[time.Time] { return sequel.Column("`updated`", v.model.Updated, func(val time.Time) driver.Value { return time.Time(val) }) } +func (B) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (B) TableName() string { return "`b`" } @@ -84,6 +90,9 @@ func (v B) GetName() sequel.ColumnValuer[string] { return sequel.Column("`name`", v.Name, func(val string) driver.Value { return string(val) }) } +func (C) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (C) TableName() string { return "`c`" } diff --git a/examples/testcase/struct-field/binary/generated.go b/examples/testcase/struct-field/binary/generated.go index 64cdc81..27c35e3 100755 --- a/examples/testcase/struct-field/binary/generated.go +++ b/examples/testcase/struct-field/binary/generated.go @@ -10,6 +10,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (Binary) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Binary) TableName() string { return "`binary`" } diff --git a/examples/testcase/struct-field/custom/generated.go b/examples/testcase/struct-field/custom/generated.go index f472b9e..823bc96 100755 --- a/examples/testcase/struct-field/custom/generated.go +++ b/examples/testcase/struct-field/custom/generated.go @@ -12,6 +12,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (Address) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Address) TableName() string { return "`address`" } @@ -55,6 +58,9 @@ func (v Address) GetCountryCode() sequel.ColumnValuer[CountryCode] { return sequel.Column("`country_code`", v.CountryCode, func(val CountryCode) driver.Value { return string(val) }) } +func (Customer) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Customer) TableName() string { return "`customer`" } diff --git a/examples/testcase/struct-field/date/generated.go b/examples/testcase/struct-field/date/generated.go index 14e72bf..6df05da 100755 --- a/examples/testcase/struct-field/date/generated.go +++ b/examples/testcase/struct-field/date/generated.go @@ -10,6 +10,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (User) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (User) TableName() string { return "`user`" } diff --git a/examples/testcase/struct-field/enum/generated.go b/examples/testcase/struct-field/enum/generated.go index 9d1d566..ab6104c 100755 --- a/examples/testcase/struct-field/enum/generated.go +++ b/examples/testcase/struct-field/enum/generated.go @@ -7,6 +7,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (Custom) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Custom) TableName() string { return "`custom`" } diff --git a/examples/testcase/struct-field/imported/generated.go b/examples/testcase/struct-field/imported/generated.go index cf20782..22f9a16 100755 --- a/examples/testcase/struct-field/imported/generated.go +++ b/examples/testcase/struct-field/imported/generated.go @@ -9,6 +9,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (Model) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Model) TableName() string { return "`model`" } @@ -49,6 +52,9 @@ func (v Model) GetTime() sequel.ColumnValuer[sql.NullTime] { return sequel.Column("`time`", v.Time, func(val sql.NullTime) driver.Value { return (driver.Valuer)(val) }) } +func (Some) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Some) TableName() string { return "`some`" } diff --git a/examples/testcase/struct-field/pk/auto-incr/generated.go b/examples/testcase/struct-field/pk/auto-incr/generated.go index bee305e..0f2eca1 100755 --- a/examples/testcase/struct-field/pk/auto-incr/generated.go +++ b/examples/testcase/struct-field/pk/auto-incr/generated.go @@ -7,6 +7,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (Model) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Model) TableName() string { return "`AutoIncrPK`" } diff --git a/examples/testcase/struct-field/pk/composite/generated.go b/examples/testcase/struct-field/pk/composite/generated.go index 8f5df04..f221917 100755 --- a/examples/testcase/struct-field/pk/composite/generated.go +++ b/examples/testcase/struct-field/pk/composite/generated.go @@ -9,6 +9,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (Composite) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Composite) TableName() string { return "`composite`" } diff --git a/examples/testcase/struct-field/pk/generated.go b/examples/testcase/struct-field/pk/generated.go index 4c15675..70cb1b5 100755 --- a/examples/testcase/struct-field/pk/generated.go +++ b/examples/testcase/struct-field/pk/generated.go @@ -8,6 +8,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (Car) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Car) TableName() string { return "`car`" } @@ -49,6 +52,9 @@ func (v Car) GetManucDate() sequel.ColumnValuer[time.Time] { return sequel.Column("`manuc_date`", v.ManucDate, func(val time.Time) driver.Value { return time.Time(val) }) } +func (User) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (User) TableName() string { return "`user`" } @@ -90,6 +96,9 @@ func (v User) GetEmail() sequel.ColumnValuer[string] { return sequel.Column("`email`", v.Email, func(val string) driver.Value { return string(val) }) } +func (House) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (House) TableName() string { return "`house`" } diff --git a/examples/testcase/struct-field/pk/uuid/generated.go b/examples/testcase/struct-field/pk/uuid/generated.go index 9e30b27..fa47aff 100755 --- a/examples/testcase/struct-field/pk/uuid/generated.go +++ b/examples/testcase/struct-field/pk/uuid/generated.go @@ -9,6 +9,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (User) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (User) TableName() string { return "`user`" } diff --git a/examples/testcase/struct-field/pointer/generated.go b/examples/testcase/struct-field/pointer/generated.go index 33542bc..f328de9 100755 --- a/examples/testcase/struct-field/pointer/generated.go +++ b/examples/testcase/struct-field/pointer/generated.go @@ -8,6 +8,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (Ptr) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Ptr) TableName() string { return "`ptr`" } diff --git a/examples/testcase/struct-field/primitive/generated.go b/examples/testcase/struct-field/primitive/generated.go index a8782f8..1a3eaf4 100755 --- a/examples/testcase/struct-field/primitive/generated.go +++ b/examples/testcase/struct-field/primitive/generated.go @@ -8,6 +8,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (Primitive) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Primitive) TableName() string { return "`primitive`" } diff --git a/examples/testcase/struct-field/size/generated.go b/examples/testcase/struct-field/size/generated.go index f685599..e284cc1 100755 --- a/examples/testcase/struct-field/size/generated.go +++ b/examples/testcase/struct-field/size/generated.go @@ -8,6 +8,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (Size) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Size) TableName() string { return "`size`" } diff --git a/examples/testcase/struct-field/slice/generated.go b/examples/testcase/struct-field/slice/generated.go index f35c230..72f66b8 100755 --- a/examples/testcase/struct-field/slice/generated.go +++ b/examples/testcase/struct-field/slice/generated.go @@ -8,6 +8,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (Slice) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Slice) TableName() string { return "`slice`" } diff --git a/examples/testcase/struct-field/valuer/generated.go b/examples/testcase/struct-field/valuer/generated.go index 5894917..d01fe5a 100755 --- a/examples/testcase/struct-field/valuer/generated.go +++ b/examples/testcase/struct-field/valuer/generated.go @@ -7,6 +7,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (B) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (B) TableName() string { return "`b`" } diff --git a/examples/testcase/struct-field/version/generated.go b/examples/testcase/struct-field/version/generated.go index 2974581..f62b944 100755 --- a/examples/testcase/struct-field/version/generated.go +++ b/examples/testcase/struct-field/version/generated.go @@ -8,6 +8,9 @@ import ( "github.com/si3nloong/sqlgen/sequel" ) +func (Version) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (Version) TableName() string { return "`version`" } diff --git a/examples/testcase/struct/alias/generated.go b/examples/testcase/struct/alias/generated.go index 9e073ee..adc8b70 100755 --- a/examples/testcase/struct/alias/generated.go +++ b/examples/testcase/struct/alias/generated.go @@ -8,6 +8,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (A) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (A) ColumnNames() []string { return []string{"`date`", "`time`"} } @@ -30,6 +33,9 @@ func (v A) GetTime() sequel.ColumnValuer[civil.Time] { return sequel.Column("`time`", v.Time, func(val civil.Time) driver.Value { return types.TextMarshaler(val) }) } +func (C) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (C) TableName() string { return "`c`" } diff --git a/examples/testcase/struct/embedded/imported/generated.go b/examples/testcase/struct/embedded/imported/generated.go index d1d9788..f04b75e 100755 --- a/examples/testcase/struct/embedded/imported/generated.go +++ b/examples/testcase/struct/embedded/imported/generated.go @@ -8,6 +8,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (B) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (B) TableName() string { return "`b`" } diff --git a/examples/testcase/struct/embedded/local/generated.go b/examples/testcase/struct/embedded/local/generated.go index 4b0e083..5f669ab 100755 --- a/examples/testcase/struct/embedded/local/generated.go +++ b/examples/testcase/struct/embedded/local/generated.go @@ -8,6 +8,9 @@ import ( "github.com/si3nloong/sqlgen/sequel/types" ) +func (B) Schemas() sequel.TableDefinition { + return sequel.TableDefinition{} +} func (B) TableName() string { return "`b`" } diff --git a/sequel/column.go b/sequel/column.go index cc9a0e1..e687f4a 100644 --- a/sequel/column.go +++ b/sequel/column.go @@ -26,14 +26,14 @@ func Column[T any](columnName string, value T, convert ConvertFunc[T]) ColumnVal type sqlCol[T any] struct { column[T] - sqlValuer SQLFunc + sqlValuer QueryFunc } func (c sqlCol[T]) SQLValue(placeholder string) string { return c.sqlValuer(placeholder) } -func SQLColumn[T any](columnName string, value T, sqlValue SQLFunc, convert ConvertFunc[T]) SQLColumnValuer[T] { +func SQLColumn[T any](columnName string, value T, sqlValue QueryFunc, convert ConvertFunc[T]) SQLColumnValuer[T] { c := sqlCol[T]{} c.colName = columnName c.v = convert(value) diff --git a/sequel/dialect/mysql/alter_table_stmt.go b/sequel/dialect/mysql/alter_table_stmt.go deleted file mode 100644 index e69bc57..0000000 --- a/sequel/dialect/mysql/alter_table_stmt.go +++ /dev/null @@ -1,37 +0,0 @@ -package mysql - -import ( - "github.com/si3nloong/sqlgen/codegen/templates" - "github.com/si3nloong/sqlgen/sequel/strpool" -) - -func (d *mysqlDriver) AlterTableStmt(n string, model *templates.Model) string { - buf := strpool.AcquireString() - defer strpool.ReleaseString(buf) - buf.WriteString(`"ALTER TABLE "+ ` + n + `.TableName() +" (`) - for i, f := range model.Fields { - if i > 0 { - buf.WriteByte(',') - } - buf.WriteString("MODIFY " + d.QuoteIdentifier(f.ColumnName) + " " + dataType(f)) - if model.IsAutoIncr && f == model.Keys[0] { - buf.WriteString(" AUTO_INCREMENT") - } - if i > 0 { - // buf.WriteString(" FIRST") - buf.WriteString(" AFTER " + d.QuoteIdentifier(model.Fields[i-1].ColumnName)) - } - } - if len(model.Keys) > 0 { - buf.WriteString(",PRIMARY KEY (") - for i, k := range model.Keys { - if i > 0 { - buf.WriteByte(',') - } - buf.WriteString(d.QuoteIdentifier(k.ColumnName)) - } - buf.WriteByte(')') - } - buf.WriteString(`);"`) - return buf.String() -} diff --git a/sequel/dialect/mysql/create_table_stmt.go b/sequel/dialect/mysql/create_table_stmt.go deleted file mode 100644 index 5f09f4c..0000000 --- a/sequel/dialect/mysql/create_table_stmt.go +++ /dev/null @@ -1,38 +0,0 @@ -package mysql - -import ( - "github.com/si3nloong/sqlgen/codegen/templates" - "github.com/si3nloong/sqlgen/sequel/strpool" -) - -func (d *mysqlDriver) CreateTableStmt(n string, model *templates.Model) string { - buf := strpool.AcquireString() - defer strpool.ReleaseString(buf) - - if model.HasTableName { - buf.WriteString(`"CREATE TABLE IF NOT EXISTS "+ ` + n + `.TableName() +" (`) - } else { - buf.WriteString(`"CREATE TABLE IF NOT EXISTS ` + d.QuoteIdentifier(model.TableName) + ` (`) - } - for i, f := range model.Fields { - if i > 0 { - buf.WriteByte(',') - } - buf.WriteString(d.QuoteIdentifier(f.ColumnName) + " " + dataType(f)) - if model.IsAutoIncr && f == model.Keys[0] { - buf.WriteString(" AUTO_INCREMENT") - } - } - if len(model.Keys) > 0 { - buf.WriteString(",PRIMARY KEY (") - for i, k := range model.Keys { - if i > 0 { - buf.WriteByte(',') - } - buf.WriteString(d.QuoteIdentifier(k.ColumnName)) - } - buf.WriteByte(')') - } - buf.WriteString(`);"`) - return buf.String() -} diff --git a/sequel/dialect/mysql/data_type.go b/sequel/dialect/mysql/data_type.go index 76b8c51..feed10f 100644 --- a/sequel/dialect/mysql/data_type.go +++ b/sequel/dialect/mysql/data_type.go @@ -4,17 +4,20 @@ import ( "database/sql" "fmt" "go/types" + "strconv" "strings" + "unsafe" - "github.com/si3nloong/sqlgen/codegen/templates" + "github.com/si3nloong/sqlgen/sequel" ) -func dataType(f *templates.Field) (dataType string) { +func dataType(f sequel.ColumnSchema) (dataType string) { var ( ptrs = make([]types.Type, 0) - t = f.Type + t = f.Type() prev types.Type ) + for t != nil { switch v := t.(type) { case *types.Pointer: @@ -30,63 +33,66 @@ func dataType(f *templates.Field) (dataType string) { switch t.String() { case "rune": - return "CHAR(1)" + notNull(len(ptrs) > 0) + return "CHAR(1)" + notNullDefault(ptrs) case "bool": - return "BOOL" + notNull(len(ptrs) > 0) + return "BOOL" + notNullDefault(ptrs, false) case "int8": - return "TINYINT" + notNull(len(ptrs) > 0) + return "TINYINT" + notNullDefault(ptrs, 0) case "int16": - return "SMALLINT" + notNull(len(ptrs) > 0) + return "SMALLINT" + notNullDefault(ptrs, 0) case "int32": - return "MEDIUMINT" + notNull(len(ptrs) > 0) + return "MEDIUMINT" + notNullDefault(ptrs, 0) case "int64": - return "BIGINT" + notNull(len(ptrs) > 0) + return "BIGINT" + notNullDefault(ptrs, 0) case "int", "uint": - return "INTEGER" + notNull(len(ptrs) > 0) + return "INTEGER" + notNullDefault(ptrs, 0) case "uint8": - return "TINYINT UNSIGNED" + notNull(len(ptrs) > 0) + return "TINYINT UNSIGNED" + notNullDefault(ptrs, 0) case "uint16": - return "SMALLINT UNSIGNED" + notNull(len(ptrs) > 0) + return "SMALLINT UNSIGNED" + notNullDefault(ptrs, 0) case "uint32": - return "MEDIUMINT UNSIGNED" + notNull(len(ptrs) > 0) + return "MEDIUMINT UNSIGNED" + notNullDefault(ptrs, 0) case "uint64": - return "BIGINT UNSIGNED" + notNull(len(ptrs) > 0) + return "BIGINT UNSIGNED" + notNullDefault(ptrs, 0) case "float32": - return "FLOAT" + notNull(len(ptrs) > 0) + return "FLOAT" + notNullDefault(ptrs, 0.0) case "float64": - return "FLOAT" + notNull(len(ptrs) > 0) + return "FLOAT" + notNullDefault(ptrs, 0.0) + case "cloud.google.com/go/civil.Time": + return "TIME" + notNullDefault(ptrs) case "cloud.google.com/go/civil.Date": - return "DATE" + notNull(len(ptrs) > 0) - case "time.Time": - var size int - if f.Size > 0 && f.Size < 7 { - size = f.Size + return "DATE" + notNullDefault(ptrs) + case "cloud.google.com/go/civil.DateTime": + if size := f.Size(); size > 0 { + return fmt.Sprintf("DATETIME(%d)", size) + notNullDefault(ptrs, sql.RawBytes(fmt.Sprintf("CURRENT_TIMESTAMP(%d)", size))) } - if size > 0 { - return fmt.Sprintf("DATETIME(%d)", size) + notNull(len(ptrs) > 0) + return "DATETIME" + notNullDefault(ptrs, sql.RawBytes(`CURRENT_TIMESTAMP`)) + case "time.Time": + if size := f.Size(); size > 0 { + return fmt.Sprintf("TIMESTAMP(%d)", size) + notNullDefault(ptrs, sql.RawBytes(fmt.Sprintf("CURRENT_TIMESTAMP(%d)", size))) } - return "DATETIME" + notNullDefault(ptrs, sql.RawBytes(`NOW()`)) + return "TIMESTAMP" + notNullDefault(ptrs, sql.RawBytes(`CURRENT_TIMESTAMP`)) case "string": size := 255 - if f.Size > 0 { - size = f.Size + if v := f.Size(); v > 0 { + size = v } - return fmt.Sprintf("VARCHAR(%d)", size) + notNullDefault(ptrs) + return fmt.Sprintf("VARCHAR(%d)", size) + notNullDefault(ptrs, "") case "[]byte": - return "BLOB" + notNull(len(ptrs) > 0) + return "BLOB" + notNullDefault(ptrs) case "[16]byte": - if f.IsBinary { - return "BINARY(16)" - } - return "VARCHAR(36)" + // if f.IsBinary { + // return "BINARY(16)" + // } + return "VARCHAR(36)" + notNullDefault(ptrs, sql.RawBytes(`UUID()`)) case "encoding/json.RawMessage": - return "JSON" + notNull(len(ptrs) > 0) + return "JSON" + notNullDefault(ptrs) default: switch { case strings.HasPrefix(t.String(), "[]"): - return "JSON" + notNull(len(ptrs) > 0) + return "JSON" + notNullDefault(ptrs) case strings.HasPrefix(t.String(), "map"): - return "JSON" + notNull(len(ptrs) > 0) + return "JSON" + notNullDefault(ptrs) } } if prev == t { @@ -94,23 +100,34 @@ func dataType(f *templates.Field) (dataType string) { } t = prev } - return "VARCHAR(255)" + notNull(len(ptrs) > 0) -} - -func notNull(isNull bool) string { - if isNull { - return "" - } - return " NOT NULL" + return "VARCHAR(255)" + notNullDefault(ptrs) } func notNullDefault(ptrs []types.Type, defaultValue ...any) string { if len(ptrs) > 0 { return "" } - str := " NOT NULL" if len(defaultValue) > 0 { - str += fmt.Sprintf(" DEFAULT %v", defaultValue[0]) + return " NOT NULL DEFAULT " + format(defaultValue[0]) + } + return " NOT NULL" +} + +func format(v any) string { + switch vi := v.(type) { + case string: + return "'" + vi + "'" + case bool: + return strconv.FormatBool(vi) + case int: + return strconv.Itoa(vi) + case float32: + return strconv.FormatFloat(float64(vi), 'f', -1, 64) + case float64: + return strconv.FormatFloat(vi, 'f', -1, 64) + case sql.RawBytes: + return unsafe.String(unsafe.SliceData(vi), len(vi)) + default: + panic("unsupported data type") } - return str } diff --git a/sequel/dialect/mysql/schema.go b/sequel/dialect/mysql/schema.go new file mode 100644 index 0000000..f3a5ccf --- /dev/null +++ b/sequel/dialect/mysql/schema.go @@ -0,0 +1,37 @@ +package mysql + +import ( + "fmt" + "strings" + + "github.com/samber/lo" + "github.com/si3nloong/sqlgen/sequel" +) + +func (s *mysqlDriver) TableSchemas(table sequel.TableSchema) sequel.TableDefinition { + def := sequel.TableDefinition{} + for _, col := range table.Columns() { + if k, ok := table.AutoIncrKey(); ok && k == col { + def.Columns = append(def.Columns, sequel.ColumnDefinition{ + Definition: dataType(col) + " AUTO_INCREMENT", + }) + } else { + def.Columns = append(def.Columns, sequel.ColumnDefinition{ + Definition: dataType(col), + }) + } + } + if keys := table.Keys(); len(keys) > 0 { + keyCols := lo.Map(keys, func(v sequel.ColumnSchema, _ int) string { + return v.ColumnName() + }) + def.PK.Columns = append(def.PK.Columns, keyCols...) + def.PK.Definition = fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(keyCols, ",")) + } + if idxs := table.Indexes(); len(idxs) > 0 { + def.Indexes = append(def.Indexes, sequel.IndexDefinition{ + Definition: "", + }) + } + return def +} diff --git a/sequel/dialect/postgres/alter_table_stmt.go b/sequel/dialect/postgres/alter_table_stmt.go deleted file mode 100644 index ed7a599..0000000 --- a/sequel/dialect/postgres/alter_table_stmt.go +++ /dev/null @@ -1,34 +0,0 @@ -package postgres - -// import ( -// "github.com/si3nloong/sqlgen/codegen/templates" -// "github.com/si3nloong/sqlgen/sequel/strpool" -// ) - -// func (d *postgresDriver) AlterTableStmt(n string, model *templates.Model) string { -// buf := strpool.AcquireString() -// defer strpool.ReleaseString(buf) -// if model.HasTableName { -// buf.WriteString("`ALTER TABLE `+ " + n + ".TableName() +` (") -// } else { -// buf.WriteString("`ALTER TABLE " + d.QuoteIdentifier(model.TableName) + " (") -// } -// for i, f := range model.Fields { -// if i > 0 { -// buf.WriteByte(',') -// } -// buf.WriteString("MODIFY " + d.QuoteIdentifier(f.ColumnName) + " " + d.dataType(f)) -// } -// if len(model.Keys) > 0 { -// buf.WriteString(",PRIMARY KEY (") -// for i, k := range model.Keys { -// if i > 0 { -// buf.WriteByte(',') -// } -// buf.WriteString(d.QuoteIdentifier(k.ColumnName)) -// } -// buf.WriteByte(')') -// } -// buf.WriteString(");`") -// return buf.String() -// } diff --git a/sequel/dialect/postgres/create_table_stmt.go b/sequel/dialect/postgres/create_table_stmt.go deleted file mode 100644 index 4520fe1..0000000 --- a/sequel/dialect/postgres/create_table_stmt.go +++ /dev/null @@ -1,31 +0,0 @@ -package postgres - -import ( - "github.com/si3nloong/sqlgen/codegen/templates" - "github.com/si3nloong/sqlgen/sequel/strpool" -) - -func (d *postgresDriver) CreateTableStmt(n string, model *templates.Model) string { - buf := strpool.AcquireString() - defer strpool.ReleaseString(buf) - - buf.WriteString("`CREATE TABLE IF NOT EXISTS `+ " + n + ".TableName() +` (") - for i, f := range model.Fields { - if i > 0 { - buf.WriteByte(',') - } - buf.WriteString(d.QuoteIdentifier(f.ColumnName) + " " + d.dataType(f)) - } - if len(model.Keys) > 0 { - buf.WriteString(",PRIMARY KEY (") - for i, k := range model.Keys { - if i > 0 { - buf.WriteByte(',') - } - buf.WriteString(d.QuoteIdentifier(k.ColumnName)) - } - buf.WriteByte(')') - } - buf.WriteString(");`") - return buf.String() -} diff --git a/sequel/dialect/postgres/data_type.go b/sequel/dialect/postgres/data_type.go index 26ae4dc..3314ff4 100644 --- a/sequel/dialect/postgres/data_type.go +++ b/sequel/dialect/postgres/data_type.go @@ -8,15 +8,16 @@ import ( "strings" "unsafe" - "github.com/si3nloong/sqlgen/codegen/templates" + "github.com/si3nloong/sqlgen/sequel" ) -func (d *postgresDriver) dataType(f *templates.Field) (dataType string) { +func dataType(f sequel.ColumnSchema) (dataType string) { var ( ptrs = make([]types.Type, 0) - t = f.Type + t = f.Type() prev types.Type ) + for t != nil { switch v := t.(type) { case *types.Pointer: @@ -42,30 +43,33 @@ func (d *postgresDriver) dataType(f *templates.Field) (dataType string) { case "bool": return "BOOL" + notNullDefault(ptrs, false) case "uint8", "uint16", "byte": - return "INT2" + notNullDefault(ptrs, 0) + " CHECK(" + d.QuoteIdentifier(f.ColumnName) + " >= 0)" + return "INT2" + notNullDefault(ptrs, 0) + " CHECK(" + f.ColumnName() + " >= 0)" case "uint32", "uint": - return "INT" + notNullDefault(ptrs, 0) + " CHECK(" + d.QuoteIdentifier(f.ColumnName) + " >= 0)" + return "INT" + notNullDefault(ptrs, 0) + " CHECK(" + f.ColumnName() + " >= 0)" case "uint64": - return "INT8" + notNullDefault(ptrs, 0) + " CHECK(" + d.QuoteIdentifier(f.ColumnName) + " >= 0)" + return "INT8" + notNullDefault(ptrs, 0) + " CHECK(" + f.ColumnName() + " >= 0)" case "float32": return "DOUBLE PRECISION" + notNullDefault(ptrs, 0.0) case "float64": return "DOUBLE PRECISION" + notNullDefault(ptrs, 0.0) + case "cloud.google.com/go/civil.Time": + return "TIME" + notNullDefault(ptrs, sql.RawBytes(`CURRENT_TIME`)) case "cloud.google.com/go/civil.Date": return "DATE" + notNullDefault(ptrs, sql.RawBytes(`CURRENT_DATE`)) - case "time.Time": - var size int - if f.Size > 0 && f.Size < 7 { - size = f.Size - } - if size > 0 { + case "cloud.google.com/go/civil.DateTime": + if size := f.Size(); size > 0 { return fmt.Sprintf("TIMESTAMP(%d)", size) + notNullDefault(ptrs, sql.RawBytes(`NOW()`)) } return "TIMESTAMP" + notNullDefault(ptrs, sql.RawBytes(`NOW()`)) + case "time.Time": + if size := f.Size(); size > 0 { + return fmt.Sprintf("TIMESTAMP(%d) WITH TIME ZONE", size) + notNullDefault(ptrs, sql.RawBytes(`NOW()`)) + } + return "TIMESTAMP WITH TIME ZONE" + notNullDefault(ptrs, sql.RawBytes(`NOW()`)) case "string": size := 255 - if f.Size > 0 { - size = f.Size + if v := f.Size(); v > 0 { + size = v } return fmt.Sprintf("VARCHAR(%d)", size) + notNullDefault(ptrs, "") case "[]rune": @@ -73,17 +77,17 @@ func (d *postgresDriver) dataType(f *templates.Field) (dataType string) { case "[]byte": return "BYTEA" + notNullDefault(ptrs) case "[16]byte": - if f.IsBinary { - return "BIT(16)" - } + // if f.IsBinary { + // return "BIT(16)" + // } return "VARBIT(36)" case "encoding/json.RawMessage": return "VARBIT" + notNullDefault(ptrs) default: if strings.HasPrefix(t.String(), "[]") { - if f.IsBinary { - return "JSONB" + notNullDefault(ptrs) - } + // if f.IsBinary { + // return "JSONB" + notNullDefault(ptrs) + // } return "JSON" + notNullDefault(ptrs) } } diff --git a/sequel/dialect/postgres/schema.go b/sequel/dialect/postgres/schema.go new file mode 100644 index 0000000..458c1db --- /dev/null +++ b/sequel/dialect/postgres/schema.go @@ -0,0 +1,22 @@ +package postgres + +import ( + "github.com/si3nloong/sqlgen/sequel" +) + +func (s *postgresDriver) TableSchemas(table sequel.TableSchema) sequel.TableDefinition { + // schemas := make([]string, 0) + // for _, col := range table.Columns() { + // schemas = append(schemas, dataType(col)) + // } + // for _, idx := range table.Indexes() { + // log.Println(idx) + // } + // if len(table.Keys()) > 0 { + // log.Println("PRIMARY KEY (" + strings.Join(lo.Map(table.Keys(), func(f sequel.ColumnSchema, _ int) string { + // return f.ColumnName() + // }), ",") + ")") + // } + // return schemas + return sequel.TableDefinition{} +} diff --git a/sequel/dialect/sqlite/alter_table_stmt.go b/sequel/dialect/sqlite/alter_table_stmt.go deleted file mode 100644 index c5595b3..0000000 --- a/sequel/dialect/sqlite/alter_table_stmt.go +++ /dev/null @@ -1,27 +0,0 @@ -package sqlite - -import ( - "github.com/si3nloong/sqlgen/codegen/templates" - "github.com/si3nloong/sqlgen/sequel/strpool" -) - -func (d sqliteDriver) AlterTableStmt(n string, model *templates.Model) string { - buf := strpool.AcquireString() - defer strpool.ReleaseString(buf) - buf.WriteString(`"ALTER TABLE "+ ` + n + `.TableName() +" (`) - for i, f := range model.Fields { - if i > 0 { - buf.WriteByte(',') - } - buf.WriteString("MODIFY " + d.QuoteIdentifier(f.ColumnName) + " " + dataType(f)) - if model.IsAutoIncr && f == model.Keys[0] { - buf.WriteString(" AUTO_INCREMENT") - } - if i > 0 { - // buf.WriteString(" FIRST") - buf.WriteString(" AFTER " + d.QuoteIdentifier(model.Fields[i-1].ColumnName)) - } - } - buf.WriteString(`);"`) - return buf.String() -} diff --git a/sequel/dialect/sqlite/create_table_stmt.go b/sequel/dialect/sqlite/create_table_stmt.go deleted file mode 100644 index 89684b3..0000000 --- a/sequel/dialect/sqlite/create_table_stmt.go +++ /dev/null @@ -1,34 +0,0 @@ -package sqlite - -import ( - "github.com/si3nloong/sqlgen/codegen/templates" - "github.com/si3nloong/sqlgen/sequel/strpool" -) - -func (d sqliteDriver) CreateTableStmt(n string, model *templates.Model) string { - buf := strpool.AcquireString() - defer strpool.ReleaseString(buf) - - buf.WriteString("`CREATE TABLE IF NOT EXISTS `+ " + n + ".TableName() +` (") - for i, f := range model.Fields { - if i > 0 { - buf.WriteByte(',') - } - buf.WriteString(d.QuoteIdentifier(f.ColumnName) + " " + dataType(f)) - if model.IsAutoIncr && f == model.Keys[0] { - buf.WriteString(" AUTO_INCREMENT") - } - } - if len(model.Keys) > 0 { - buf.WriteString(",PRIMARY KEY (") - for i, k := range model.Keys { - if i > 0 { - buf.WriteByte(',') - } - buf.WriteString(d.QuoteIdentifier(k.ColumnName)) - } - buf.WriteByte(')') - } - buf.WriteString(");`") - return buf.String() -} diff --git a/sequel/dialect/sqlite/data_type.go b/sequel/dialect/sqlite/data_type.go index f097112..da16507 100644 --- a/sequel/dialect/sqlite/data_type.go +++ b/sequel/dialect/sqlite/data_type.go @@ -5,13 +5,13 @@ import ( "go/types" "strings" - "github.com/si3nloong/sqlgen/codegen/templates" + "github.com/si3nloong/sqlgen/sequel" ) -func dataType(f *templates.Field) (dataType string) { +func dataType(f sequel.ColumnSchema) (dataType string) { var ( ptrs = make([]types.Type, 0) - t = f.Type + t = f.Type() prev types.Type ) for t != nil { @@ -55,26 +55,22 @@ func dataType(f *templates.Field) (dataType string) { case "cloud.google.com/go/civil.Date": return "DATE" case "time.Time": - var size int - if f.Size > 0 && f.Size < 7 { - size = f.Size - } - if size > 0 { + if size := f.Size(); size > 0 { return fmt.Sprintf("DATETIME(%d)", size) + notNull(len(ptrs) > 0) } return "DATETIME" + notNull(len(ptrs) > 0) case "string": size := 255 - if f.Size > 0 { - size = f.Size + if v := f.Size(); v > 0 { + size = v } return fmt.Sprintf("VARCHAR(%d)", size) + notNull(len(ptrs) > 0) case "[]byte": return "BLOB" + notNull(len(ptrs) > 0) case "[16]byte": - if f.IsBinary { - return "BINARY(16)" - } + // if f.IsBinary { + // return "BINARY(16)" + // } return "VARCHAR(36)" case "encoding/json.RawMessage": return "JSON" + notNull(len(ptrs) > 0) diff --git a/sequel/dialect/sqlite/schema.go b/sequel/dialect/sqlite/schema.go new file mode 100644 index 0000000..e7ccbe7 --- /dev/null +++ b/sequel/dialect/sqlite/schema.go @@ -0,0 +1,37 @@ +package sqlite + +import ( + "fmt" + "strings" + + "github.com/samber/lo" + "github.com/si3nloong/sqlgen/sequel" +) + +func (s *sqliteDriver) TableSchemas(table sequel.TableSchema) sequel.TableDefinition { + def := sequel.TableDefinition{} + for _, col := range table.Columns() { + if k, ok := table.AutoIncrKey(); ok && k == col { + def.Columns = append(def.Columns, sequel.ColumnDefinition{ + Definition: dataType(col) + " AUTO_INCREMENT", + }) + } else { + def.Columns = append(def.Columns, sequel.ColumnDefinition{ + Definition: dataType(col), + }) + } + } + if keys := table.Keys(); len(keys) > 0 { + keyCols := lo.Map(keys, func(v sequel.ColumnSchema, _ int) string { + return v.ColumnName() + }) + def.PK.Columns = append(def.PK.Columns, keyCols...) + def.PK.Definition = fmt.Sprintf("PRIMARY KEY (%s)", strings.Join(keyCols, ",")) + } + if idxs := table.Indexes(); len(idxs) > 0 { + def.Indexes = append(def.Indexes, sequel.IndexDefinition{ + Definition: "", + }) + } + return def +} diff --git a/sequel/interface.go b/sequel/interface.go index 8c62e08..7ffa147 100644 --- a/sequel/interface.go +++ b/sequel/interface.go @@ -82,11 +82,11 @@ type SingleInserter interface { } type Inserter interface { - Columner + TableColumnValuer InsertPlaceholders(row int) string } -type TableColumnValuer[T any] interface { +type TableColumnValuer interface { Tabler Columner Valuer @@ -118,6 +118,6 @@ type Stmt interface { type StmtBuilder interface { StmtWriter - Var(query string, v any) - Vars(query string, v []any) + Var(v any) string + Vars(vals []any) string } diff --git a/sequel/migration.go b/sequel/migration.go new file mode 100644 index 0000000..0ab43ee --- /dev/null +++ b/sequel/migration.go @@ -0,0 +1,24 @@ +package sequel + +type TableDefinition struct { + PK *PrimaryKeyDefinition + Columns []ColumnDefinition + Indexes []IndexDefinition +} + +type PrimaryKeyDefinition struct { + Columns []string + Definition string +} + +type ColumnDefinition struct { + Name string + Definition string +} + +type IndexDefinition struct { + Name string + Columns []string + Type string + Definition string +} diff --git a/sequel/sequel.go b/sequel/sequel.go index 3641fba..4d107c0 100644 --- a/sequel/sequel.go +++ b/sequel/sequel.go @@ -9,7 +9,7 @@ import ( type ( ConvertFunc[T any] func(T) driver.Value - SQLFunc func(placeholder string) string + QueryFunc func(placeholder string) string ) type DB interface { @@ -28,7 +28,7 @@ type Dialect interface { QuoteIdentifier(v string) string QuoteRune() rune - CreateTableStmt(n string, model TableSchema) string + TableSchemas(model TableSchema) TableDefinition } type TableSchema interface { @@ -36,9 +36,10 @@ type TableSchema interface { DatabaseName() string TableName() string AutoIncrKey() (ColumnSchema, bool) - Implements(*types.Interface) (wrongType bool) Keys() []ColumnSchema Columns() []ColumnSchema + Indexes() []IndexSchema + Implements(*types.Interface) (*types.Func, bool) } type ColumnSchema interface { @@ -47,11 +48,16 @@ type ColumnSchema interface { // GoTag() reflect.StructTag ColumnName() string ColumnPos() int + Size() int Type() types.Type - SQLValuer() SQLFunc - SQLScanner() SQLFunc - // ActualType() string - Implements(*types.Interface) (wrongType bool) + SQLValuer() QueryFunc + SQLScanner() QueryFunc +} + +type IndexSchema interface { + Name() string + Type() string + ColumnNames() []string } type ColumnValuer[T any] interface { @@ -68,7 +74,5 @@ type SQLColumnValuer[T any] interface { } type Migrator interface { - // [0] is column name - // [1] is column data type - Schemas() [][2]string + Schemas() TableDefinition }