Skip to content

Commit

Permalink
Add a WithSQLTransformer option to rewrite user-defined SQL in up
Browse files Browse the repository at this point in the history
… and `down` triggers (#329)

Add a new `WithSQLTransformer` option to rewrite the user-defined SQL
used to define `up` and `down` triggers.

The intention is that the transformer be used to sanitize user-input
SQL.

Transformers implement the following interface:

```go
type SQLTransformer interface {
	Transform(sql string) (string, error)
}
```

and are used by the `createTrigger` function to rewrite the `up` or
`down` SQL before using it in the trigger function definition.

Later PRs will use the same transformer to rewrite the `up` and `down`
values used in raw SQL migrations and column `DEFAULT` expressions.
  • Loading branch information
andrew-farries committed Mar 27, 2024
1 parent cf66d1f commit cc8c2d3
Show file tree
Hide file tree
Showing 26 changed files with 239 additions and 83 deletions.
16 changes: 13 additions & 3 deletions pkg/migrations/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ type Operation interface {
// version in the database (through a view)
// update the given views to expose the new schema version
// Returns the table that requires backfilling, if any.
Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error)
Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error)

// Complete will update the database schema to match the current version
// after calling Start.
// This method should be called once the previous version is no longer used
Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error
Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error

// Rollback will revert the changes made by Start. It is not possible to
// rollback a completed migration.
Rollback(ctx context.Context, conn *sql.DB) error
Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error

// Validate returns a descriptive error if the operation cannot be applied to the given schema
Validate(ctx context.Context, s *schema.Schema) error
Expand All @@ -46,6 +46,16 @@ type RequiresSchemaRefreshOperation interface {
RequiresSchemaRefresh()
}

type SQLTransformer interface {
TransformSQL(sql string) (string, error)
}

type SQLTransformerFunc func(string) (string, error)

func (fn SQLTransformerFunc) TransformSQL(sql string) (string, error) {
return fn(sql)
}

type (
Operations []Operation
Migration struct {
Expand Down
8 changes: 4 additions & 4 deletions pkg/migrations/op_add_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

var _ Operation = (*OpAddColumn)(nil)

func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)

if err := addColumn(ctx, conn, *o, table); err != nil {
Expand All @@ -42,7 +42,7 @@ func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema strin

var tableToBackfill *schema.Table
if o.Up != "" {
err := createTrigger(ctx, conn, triggerConfig{
err := createTrigger(ctx, conn, tr, triggerConfig{
Name: TriggerName(o.Table, o.Column.Name),
Direction: TriggerDirectionUp,
Columns: s.GetTable(o.Table).Columns,
Expand All @@ -65,7 +65,7 @@ func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema strin
return tableToBackfill, nil
}

func (o *OpAddColumn) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
func (o *OpAddColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
tempName := TemporaryName(o.Column.Name)

_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s",
Expand Down Expand Up @@ -118,7 +118,7 @@ func (o *OpAddColumn) Complete(ctx context.Context, conn *sql.DB, s *schema.Sche
return err
}

func (o *OpAddColumn) Rollback(ctx context.Context, conn *sql.DB) error {
func (o *OpAddColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
tempName := TemporaryName(o.Column.Name)

_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s DROP COLUMN IF EXISTS %s",
Expand Down
12 changes: 6 additions & 6 deletions pkg/migrations/op_alter_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@ import (

var _ Operation = (*OpAlterColumn)(nil)

func (o *OpAlterColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpAlterColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
op := o.innerOperation()

return op.Start(ctx, conn, stateSchema, s, cbs...)
return op.Start(ctx, conn, stateSchema, tr, s, cbs...)
}

func (o *OpAlterColumn) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
func (o *OpAlterColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
op := o.innerOperation()

return op.Complete(ctx, conn, s)
return op.Complete(ctx, conn, tr, s)
}

func (o *OpAlterColumn) Rollback(ctx context.Context, conn *sql.DB) error {
func (o *OpAlterColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
op := o.innerOperation()

return op.Rollback(ctx, conn)
return op.Rollback(ctx, conn, tr)
}

func (o *OpAlterColumn) Validate(ctx context.Context, s *schema.Schema) error {
Expand Down
10 changes: 5 additions & 5 deletions pkg/migrations/op_change_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type OpChangeType struct {

var _ Operation = (*OpChangeType)(nil)

func (o *OpChangeType) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpChangeType) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)
column := table.GetColumn(o.Column)

Expand All @@ -32,7 +32,7 @@ func (o *OpChangeType) Start(ctx context.Context, conn *sql.DB, stateSchema stri
}

// Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL.
err := createTrigger(ctx, conn, triggerConfig{
err := createTrigger(ctx, conn, tr, triggerConfig{
Name: TriggerName(o.Table, o.Column),
Direction: TriggerDirectionUp,
Columns: table.Columns,
Expand All @@ -54,7 +54,7 @@ func (o *OpChangeType) Start(ctx context.Context, conn *sql.DB, stateSchema stri
})

// Add a trigger to copy values from the new column to the old, rewriting values using the `down` SQL.
err = createTrigger(ctx, conn, triggerConfig{
err = createTrigger(ctx, conn, tr, triggerConfig{
Name: TriggerName(o.Table, TemporaryName(o.Column)),
Direction: TriggerDirectionDown,
Columns: table.Columns,
Expand All @@ -71,7 +71,7 @@ func (o *OpChangeType) Start(ctx context.Context, conn *sql.DB, stateSchema stri
return table, nil
}

func (o *OpChangeType) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
func (o *OpChangeType) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
// Remove the up function and trigger
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))
Expand Down Expand Up @@ -104,7 +104,7 @@ func (o *OpChangeType) Complete(ctx context.Context, conn *sql.DB, s *schema.Sch
return nil
}

func (o *OpChangeType) Rollback(ctx context.Context, conn *sql.DB) error {
func (o *OpChangeType) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
// Drop the new column
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table),
Expand Down
6 changes: 3 additions & 3 deletions pkg/migrations/op_create_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

var _ Operation = (*OpCreateIndex)(nil)

func (o *OpCreateIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpCreateIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
// create index concurrently
_, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (%s)",
pq.QuoteIdentifier(o.Name),
Expand All @@ -23,12 +23,12 @@ func (o *OpCreateIndex) Start(ctx context.Context, conn *sql.DB, stateSchema str
return nil, err
}

func (o *OpCreateIndex) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
func (o *OpCreateIndex) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
// No-op
return nil
}

func (o *OpCreateIndex) Rollback(ctx context.Context, conn *sql.DB) error {
func (o *OpCreateIndex) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
// drop the index concurrently
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", o.Name))

Expand Down
6 changes: 3 additions & 3 deletions pkg/migrations/op_create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

var _ Operation = (*OpCreateTable)(nil)

func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
tempName := TemporaryName(o.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (%s)",
pq.QuoteIdentifier(tempName),
Expand Down Expand Up @@ -54,15 +54,15 @@ func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema str
return nil, nil
}

func (o *OpCreateTable) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
func (o *OpCreateTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
tempName := TemporaryName(o.Name)
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME TO %s",
pq.QuoteIdentifier(tempName),
pq.QuoteIdentifier(o.Name)))
return err
}

func (o *OpCreateTable) Rollback(ctx context.Context, conn *sql.DB) error {
func (o *OpCreateTable) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
tempName := TemporaryName(o.Name)

_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s",
Expand Down
8 changes: 4 additions & 4 deletions pkg/migrations/op_drop_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import (

var _ Operation = (*OpDropColumn)(nil)

func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
if o.Down != "" {
err := createTrigger(ctx, conn, triggerConfig{
err := createTrigger(ctx, conn, tr, triggerConfig{
Name: TriggerName(o.Table, o.Column),
Direction: TriggerDirectionDown,
Columns: s.GetTable(o.Table).Columns,
Expand All @@ -34,7 +34,7 @@ func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, stateSchema stri
return nil, nil
}

func (o *OpDropColumn) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
func (o *OpDropColumn) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Column)))
Expand All @@ -48,7 +48,7 @@ func (o *OpDropColumn) Complete(ctx context.Context, conn *sql.DB, s *schema.Sch
return err
}

func (o *OpDropColumn) Rollback(ctx context.Context, conn *sql.DB) error {
func (o *OpDropColumn) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))

Expand Down
10 changes: 5 additions & 5 deletions pkg/migrations/op_drop_constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

var _ Operation = (*OpDropConstraint)(nil)

func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)
column := table.GetColumn(o.Column)

Expand All @@ -24,7 +24,7 @@ func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema
}

// Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL.
err := createTrigger(ctx, conn, triggerConfig{
err := createTrigger(ctx, conn, tr, triggerConfig{
Name: TriggerName(o.Table, o.Column),
Direction: TriggerDirectionUp,
Columns: table.Columns,
Expand All @@ -46,7 +46,7 @@ func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema
})

// Add a trigger to copy values from the new column to the old, rewriting values using the `down` SQL.
err = createTrigger(ctx, conn, triggerConfig{
err = createTrigger(ctx, conn, tr, triggerConfig{
Name: TriggerName(o.Table, TemporaryName(o.Column)),
Direction: TriggerDirectionDown,
Columns: table.Columns,
Expand All @@ -62,7 +62,7 @@ func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema
return table, nil
}

func (o *OpDropConstraint) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
func (o *OpDropConstraint) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
// Remove the up function and trigger
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE",
pq.QuoteIdentifier(TriggerFunctionName(o.Table, o.Column))))
Expand Down Expand Up @@ -95,7 +95,7 @@ func (o *OpDropConstraint) Complete(ctx context.Context, conn *sql.DB, s *schema
return err
}

func (o *OpDropConstraint) Rollback(ctx context.Context, conn *sql.DB) error {
func (o *OpDropConstraint) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
// Drop the new column
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table),
Expand Down
6 changes: 3 additions & 3 deletions pkg/migrations/op_drop_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ import (

var _ Operation = (*OpDropIndex)(nil)

func (o *OpDropIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpDropIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
// no-op
return nil, nil
}

func (o *OpDropIndex) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
func (o *OpDropIndex) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
// drop the index concurrently
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", o.Name))

return err
}

func (o *OpDropIndex) Rollback(ctx context.Context, conn *sql.DB) error {
func (o *OpDropIndex) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
// no-op
return nil
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/migrations/op_drop_not_null.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type OpDropNotNull struct {

var _ Operation = (*OpDropNotNull)(nil)

func (o *OpDropNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpDropNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
table := s.GetTable(o.Table)
column := table.GetColumn(o.Column)

Expand All @@ -31,7 +31,7 @@ func (o *OpDropNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema str
}

// Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL.
err := createTrigger(ctx, conn, triggerConfig{
err := createTrigger(ctx, conn, tr, triggerConfig{
Name: TriggerName(o.Table, o.Column),
Direction: TriggerDirectionUp,
Columns: table.Columns,
Expand All @@ -53,7 +53,7 @@ func (o *OpDropNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema str
})

// Add a trigger to copy values from the new column to the old.
err = createTrigger(ctx, conn, triggerConfig{
err = createTrigger(ctx, conn, tr, triggerConfig{
Name: TriggerName(o.Table, TemporaryName(o.Column)),
Direction: TriggerDirectionDown,
Columns: table.Columns,
Expand All @@ -70,7 +70,7 @@ func (o *OpDropNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema str
return table, nil
}

func (o *OpDropNotNull) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
func (o *OpDropNotNull) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
// Drop the old column
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table),
Expand Down Expand Up @@ -103,7 +103,7 @@ func (o *OpDropNotNull) Complete(ctx context.Context, conn *sql.DB, s *schema.Sc
return nil
}

func (o *OpDropNotNull) Rollback(ctx context.Context, conn *sql.DB) error {
func (o *OpDropNotNull) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
// Drop the new column
_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s",
pq.QuoteIdentifier(o.Table),
Expand Down
6 changes: 3 additions & 3 deletions pkg/migrations/op_drop_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@ import (

var _ Operation = (*OpDropTable)(nil)

func (o *OpDropTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpDropTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
s.RemoveTable(o.Name)
return nil, nil
}

func (o *OpDropTable) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
func (o *OpDropTable) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
_, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(o.Name)))

return err
}

func (o *OpDropTable) Rollback(ctx context.Context, conn *sql.DB) error {
func (o *OpDropTable) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
return nil
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/migrations/op_raw_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,23 @@ import (

var _ Operation = (*OpRawSQL)(nil)

func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, stateSchema string, tr SQLTransformer, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) {
if !o.OnComplete {
_, err := conn.ExecContext(ctx, o.Up)
return nil, err
}
return nil, nil
}

func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error {
func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB, tr SQLTransformer, s *schema.Schema) error {
if o.OnComplete {
_, err := conn.ExecContext(ctx, o.Up)
return err
}
return nil
}

func (o *OpRawSQL) Rollback(ctx context.Context, conn *sql.DB) error {
func (o *OpRawSQL) Rollback(ctx context.Context, conn *sql.DB, tr SQLTransformer) error {
if o.Down != "" {
_, err := conn.ExecContext(ctx, o.Down)
return err
Expand Down
Loading

0 comments on commit cc8c2d3

Please sign in to comment.