Skip to content

Commit

Permalink
feat: add context-aware Go migrations (#534)
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-shalom committed Jun 29, 2023
1 parent 7d9fbaf commit e18fac6
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 33 deletions.
7 changes: 4 additions & 3 deletions create.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,21 @@ SELECT 'down SQL query';
var goSQLMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`package migrations
import (
"context"
"database/sql"
"github.com/pressly/goose/v3"
)
func init() {
goose.AddMigration(up{{.CamelName}}, down{{.CamelName}})
goose.AddMigrationContext(up{{.CamelName}}, down{{.CamelName}})
}
func up{{.CamelName}}(tx *sql.Tx) error {
func up{{.CamelName}}(ctx context.Context, tx *sql.Tx) error {
// This code is executed when the migration is applied.
return nil
}
func down{{.CamelName}}(tx *sql.Tx) error {
func down{{.CamelName}}(ctx context.Context, tx *sql.Tx) error {
// This code is executed when the migration is rolled back.
return nil
}
Expand Down
88 changes: 73 additions & 15 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,30 +130,58 @@ func (ms Migrations) String() string {
// GoMigration is a Go migration func that is run within a transaction.
type GoMigration func(tx *sql.Tx) error

// GoMigrationContext is a Go migration func that is run within a transaction and receives a context.
type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error

// GoMigrationNoTx is a Go migration func that is run outside a transaction.
type GoMigrationNoTx func(db *sql.DB) error

// GoMigrationNoTxContext is a Go migration func that is run outside a transaction and receives a context.
type GoMigrationNoTxContext func(ctx context.Context, db *sql.DB) error

// AddMigration adds Go migrations.
func AddMigration(up, down GoMigration) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigration(filename, up, down)
// intentionally don't call to AddMigrationContext so each of these functions can calculate the filename correctly
AddNamedMigrationContext(filename, withContext(up), withContext(down))
}

// AddMigrationContext adds Go migrations.
func AddMigrationContext(up, down GoMigrationContext) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigrationContext(filename, up, down)
}

// AddNamedMigration adds named Go migrations.
func AddNamedMigration(filename string, up, down GoMigration) {
AddNamedMigrationContext(filename, withContext(up), withContext(down))
}

// AddNamedMigrationContext adds named Go migrations.
func AddNamedMigrationContext(filename string, up, down GoMigrationContext) {
if err := register(filename, true, up, down, nil, nil); err != nil {
panic(err)
}
}

// AddMigrationNoTx adds Go migrations that will be run outside transaction.
func AddMigrationNoTx(up, down GoMigrationNoTx) {
_, filename, _, _ := runtime.Caller(1)
AddNamedMigrationNoTx(filename, up, down)
AddMigrationNoTxContext(withContext(up), withContext(down))
}

// AddMigrationNoTxContext adds Go migrations that will be run outside transaction.
func AddMigrationNoTxContext(up, down GoMigrationNoTxContext) {
_, filename, _, _ := runtime.Caller(2)
AddNamedMigrationNoTxContext(filename, up, down)
}

// AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction.
func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) {
AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down))
}

// AddNamedMigrationNoTxContext adds named Go migrations that will be run outside transaction.
func AddNamedMigrationNoTxContext(filename string, up, down GoMigrationNoTxContext) {
if err := register(filename, false, nil, nil, up, down); err != nil {
panic(err)
}
Expand All @@ -162,8 +190,8 @@ func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) {
func register(
filename string,
useTx bool,
up, down GoMigration,
upNoTx, downNoTx GoMigrationNoTx,
up, down GoMigrationContext,
upNoTx, downNoTx GoMigrationNoTxContext,
) error {
// Sanity check caller did not mix tx and non-tx based functions.
if (up != nil || down != nil) && (upNoTx != nil || downNoTx != nil) {
Expand All @@ -179,16 +207,23 @@ func register(
}
// Add to global as a registered migration.
registeredGoMigrations[v] = &Migration{
Version: v,
Next: -1,
Previous: -1,
Registered: true,
Source: filename,
UseTx: useTx,
UpFn: up,
DownFn: down,
UpFnNoTx: upNoTx,
DownFnNoTx: downNoTx,
Version: v,
Next: -1,
Previous: -1,
Registered: true,
Source: filename,
UseTx: useTx,
UpFnContext: up,
DownFnContext: down,
UpFnNoTxContext: upNoTx,
DownFnNoTxContext: downNoTx,
// These are deprecated and will be removed in the future.
// For backwards compatibility we still save the non-context versions in the struct in case someone is using them.
// Goose does not use these internally anymore and instead uses the context versions.
UpFn: withoutContext(up),
DownFn: withoutContext(down),
UpFnNoTx: withoutContext(upNoTx),
DownFnNoTx: withoutContext(downNoTx),
}
return nil
}
Expand Down Expand Up @@ -378,3 +413,26 @@ func GetDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) {

return version, nil
}

// withContext changes the signature of a function that receives one argument to receive a context and the argument.
func withContext[T any](fn func(T) error) func(context.Context, T) error {
if fn == nil {
return nil
}

return func(ctx context.Context, t T) error {
return fn(t)
}
}

// withoutContext changes the signature of a function that receives a context and one argument to receive only the argument.
// When called the passed context is always context.Background().
func withoutContext[T any](fn func(context.Context, T) error) func(T) error {
if fn == nil {
return nil
}

return func(t T) error {
return fn(context.Background(), t)
}
}
38 changes: 23 additions & 15 deletions migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,23 @@ type MigrationRecord struct {

// Migration struct.
type Migration struct {
Version int64
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none
Source string // path to .sql script or go file
Registered bool
UseTx bool
Version int64
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none
Source string // path to .sql script or go file
Registered bool
UseTx bool

// These are deprecated and will be removed in the future.
// For backwards compatibility we still save the non-context versions in the struct in case someone is using them.
// Goose does not use these internally anymore and instead uses the context versions.
UpFn, DownFn GoMigration
UpFnNoTx, DownFnNoTx GoMigrationNoTx
noVersioning bool

// New functions with context
UpFnContext, DownFnContext GoMigrationContext
UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext
noVersioning bool
}

func (m *Migration) String() string {
Expand Down Expand Up @@ -99,9 +107,9 @@ func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error {
var empty bool
if m.UseTx {
// Run go-based migration inside a tx.
fn := m.DownFn
fn := m.DownFnContext
if direction {
fn = m.UpFn
fn = m.UpFnContext
}
empty = (fn == nil)
if err := runGoMigration(
Expand All @@ -116,9 +124,9 @@ func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error {
}
} else {
// Run go-based migration outside a tx.
fn := m.DownFnNoTx
fn := m.DownFnNoTxContext
if direction {
fn = m.UpFnNoTx
fn = m.UpFnNoTxContext
}
empty = (fn == nil)
if err := runGoMigrationNoTx(
Expand All @@ -145,14 +153,14 @@ func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error {
func runGoMigrationNoTx(
ctx context.Context,
db *sql.DB,
fn GoMigrationNoTx,
fn GoMigrationNoTxContext,
version int64,
direction bool,
recordVersion bool,
) error {
if fn != nil {
// Run go migration function.
if err := fn(db); err != nil {
if err := fn(ctx, db); err != nil {
return fmt.Errorf("failed to run go migration: %w", err)
}
}
Expand All @@ -165,7 +173,7 @@ func runGoMigrationNoTx(
func runGoMigration(
ctx context.Context,
db *sql.DB,
fn GoMigration,
fn GoMigrationContext,
version int64,
direction bool,
recordVersion bool,
Expand All @@ -179,7 +187,7 @@ func runGoMigration(
}
if fn != nil {
// Run go migration function.
if err := fn(tx); err != nil {
if err := fn(ctx, tx); err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to run go migration: %w", err)
}
Expand Down

0 comments on commit e18fac6

Please sign in to comment.