Skip to content

Commit

Permalink
fix: append default VARCHAR length instead of hardcoding it in the ty…
Browse files Browse the repository at this point in the history
…pe definition

+ use strings.EqualFold() for case-insensitive string comparison (SQL is case-insensitive)
+ Dialect interface now requires `DefaultVarcharLen()`
+ CreateTableQuery uses the length set in .Varchar() also for `bun:",type:varchar"`
  • Loading branch information
bevzzz committed Dec 28, 2022
1 parent 4162cfd commit e5079c7
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 12 deletions.
6 changes: 4 additions & 2 deletions dialect/mssqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,12 @@ func (*Dialect) AppendBool(b []byte, v bool) []byte {
return strconv.AppendUint(b, uint64(num), 10)
}

func (d *Dialect) DefaultVarcharLen() int {
return 255
}

func sqlType(field *schema.Field) string {
switch field.DiscoveredSQLType {
case sqltype.VarChar:
return field.DiscoveredSQLType + "(255)"
case sqltype.Timestamp:
return datetimeType
case sqltype.Boolean:
Expand Down
9 changes: 5 additions & 4 deletions dialect/mysqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,12 @@ func (*Dialect) AppendJSON(b, jsonb []byte) []byte {
return b
}

func (d *Dialect) DefaultVarcharLen() int {
return 255
}

func sqlType(field *schema.Field) string {
switch field.DiscoveredSQLType {
case sqltype.VarChar:
return field.DiscoveredSQLType + "(255)"
case sqltype.Timestamp:
if field.DiscoveredSQLType == sqltype.Timestamp {
return datetimeType
}
return field.DiscoveredSQLType
Expand Down
4 changes: 4 additions & 0 deletions dialect/pgdialect/sqltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ var (
jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
)

func (d *Dialect) DefaultVarcharLen() int {
return 0
}

func fieldSQLType(field *schema.Field) string {
if field.UserSQLType != "" {
return field.UserSQLType
Expand Down
4 changes: 4 additions & 0 deletions dialect/sqlitedialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ func (d *Dialect) AppendBytes(b []byte, bs []byte) []byte {
return b
}

func (d *Dialect) DefaultVarcharLen() int {
return 0
}

func fieldSQLType(field *schema.Field) string {
switch field.DiscoveredSQLType {
case sqltype.SmallInt, sqltype.BigInt:
Expand Down
17 changes: 13 additions & 4 deletions query_table_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"sort"
"strconv"
"strings"

"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/dialect/sqltype"
Expand Down Expand Up @@ -32,6 +33,7 @@ func NewCreateTableQuery(db *DB) *CreateTableQuery {
db: db,
conn: db.DB,
},
varchar: db.Dialect().DefaultVarcharLen(),
}
return q
}
Expand Down Expand Up @@ -82,6 +84,10 @@ func (q *CreateTableQuery) IfNotExists() *CreateTableQuery {
return q
}

// Varchar changes the default length for VARCHAR columns.
// Because some dialects require that length is always specified for VARCHAR type,
// we will use the exact user-defined type if length is set explicitly, as in `bun:",type:varchar(5)"`,
// but assume the new default length when it's omitted, e.g. `bun:",type:varchar"`.
func (q *CreateTableQuery) Varchar(n int) *CreateTableQuery {
q.varchar = n
return q
Expand Down Expand Up @@ -120,7 +126,7 @@ func (q *CreateTableQuery) WithForeignKeys() *CreateTableQuery {
return q
}

//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------

func (q *CreateTableQuery) Operation() string {
return "CREATE TABLE"
Expand Down Expand Up @@ -221,12 +227,15 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by
}

func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte {
if field.CreateTableSQLType != field.DiscoveredSQLType {
// Most of the time these two will match, but for the cases where DiscoveredSQLType is dialect-specific,
// e.g. pgdialect would change sqltype.SmallInt to pgTypeSmallSerial for columns that have `bun:",autoincrement"`
if !strings.EqualFold(field.CreateTableSQLType, field.DiscoveredSQLType) {
return append(b, field.CreateTableSQLType...)
}

if q.varchar > 0 &&
field.CreateTableSQLType == sqltype.VarChar {
// For all common SQL types except VARCHAR, both UserDefinedSQLType and DiscoveredSQLType specify the correct type,
// and we needn't modify it. For VARCHAR columns, we will stop to check if a valid length has been set in .Varchar(int).
if q.varchar > 0 && strings.EqualFold(field.CreateTableSQLType, sqltype.VarChar) {
b = append(b, "varchar("...)
b = strconv.AppendInt(b, int64(q.varchar), 10)
b = append(b, ")"...)
Expand Down
13 changes: 11 additions & 2 deletions schema/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ type Dialect interface {
AppendBytes(b []byte, bs []byte) []byte
AppendJSON(b, jsonb []byte) []byte
AppendBool(b []byte, v bool) []byte

// DefaultVarcharLen should be returned for dialects in which specifying VARCHAR length
// is mandatory in queries that modify the schema (CREATE TABLE / ADD COLUMN, etc).
// Dialects that do not have such requirement may return 0, which should be interpreted so by the caller.
DefaultVarcharLen() int
}

//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------

type BaseDialect struct{}

Expand Down Expand Up @@ -131,7 +136,7 @@ func (BaseDialect) AppendBool(b []byte, v bool) []byte {
return dialect.AppendBool(b, v)
}

//------------------------------------------------------------------------------
// ------------------------------------------------------------------------------

type nopDialect struct {
BaseDialect
Expand Down Expand Up @@ -168,3 +173,7 @@ func (d *nopDialect) OnTable(table *Table) {}
func (d *nopDialect) IdentQuote() byte {
return '"'
}

func (d *nopDialect) DefaultVarcharLen() int {
return 0
}

0 comments on commit e5079c7

Please sign in to comment.