diff --git a/db/alter.go b/db/alter.go index 0a4ffddb..0564d3e9 100644 --- a/db/alter.go +++ b/db/alter.go @@ -12,14 +12,14 @@ type AlterTableSqlBuilder struct { } func (b *AlterTableSqlBuilder) AddColumn(col *Column) *AlterTableSqlBuilder { - if colVal, err := col.String(); err == nil { + if colVal, err := col.ToSQL(b.Dialect); err == nil { b.Changes = append(b.Changes, fmt.Sprintf("ADD COLUMN %s", colVal)) } return b } func (b *AlterTableSqlBuilder) ChangeColumn(name string, col *Column) *AlterTableSqlBuilder { - if colVal, err := col.String(); err == nil { + if colVal, err := col.ToSQL(b.Dialect); err == nil { b.Changes = append(b.Changes, fmt.Sprintf("CHANGE COLUMN %s %s", name, colVal)) } return b diff --git a/db/builder.go b/db/builder.go new file mode 100644 index 00000000..d0f4fe46 --- /dev/null +++ b/db/builder.go @@ -0,0 +1,15 @@ +/* + * Copyright © 2019-2022 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package db + +type SQLBuilder interface { + ToSQL() (string, error) +} diff --git a/db/column.go b/db/column.go new file mode 100644 index 00000000..4e0bb6d4 --- /dev/null +++ b/db/column.go @@ -0,0 +1,280 @@ +/* + * Copyright © 2019-2022 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package db + +import ( + "fmt" + "strings" +) + +type Column struct { + Name string + Type ColumnType + Nullable bool + PrimaryKey bool +} + +func NullableColumn(name string, ty ColumnType) *Column { + return &Column{ + Name: name, + Type: ty, + Nullable: true, + PrimaryKey: false, + } +} + +func NonNullableColumn(name string, ty ColumnType) *Column { + return &Column{ + Name: name, + Type: ty, + Nullable: false, + PrimaryKey: false, + } +} + +func PrimaryKeyColumn(name string, ty ColumnType) *Column { + return &Column{ + Name: name, + Type: ty, + Nullable: false, + PrimaryKey: true, + } +} + +type ColumnType interface { + Name(DialectType) (string, error) + Default(DialectType) (string, error) +} + +type ColumnTypeInt struct { + IsSigned bool + MaxBytes int + MaxDigits int + HasDefault bool + DefaultVal int +} + +type ColumnTypeString struct { + IsFixedLength bool + MaxChars int + HasDefault bool + DefaultVal string +} + +type ColumnDefault int + +type ColumnTypeBool struct { + DefaultVal ColumnDefault +} + +const ( + NoDefault ColumnDefault = iota + DefaultFalse ColumnDefault = iota + DefaultTrue ColumnDefault = iota + DefaultNow ColumnDefault = iota +) + +type ColumnTypeDateTime struct { + DefaultVal ColumnDefault +} + +func (intCol ColumnTypeInt) Name(d DialectType) (string, error) { + switch d { + case DialectSQLite: + return "INTEGER", nil + + case DialectMySQL: + var colName string + switch intCol.MaxBytes { + case 1: + colName = "TINYINT" + case 2: + colName = "SMALLINT" + case 3: + colName = "MEDIUMINT" + case 4: + colName = "INT" + default: + colName = "BIGINT" + } + if intCol.MaxDigits > 0 { + colName = fmt.Sprintf("%s(%d)", colName, intCol.MaxDigits) + } + if !intCol.IsSigned { + colName += " UNSIGNED" + } + return colName, nil + + default: + return "", fmt.Errorf("dialect %d does not support integer columns", d) + } +} + +func (intCol ColumnTypeInt) Default(d DialectType) (string, error) { + switch d { + case DialectSQLite, DialectMySQL: + if intCol.HasDefault { + return fmt.Sprintf("%d", intCol.DefaultVal), nil + } + return "", nil + default: + return "", fmt.Errorf("dialect %d does not support defaulted integer columns", d) + } +} + +func (strCol ColumnTypeString) Name(d DialectType) (string, error) { + switch d { + case DialectSQLite: + return "TEXT", nil + + case DialectMySQL: + if strCol.IsFixedLength { + if strCol.MaxChars > 0 { + return fmt.Sprintf("CHAR(%d)", strCol.MaxChars), nil + } + return "CHAR", nil + } + + if strCol.MaxChars <= 0 { + return "TEXT", nil + } + if strCol.MaxChars < (1 << 16) { + return fmt.Sprintf("VARCHAR(%d)", strCol.MaxChars), nil + } + return "TEXT", nil + + default: + return "", fmt.Errorf("dialect %d does not support string columns", d) + } +} + +func (strCol ColumnTypeString) Default(d DialectType) (string, error) { + switch d { + case DialectSQLite, DialectMySQL: + if strCol.HasDefault { + return EscapeSimple.SQLEscape(d, strCol.DefaultVal) + } + return "", nil + default: + return "", fmt.Errorf("dialect %d does not support defaulted string columns", d) + } +} + +func (boolCol ColumnTypeBool) Name(d DialectType) (string, error) { + switch d { + case DialectSQLite: + return "INTEGER", nil + case DialectMySQL: + return "BOOL", nil + default: + return "", fmt.Errorf("boolean column type not supported for dialect %d", d) + } +} + +func (boolCol ColumnTypeBool) Default(d DialectType) (string, error) { + switch d { + case DialectSQLite, DialectMySQL: + switch boolCol.DefaultVal { + case NoDefault: + return "", nil + case DefaultFalse: + return "0", nil + case DefaultTrue: + return "1", nil + default: + return "", fmt.Errorf("boolean columns cannot default to %d for dialect %d", boolCol.DefaultVal, d) + } + default: + return "", fmt.Errorf("dialect %d does not support defaulted boolean columns", d) + } +} + +func (dateTimeCol ColumnTypeDateTime) Name(d DialectType) (string, error) { + switch d { + case DialectSQLite, DialectMySQL: + return "DATETIME", nil + default: + return "", fmt.Errorf("datetime column type not supported for dialect %d", d) + } +} + +func (dateTimeCol ColumnTypeDateTime) Default(d DialectType) (string, error) { + switch d { + case DialectSQLite, DialectMySQL: + switch dateTimeCol.DefaultVal { + case NoDefault: + return "", nil + case DefaultNow: + switch d { + case DialectSQLite: + return "CURRENT_TIMESTAMP", nil + case DialectMySQL: + return "NOW()", nil + } + } + return "", fmt.Errorf("datetime columns cannot default to %d for dialect %d", dateTimeCol.DefaultVal, d) + default: + return "", fmt.Errorf("dialect %d does not support defaulted datetime columns", d) + } +} + +func (c *Column) SetName(name string) *Column { + c.Name = name + return c +} + +func (c *Column) SetNullable(nullable bool) *Column { + c.Nullable = nullable + return c +} + +func (c *Column) SetPrimaryKey(pk bool) *Column { + c.PrimaryKey = pk + return c +} + +func (c *Column) SetType(t ColumnType) *Column { + c.Type = t + return c +} + +func (c *Column) ToSQL(d DialectType) (string, error) { + var str strings.Builder + + str.WriteString(c.Name) + + str.WriteString(" ") + typeStr, err := c.Type.Name(d) + if err != nil { + return "", err + } + + str.WriteString(typeStr) + + if !c.Nullable { + str.WriteString(" NOT NULL") + } + + defaultStr, err := c.Type.Default(d) + if err != nil { + return "", err + } + if len(defaultStr) > 0 { + str.WriteString(" DEFAULT ") + str.WriteString(defaultStr) + } + + if c.PrimaryKey { + str.WriteString(" PRIMARY KEY") + } + + return str.String(), nil +} diff --git a/db/create.go b/db/create.go index 648f93ae..a9fad986 100644 --- a/db/create.go +++ b/db/create.go @@ -15,32 +15,6 @@ import ( "strings" ) -type ColumnType int - -type OptionalInt struct { - Set bool - Value int -} - -type OptionalString struct { - Set bool - Value string -} - -type SQLBuilder interface { - ToSQL() (string, error) -} - -type Column struct { - Dialect DialectType - Name string - Nullable bool - Default OptionalString - Type ColumnType - Size OptionalInt - PrimaryKey bool -} - type CreateTableSqlBuilder struct { Dialect DialectType Name string @@ -50,157 +24,6 @@ type CreateTableSqlBuilder struct { Constraints []string } -const ( - ColumnTypeBool ColumnType = iota - ColumnTypeSmallInt ColumnType = iota - ColumnTypeInteger ColumnType = iota - ColumnTypeChar ColumnType = iota - ColumnTypeVarChar ColumnType = iota - ColumnTypeText ColumnType = iota - ColumnTypeDateTime ColumnType = iota -) - -var _ SQLBuilder = &CreateTableSqlBuilder{} - -var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0} -var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""} - -func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) { - if dialect != DialectMySQL && dialect != DialectSQLite { - return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) - } - switch d { - case ColumnTypeSmallInt: - { - if dialect == DialectSQLite { - return "INTEGER", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } - return "SMALLINT" + mod, nil - } - case ColumnTypeInteger: - { - if dialect == DialectSQLite { - return "INTEGER", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } - return "INT" + mod, nil - } - case ColumnTypeChar: - { - if dialect == DialectSQLite { - return "TEXT", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } - return "CHAR" + mod, nil - } - case ColumnTypeVarChar: - { - if dialect == DialectSQLite { - return "TEXT", nil - } - mod := "" - if size.Set { - mod = fmt.Sprintf("(%d)", size.Value) - } - return "VARCHAR" + mod, nil - } - case ColumnTypeBool: - { - if dialect == DialectSQLite { - return "INTEGER", nil - } - return "TINYINT(1)", nil - } - case ColumnTypeDateTime: - return "DATETIME", nil - case ColumnTypeText: - return "TEXT", nil - } - return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size) -} - -func (c *Column) SetName(name string) *Column { - c.Name = name - return c -} - -func (c *Column) SetNullable(nullable bool) *Column { - c.Nullable = nullable - return c -} - -func (c *Column) SetPrimaryKey(pk bool) *Column { - c.PrimaryKey = pk - return c -} - -func (c *Column) SetDefault(value string) *Column { - c.Default = OptionalString{Set: true, Value: value} - return c -} - -func (c *Column) SetDefaultCurrentTimestamp() *Column { - def := "NOW()" - if c.Dialect == DialectSQLite { - def = "CURRENT_TIMESTAMP" - } - c.Default = OptionalString{Set: true, Value: def} - return c -} - -func (c *Column) SetType(t ColumnType) *Column { - c.Type = t - return c -} - -func (c *Column) SetSize(size int) *Column { - c.Size = OptionalInt{Set: true, Value: size} - return c -} - -func (c *Column) String() (string, error) { - var str strings.Builder - - str.WriteString(c.Name) - - str.WriteString(" ") - typeStr, err := c.Type.Format(c.Dialect, c.Size) - if err != nil { - return "", err - } - - str.WriteString(typeStr) - - if !c.Nullable { - str.WriteString(" NOT NULL") - } - - if c.Default.Set { - str.WriteString(" DEFAULT ") - val := c.Default.Value - if val == "" { - val = "''" - } - str.WriteString(val) - } - - if c.PrimaryKey { - str.WriteString(" PRIMARY KEY") - } - - return str.String(), nil -} - func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder { if b.Columns == nil { b.Columns = make(map[string]*Column) @@ -241,7 +64,7 @@ func (b *CreateTableSqlBuilder) ToSQL() (string, error) { if !ok { return "", fmt.Errorf("column not found: %s", columnName) } - columnStr, err := column.String() + columnStr, err := column.ToSQL(b.Dialect) if err != nil { return "", err } diff --git a/db/dialect.go b/db/dialect.go index 42514657..ee1eb0f3 100644 --- a/db/dialect.go +++ b/db/dialect.go @@ -9,68 +9,42 @@ const ( DialectMySQL DialectType = iota ) -func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column { +func (d DialectType) IsKnown() bool { switch d { - case DialectSQLite: - return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size} - case DialectMySQL: - return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size} + case DialectSQLite, DialectMySQL: + return true default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) + return false } } -func (d DialectType) Table(name string) *CreateTableSqlBuilder { - switch d { - case DialectSQLite: - return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name} - case DialectMySQL: - return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name} - default: +func (d DialectType) AssertKnown() { + if !d.IsKnown() { panic(fmt.Sprintf("unexpected dialect: %d", d)) } } +func (d DialectType) Table(name string) *CreateTableSqlBuilder { + d.AssertKnown() + return &CreateTableSqlBuilder{Dialect: d, Name: name} +} + func (d DialectType) AlterTable(name string) *AlterTableSqlBuilder { - switch d { - case DialectSQLite: - return &AlterTableSqlBuilder{Dialect: DialectSQLite, Name: name} - case DialectMySQL: - return &AlterTableSqlBuilder{Dialect: DialectMySQL, Name: name} - default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) - } + d.AssertKnown() + return &AlterTableSqlBuilder{Dialect: d, Name: name} } func (d DialectType) CreateUniqueIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { - switch d { - case DialectSQLite: - return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: true, Columns: columns} - case DialectMySQL: - return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: true, Columns: columns} - default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) - } + d.AssertKnown() + return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: true, Columns: columns} } func (d DialectType) CreateIndex(name, table string, columns ...string) *CreateIndexSqlBuilder { - switch d { - case DialectSQLite: - return &CreateIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table, Unique: false, Columns: columns} - case DialectMySQL: - return &CreateIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table, Unique: false, Columns: columns} - default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) - } + d.AssertKnown() + return &CreateIndexSqlBuilder{Dialect: d, Name: name, Table: table, Unique: false, Columns: columns} } func (d DialectType) DropIndex(name, table string) *DropIndexSqlBuilder { - switch d { - case DialectSQLite: - return &DropIndexSqlBuilder{Dialect: DialectSQLite, Name: name, Table: table} - case DialectMySQL: - return &DropIndexSqlBuilder{Dialect: DialectMySQL, Name: name, Table: table} - default: - panic(fmt.Sprintf("unexpected dialect: %d", d)) - } + d.AssertKnown() + return &DropIndexSqlBuilder{Dialect: d, Name: name, Table: table} } diff --git a/db/escape.go b/db/escape.go new file mode 100644 index 00000000..53b8ef31 --- /dev/null +++ b/db/escape.go @@ -0,0 +1,63 @@ +/* + * Copyright © 2019-2022 A Bunch Tell LLC. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + +package db + +import ( + "strings" +) + +type EscapeContext int + +const ( + EscapeSimple EscapeContext = iota +) + +func (_ EscapeContext) SQLEscape(d DialectType, s string) (string, error) { + builder := strings.Builder{} + switch d { + case DialectSQLite: + builder.WriteRune('\'') + for _, c := range s { + if c == '\'' { + builder.WriteString("''") + } else { + builder.WriteRune(c) + } + } + builder.WriteRune('\'') + case DialectMySQL: + builder.WriteRune('\'') + for _, c := range s { + switch c { + case 0: + builder.WriteString("\\0") + case '\'': + builder.WriteString("\\'") + case '"': + builder.WriteString("\\\"") + case '\b': + builder.WriteString("\\b") + case '\n': + builder.WriteString("\\n") + case '\r': + builder.WriteString("\\r") + case '\t': + builder.WriteString("\\t") + case '\\': + builder.WriteString("\\\\") + default: + builder.WriteRune(c) + } + } + builder.WriteRune('\'') + } + return builder.String(), nil +} diff --git a/migrations/v4.go b/migrations/v4.go index c69dce15..25533dd5 100644 --- a/migrations/v4.go +++ b/migrations/v4.go @@ -26,8 +26,8 @@ func oauth(db *datastore) error { createTableUsersOauth, err := dialect. Table("oauth_users"). SetIfNotExists(false). - Column(dialect.Column("user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). - Column(dialect.Column("remote_user_id", wf_db.ColumnTypeInteger, wf_db.UnsetSize)). + Column(wf_db.NonNullableColumn("user_id", wf_db.ColumnTypeInt{MaxBytes: 4})). + Column(wf_db.NonNullableColumn("remote_user_id", wf_db.ColumnTypeInt{MaxBytes: 4})). ToSQL() if err != nil { return err @@ -35,9 +35,9 @@ func oauth(db *datastore) error { createTableOauthClientState, err := dialect. Table("oauth_client_states"). SetIfNotExists(false). - Column(dialect.Column("state", wf_db.ColumnTypeVarChar, wf_db.OptionalInt{Set: true, Value: 255})). - Column(dialect.Column("used", wf_db.ColumnTypeBool, wf_db.UnsetSize)). - Column(dialect.Column("created_at", wf_db.ColumnTypeDateTime, wf_db.UnsetSize).SetDefaultCurrentTimestamp()). + Column(wf_db.NonNullableColumn("state", wf_db.ColumnTypeString{MaxChars: 255})). + Column(wf_db.NonNullableColumn("used", wf_db.ColumnTypeBool{})). + Column(wf_db.NonNullableColumn("created_at", wf_db.ColumnTypeDateTime{DefaultVal: wf_db.DefaultNow})). UniqueConstraint("state"). ToSQL() if err != nil { diff --git a/migrations/v5.go b/migrations/v5.go index 1fe3e302..01ad2a7d 100644 --- a/migrations/v5.go +++ b/migrations/v5.go @@ -26,39 +26,55 @@ func oauthSlack(db *datastore) error { builders := []wf_db.SQLBuilder{ dialect. AlterTable("oauth_client_states"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NonNullableColumn( "provider", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")), + wf_db.ColumnTypeString{ + MaxChars: 24, + HasDefault: true, + DefaultVal: "", + })), dialect. AlterTable("oauth_client_states"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NonNullableColumn( "client_id", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")), + wf_db.ColumnTypeString{ + MaxChars: 128, + HasDefault: true, + DefaultVal: "", + }, + )), dialect. AlterTable("oauth_users"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NonNullableColumn( "provider", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 24}).SetDefault("")), + wf_db.ColumnTypeString{ + MaxChars: 24, + HasDefault: true, + DefaultVal: "", + })), dialect. AlterTable("oauth_users"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NonNullableColumn( "client_id", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 128}).SetDefault("")), + wf_db.ColumnTypeString{ + MaxChars: 128, + HasDefault: true, + DefaultVal: "", + })), dialect. AlterTable("oauth_users"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NonNullableColumn( "access_token", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 512}).SetDefault("")), + wf_db.ColumnTypeString{ + MaxChars: 512, + HasDefault: true, + DefaultVal: "", + })), dialect.CreateUniqueIndex("oauth_users_uk", "oauth_users", "user_id", "provider", "client_id"), } @@ -67,11 +83,12 @@ func oauthSlack(db *datastore) error { builders = append(builders, dialect. AlterTable("oauth_users"). ChangeColumn("remote_user_id", - dialect. - Column( + wf_db. + NonNullableColumn( "remote_user_id", - wf_db.ColumnTypeVarChar, - wf_db.OptionalInt{Set: true, Value: 128}))) + wf_db.ColumnTypeString{ + MaxChars: 128, + }))) } for _, builder := range builders { diff --git a/migrations/v7.go b/migrations/v7.go index 5737b217..7eb89103 100644 --- a/migrations/v7.go +++ b/migrations/v7.go @@ -26,11 +26,13 @@ func oauthAttach(db *datastore) error { builders := []wf_db.SQLBuilder{ dialect. AlterTable("oauth_client_states"). - AddColumn(dialect. - Column( + AddColumn(wf_db. + NullableColumn( "attach_user_id", - wf_db.ColumnTypeInteger, - wf_db.OptionalInt{Set: true, Value: 24}).SetNullable(true)), + wf_db.ColumnTypeInt{ + MaxBytes: 4, + MaxDigits: 24, + })), } for _, builder := range builders { query, err := builder.ToSQL() diff --git a/migrations/v8.go b/migrations/v8.go index 28af523d..00a95cac 100644 --- a/migrations/v8.go +++ b/migrations/v8.go @@ -26,10 +26,10 @@ func oauthInvites(db *datastore) error { builders := []wf_db.SQLBuilder{ dialect. AlterTable("oauth_client_states"). - AddColumn(dialect.Column("invite_code", wf_db.ColumnTypeChar, wf_db.OptionalInt{ - Set: true, - Value: 6, - }).SetNullable(true)), + AddColumn(wf_db.NullableColumn("invite_code", wf_db.ColumnTypeString{ + IsFixedLength: true, + MaxChars: 6, + })), } for _, builder := range builders { query, err := builder.ToSQL()