Skip to content

Commit

Permalink
fix: fixed bug creating table when model has no columns
Browse files Browse the repository at this point in the history
If you wanted to create a table using only the ColumnExpr function and
not provide columns from the struct model. Then you would receive syntax
errors as the ColumnExpr columns always have a `,` prepended before them
even if no columns were appended before them. This makes sure that the
`,` is properly prepended before adding columns from the ColumnExpr. It
also evaluates the AUTO_INCREMENT feature from the provided formatter
dialect rather than the DB. This will make testing the AppendQuery
method easier in the future as the library evolves as it this method is
now _slightly_ less dependent on a DB connection to be called.
  • Loading branch information
elliotcourant committed Oct 27, 2021
1 parent 9a44f93 commit 042c50b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 8 deletions.
47 changes: 47 additions & 0 deletions internal/dbtest/query_test.go
Expand Up @@ -9,6 +9,8 @@ import (
"time"

"github.com/bradleyjkemp/cupaloy"
"github.com/stretchr/testify/assert"
"github.com/uptrace/bun/dialect"

"github.com/uptrace/bun"
"github.com/uptrace/bun/schema"
Expand Down Expand Up @@ -645,3 +647,48 @@ func TestQuery(t *testing.T) {
}
})
}

func TestCreateTableQuery_AppendQuery(t *testing.T) {
testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
t.Run("with columns from model", func(t *testing.T) {
type Login struct {
Email string
Password string
}

output, err := bun.NewCreateTableQuery(db).
Model(&Login{}).
AppendQuery(schema.NewFormatter(db.Dialect()), nil)
assert.NoError(t, err, "should be able to generate a simple create table query without error")
assert.NotEmpty(t, output, "the resulting query byte array should not be empty")
switch db.Dialect().Name() {
case dialect.PG, dialect.SQLite:
assert.Equal(t, `CREATE TABLE "logins" ("email" VARCHAR, "password" VARCHAR)`, string(output))
case dialect.MySQL:
assert.Equal(t, "CREATE TABLE `logins` (`email` VARCHAR(255), `password` VARCHAR(255))", string(output))
default:
t.Fatalf("unknown dialect: %+v", db.Dialect().Name())
}
})

t.Run("with columns from column expr", func(t *testing.T) {
type Login struct{}

output, err := bun.NewCreateTableQuery(db).
Model(&Login{}).
ColumnExpr(`email VARCHAR`).
ColumnExpr(`password VARCHAR`).
AppendQuery(schema.NewFormatter(db.Dialect()), nil)
assert.NoError(t, err, "should be able to generate a simple create table query without error")
assert.NotEmpty(t, output, "the resulting query byte array should not be empty")
switch db.Dialect().Name() {
case dialect.PG, dialect.SQLite:
assert.Equal(t, `CREATE TABLE "logins" (email VARCHAR, password VARCHAR)`, string(output))
case dialect.MySQL:
assert.Equal(t, "CREATE TABLE `logins` (email VARCHAR, password VARCHAR)", string(output))
default:
t.Fatalf("unknown dialect: %+v", db.Dialect().Name())
}
})
})
}
21 changes: 13 additions & 8 deletions query_table_create.go
Expand Up @@ -44,7 +44,7 @@ func (q *CreateTableQuery) Model(model interface{}) *CreateTableQuery {
return q
}

//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------

func (q *CreateTableQuery) Table(tables ...string) *CreateTableQuery {
for _, table := range tables {
Expand All @@ -68,7 +68,7 @@ func (q *CreateTableQuery) ColumnExpr(query string, args ...interface{}) *Create
return q
}

//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------

func (q *CreateTableQuery) Temp() *CreateTableQuery {
q.temp = true
Expand Down Expand Up @@ -128,7 +128,7 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by
if field.NotNull {
b = append(b, " NOT NULL"...)
}
if q.db.features.Has(feature.AutoIncrement) && field.AutoIncrement {
if fmter.Dialect().Features().Has(feature.AutoIncrement) && field.AutoIncrement {
b = append(b, " AUTO_INCREMENT"...)
}
if field.SQLDefault != "" {
Expand All @@ -137,8 +137,13 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by
}
}

for _, col := range q.columns {
b = append(b, ", "...)
for i, col := range q.columns {
// Only pre-pend the comma if we are on subsequent iterations, or if there were fields/columns appended before
// this. This way if we are only appending custom column expressions we will not produce a syntax error with a
// leading comma.
if i > 0 || len(q.table.Fields) > 0 {
b = append(b, ", "...)
}
b, err = col.AppendQuery(fmter, b)
if err != nil {
return nil, err
Expand All @@ -147,7 +152,7 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by

b = q.appendPKConstraint(b, q.table.PKs)
b = q.appendUniqueConstraints(fmter, b)
b, err = q.appenFKConstraints(fmter, b)
b, err = q.appendFKConstraints(fmter, b)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -226,7 +231,7 @@ func (q *CreateTableQuery) appendUniqueConstraint(
return b
}

func (q *CreateTableQuery) appenFKConstraints(
func (q *CreateTableQuery) appendFKConstraints(
fmter schema.Formatter, b []byte,
) (_ []byte, err error) {
for _, fk := range q.fks {
Expand All @@ -250,7 +255,7 @@ func (q *CreateTableQuery) appendPKConstraint(b []byte, pks []*schema.Field) []b
return b
}

//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------

func (q *CreateTableQuery) Exec(ctx context.Context, dest ...interface{}) (sql.Result, error) {
if err := q.beforeCreateTableHook(ctx); err != nil {
Expand Down

0 comments on commit 042c50b

Please sign in to comment.