From 87592390b9126a918ea8e44eedcb4eb9162a2c80 Mon Sep 17 00:00:00 2001 From: Michael Fridman Date: Mon, 8 May 2023 08:17:14 -0400 Subject: [PATCH] fix: use global table name in queries (#515) --- dialect.go | 4 +- internal/dialect/dialectquery/clickhouse.go | 24 ++++--- internal/dialect/dialectquery/dialectquery.go | 10 +-- internal/dialect/dialectquery/mysql.go | 24 ++++--- internal/dialect/dialectquery/postgres.go | 24 ++++--- internal/dialect/dialectquery/redshift.go | 24 ++++--- internal/dialect/dialectquery/sqlite3.go | 24 ++++--- internal/dialect/dialectquery/sqlserver.go | 24 ++++--- internal/dialect/dialectquery/tidb.go | 24 ++++--- internal/dialect/dialectquery/vertica.go | 24 ++++--- internal/dialect/store.go | 66 +++++++++---------- migrate.go | 6 +- migration.go | 8 +-- migration_sql.go | 8 +-- reset.go | 2 +- status.go | 2 +- up.go | 2 +- 17 files changed, 139 insertions(+), 161 deletions(-) diff --git a/dialect.go b/dialect.go index 8dd82b1a1..a14248002 100644 --- a/dialect.go +++ b/dialect.go @@ -7,7 +7,7 @@ import ( ) func init() { - store, _ = dialect.NewStore(dialect.Postgres, TableName()) + store, _ = dialect.NewStore(dialect.Postgres) } var store dialect.Store @@ -36,6 +36,6 @@ func SetDialect(s string) error { return fmt.Errorf("%q: unknown dialect", s) } var err error - store, err = dialect.NewStore(d, TableName()) + store, err = dialect.NewStore(d) return err } diff --git a/internal/dialect/dialectquery/clickhouse.go b/internal/dialect/dialectquery/clickhouse.go index 6f1040afa..ca07f8684 100644 --- a/internal/dialect/dialectquery/clickhouse.go +++ b/internal/dialect/dialectquery/clickhouse.go @@ -2,13 +2,11 @@ package dialectquery import "fmt" -type Clickhouse struct { - Table string -} +type Clickhouse struct{} var _ Querier = (*Clickhouse)(nil) -func (c *Clickhouse) CreateTable() string { +func (c *Clickhouse) CreateTable(tableName string) string { q := `CREATE TABLE IF NOT EXISTS %s ( version_id Int64, is_applied UInt8, @@ -17,25 +15,25 @@ func (c *Clickhouse) CreateTable() string { ) ENGINE = MergeTree() ORDER BY (date)` - return fmt.Sprintf(q, c.Table) + return fmt.Sprintf(q, tableName) } -func (c *Clickhouse) InsertVersion() string { +func (c *Clickhouse) InsertVersion(tableName string) string { q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)` - return fmt.Sprintf(q, c.Table) + return fmt.Sprintf(q, tableName) } -func (c *Clickhouse) DeleteVersion() string { +func (c *Clickhouse) DeleteVersion(tableName string) string { q := `ALTER TABLE %s DELETE WHERE version_id = $1 SETTINGS mutations_sync = 2` - return fmt.Sprintf(q, c.Table) + return fmt.Sprintf(q, tableName) } -func (c *Clickhouse) GetMigrationByVersion() string { +func (c *Clickhouse) GetMigrationByVersion(tableName string) string { q := `SELECT tstamp, is_applied FROM %s WHERE version_id = $1 ORDER BY tstamp DESC LIMIT 1` - return fmt.Sprintf(q, c.Table) + return fmt.Sprintf(q, tableName) } -func (c *Clickhouse) ListMigrations() string { +func (c *Clickhouse) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied FROM %s ORDER BY version_id DESC` - return fmt.Sprintf(q, c.Table) + return fmt.Sprintf(q, tableName) } diff --git a/internal/dialect/dialectquery/dialectquery.go b/internal/dialect/dialectquery/dialectquery.go index a8951a736..482771aa1 100644 --- a/internal/dialect/dialectquery/dialectquery.go +++ b/internal/dialect/dialectquery/dialectquery.go @@ -4,25 +4,25 @@ package dialectquery // specific query. type Querier interface { // CreateTable returns the SQL query string to create the db version table. - CreateTable() string + CreateTable(tableName string) string // InsertVersion returns the SQL query string to insert a new version into // the db version table. - InsertVersion() string + InsertVersion(tableName string) string // DeleteVersion returns the SQL query string to delete a version from // the db version table. - DeleteVersion() string + DeleteVersion(tableName string) string // GetMigrationByVersion returns the SQL query string to get a single // migration by version. // // The query should return the timestamp and is_applied columns. - GetMigrationByVersion() string + GetMigrationByVersion(tableName string) string // ListMigrations returns the SQL query string to list all migrations in // descending order by id. // // The query should return the version_id and is_applied columns. - ListMigrations() string + ListMigrations(tableName string) string } diff --git a/internal/dialect/dialectquery/mysql.go b/internal/dialect/dialectquery/mysql.go index 6a4f0b439..25954cbc2 100644 --- a/internal/dialect/dialectquery/mysql.go +++ b/internal/dialect/dialectquery/mysql.go @@ -2,13 +2,11 @@ package dialectquery import "fmt" -type Mysql struct { - Table string -} +type Mysql struct{} var _ Querier = (*Mysql)(nil) -func (m *Mysql) CreateTable() string { +func (m *Mysql) CreateTable(tableName string) string { q := `CREATE TABLE %s ( id serial NOT NULL, version_id bigint NOT NULL, @@ -16,25 +14,25 @@ func (m *Mysql) CreateTable() string { tstamp timestamp NULL default now(), PRIMARY KEY(id) )` - return fmt.Sprintf(q, m.Table) + return fmt.Sprintf(q, tableName) } -func (m *Mysql) InsertVersion() string { +func (m *Mysql) InsertVersion(tableName string) string { q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)` - return fmt.Sprintf(q, m.Table) + return fmt.Sprintf(q, tableName) } -func (m *Mysql) DeleteVersion() string { +func (m *Mysql) DeleteVersion(tableName string) string { q := `DELETE FROM %s WHERE version_id=?` - return fmt.Sprintf(q, m.Table) + return fmt.Sprintf(q, tableName) } -func (m *Mysql) GetMigrationByVersion() string { +func (m *Mysql) GetMigrationByVersion(tableName string) string { q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1` - return fmt.Sprintf(q, m.Table) + return fmt.Sprintf(q, tableName) } -func (m *Mysql) ListMigrations() string { +func (m *Mysql) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied from %s ORDER BY id DESC` - return fmt.Sprintf(q, m.Table) + return fmt.Sprintf(q, tableName) } diff --git a/internal/dialect/dialectquery/postgres.go b/internal/dialect/dialectquery/postgres.go index 1cddeb928..5103390f4 100644 --- a/internal/dialect/dialectquery/postgres.go +++ b/internal/dialect/dialectquery/postgres.go @@ -2,13 +2,11 @@ package dialectquery import "fmt" -type Postgres struct { - Table string -} +type Postgres struct{} var _ Querier = (*Postgres)(nil) -func (p *Postgres) CreateTable() string { +func (p *Postgres) CreateTable(tableName string) string { q := `CREATE TABLE %s ( id serial NOT NULL, version_id bigint NOT NULL, @@ -16,25 +14,25 @@ func (p *Postgres) CreateTable() string { tstamp timestamp NULL default now(), PRIMARY KEY(id) )` - return fmt.Sprintf(q, p.Table) + return fmt.Sprintf(q, tableName) } -func (p *Postgres) InsertVersion() string { +func (p *Postgres) InsertVersion(tableName string) string { q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)` - return fmt.Sprintf(q, p.Table) + return fmt.Sprintf(q, tableName) } -func (p *Postgres) DeleteVersion() string { +func (p *Postgres) DeleteVersion(tableName string) string { q := `DELETE FROM %s WHERE version_id=$1` - return fmt.Sprintf(q, p.Table) + return fmt.Sprintf(q, tableName) } -func (p *Postgres) GetMigrationByVersion() string { +func (p *Postgres) GetMigrationByVersion(tableName string) string { q := `SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1` - return fmt.Sprintf(q, p.Table) + return fmt.Sprintf(q, tableName) } -func (p *Postgres) ListMigrations() string { +func (p *Postgres) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied from %s ORDER BY id DESC` - return fmt.Sprintf(q, p.Table) + return fmt.Sprintf(q, tableName) } diff --git a/internal/dialect/dialectquery/redshift.go b/internal/dialect/dialectquery/redshift.go index 8dd5fda33..006a0ca6d 100644 --- a/internal/dialect/dialectquery/redshift.go +++ b/internal/dialect/dialectquery/redshift.go @@ -2,13 +2,11 @@ package dialectquery import "fmt" -type Redshift struct { - Table string -} +type Redshift struct{} var _ Querier = (*Redshift)(nil) -func (r *Redshift) CreateTable() string { +func (r *Redshift) CreateTable(tableName string) string { q := `CREATE TABLE %s ( id integer NOT NULL identity(1, 1), version_id bigint NOT NULL, @@ -16,25 +14,25 @@ func (r *Redshift) CreateTable() string { tstamp timestamp NULL default sysdate, PRIMARY KEY(id) )` - return fmt.Sprintf(q, r.Table) + return fmt.Sprintf(q, tableName) } -func (r *Redshift) InsertVersion() string { +func (r *Redshift) InsertVersion(tableName string) string { q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)` - return fmt.Sprintf(q, r.Table) + return fmt.Sprintf(q, tableName) } -func (r *Redshift) DeleteVersion() string { +func (r *Redshift) DeleteVersion(tableName string) string { q := `DELETE FROM %s WHERE version_id=$1` - return fmt.Sprintf(q, r.Table) + return fmt.Sprintf(q, tableName) } -func (r *Redshift) GetMigrationByVersion() string { +func (r *Redshift) GetMigrationByVersion(tableName string) string { q := `SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1` - return fmt.Sprintf(q, r.Table) + return fmt.Sprintf(q, tableName) } -func (r *Redshift) ListMigrations() string { +func (r *Redshift) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied from %s ORDER BY id DESC` - return fmt.Sprintf(q, r.Table) + return fmt.Sprintf(q, tableName) } diff --git a/internal/dialect/dialectquery/sqlite3.go b/internal/dialect/dialectquery/sqlite3.go index d6526f20d..689900a72 100644 --- a/internal/dialect/dialectquery/sqlite3.go +++ b/internal/dialect/dialectquery/sqlite3.go @@ -2,38 +2,36 @@ package dialectquery import "fmt" -type Sqlite3 struct { - Table string -} +type Sqlite3 struct{} var _ Querier = (*Sqlite3)(nil) -func (s *Sqlite3) CreateTable() string { +func (s *Sqlite3) CreateTable(tableName string) string { q := `CREATE TABLE %s ( id INTEGER PRIMARY KEY AUTOINCREMENT, version_id INTEGER NOT NULL, is_applied INTEGER NOT NULL, tstamp TIMESTAMP DEFAULT (datetime('now')) )` - return fmt.Sprintf(q, s.Table) + return fmt.Sprintf(q, tableName) } -func (s *Sqlite3) InsertVersion() string { +func (s *Sqlite3) InsertVersion(tableName string) string { q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)` - return fmt.Sprintf(q, s.Table) + return fmt.Sprintf(q, tableName) } -func (s *Sqlite3) DeleteVersion() string { +func (s *Sqlite3) DeleteVersion(tableName string) string { q := `DELETE FROM %s WHERE version_id=?` - return fmt.Sprintf(q, s.Table) + return fmt.Sprintf(q, tableName) } -func (s *Sqlite3) GetMigrationByVersion() string { +func (s *Sqlite3) GetMigrationByVersion(tableName string) string { q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1` - return fmt.Sprintf(q, s.Table) + return fmt.Sprintf(q, tableName) } -func (s *Sqlite3) ListMigrations() string { +func (s *Sqlite3) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied from %s ORDER BY id DESC` - return fmt.Sprintf(q, s.Table) + return fmt.Sprintf(q, tableName) } diff --git a/internal/dialect/dialectquery/sqlserver.go b/internal/dialect/dialectquery/sqlserver.go index 63e0e53f0..0caa2f644 100644 --- a/internal/dialect/dialectquery/sqlserver.go +++ b/internal/dialect/dialectquery/sqlserver.go @@ -2,33 +2,31 @@ package dialectquery import "fmt" -type Sqlserver struct { - Table string -} +type Sqlserver struct{} var _ Querier = (*Sqlserver)(nil) -func (s *Sqlserver) CreateTable() string { +func (s *Sqlserver) CreateTable(tableName string) string { q := `CREATE TABLE %s ( id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, version_id BIGINT NOT NULL, is_applied BIT NOT NULL, tstamp DATETIME NULL DEFAULT CURRENT_TIMESTAMP )` - return fmt.Sprintf(q, s.Table) + return fmt.Sprintf(q, tableName) } -func (s *Sqlserver) InsertVersion() string { +func (s *Sqlserver) InsertVersion(tableName string) string { q := `INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2)` - return fmt.Sprintf(q, s.Table) + return fmt.Sprintf(q, tableName) } -func (s *Sqlserver) DeleteVersion() string { +func (s *Sqlserver) DeleteVersion(tableName string) string { q := `DELETE FROM %s WHERE version_id=@p1` - return fmt.Sprintf(q, s.Table) + return fmt.Sprintf(q, tableName) } -func (s *Sqlserver) GetMigrationByVersion() string { +func (s *Sqlserver) GetMigrationByVersion(tableName string) string { q := ` WITH Migrations AS ( @@ -42,10 +40,10 @@ FROM Migrations WHERE RowNumber BETWEEN 1 AND 2 ORDER BY tstamp DESC ` - return fmt.Sprintf(q, s.Table) + return fmt.Sprintf(q, tableName) } -func (s *Sqlserver) ListMigrations() string { +func (s *Sqlserver) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied FROM %s ORDER BY id DESC` - return fmt.Sprintf(q, s.Table) + return fmt.Sprintf(q, tableName) } diff --git a/internal/dialect/dialectquery/tidb.go b/internal/dialect/dialectquery/tidb.go index 8b98dde44..984e60a7a 100644 --- a/internal/dialect/dialectquery/tidb.go +++ b/internal/dialect/dialectquery/tidb.go @@ -2,13 +2,11 @@ package dialectquery import "fmt" -type Tidb struct { - Table string -} +type Tidb struct{} var _ Querier = (*Tidb)(nil) -func (t *Tidb) CreateTable() string { +func (t *Tidb) CreateTable(tableName string) string { q := `CREATE TABLE %s ( id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE, version_id bigint NOT NULL, @@ -16,25 +14,25 @@ func (t *Tidb) CreateTable() string { tstamp timestamp NULL default now(), PRIMARY KEY(id) )` - return fmt.Sprintf(q, t.Table) + return fmt.Sprintf(q, tableName) } -func (t *Tidb) InsertVersion() string { +func (t *Tidb) InsertVersion(tableName string) string { q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)` - return fmt.Sprintf(q, t.Table) + return fmt.Sprintf(q, tableName) } -func (t *Tidb) DeleteVersion() string { +func (t *Tidb) DeleteVersion(tableName string) string { q := `DELETE FROM %s WHERE version_id=?` - return fmt.Sprintf(q, t.Table) + return fmt.Sprintf(q, tableName) } -func (t *Tidb) GetMigrationByVersion() string { +func (t *Tidb) GetMigrationByVersion(tableName string) string { q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1` - return fmt.Sprintf(q, t.Table) + return fmt.Sprintf(q, tableName) } -func (t *Tidb) ListMigrations() string { +func (t *Tidb) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied from %s ORDER BY id DESC` - return fmt.Sprintf(q, t.Table) + return fmt.Sprintf(q, tableName) } diff --git a/internal/dialect/dialectquery/vertica.go b/internal/dialect/dialectquery/vertica.go index 7a4025e06..4964aeaf6 100644 --- a/internal/dialect/dialectquery/vertica.go +++ b/internal/dialect/dialectquery/vertica.go @@ -2,13 +2,11 @@ package dialectquery import "fmt" -type Vertica struct { - Table string -} +type Vertica struct{} var _ Querier = (*Vertica)(nil) -func (v *Vertica) CreateTable() string { +func (v *Vertica) CreateTable(tableName string) string { q := `CREATE TABLE %s ( id identity(1,1) NOT NULL, version_id bigint NOT NULL, @@ -16,25 +14,25 @@ func (v *Vertica) CreateTable() string { tstamp timestamp NULL default now(), PRIMARY KEY(id) )` - return fmt.Sprintf(q, v.Table) + return fmt.Sprintf(q, tableName) } -func (v *Vertica) InsertVersion() string { +func (v *Vertica) InsertVersion(tableName string) string { q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)` - return fmt.Sprintf(q, v.Table) + return fmt.Sprintf(q, tableName) } -func (v *Vertica) DeleteVersion() string { +func (v *Vertica) DeleteVersion(tableName string) string { q := `DELETE FROM %s WHERE version_id=?` - return fmt.Sprintf(q, v.Table) + return fmt.Sprintf(q, tableName) } -func (v *Vertica) GetMigrationByVersion() string { +func (v *Vertica) GetMigrationByVersion(tableName string) string { q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1` - return fmt.Sprintf(q, v.Table) + return fmt.Sprintf(q, tableName) } -func (v *Vertica) ListMigrations() string { +func (v *Vertica) ListMigrations(tableName string) string { q := `SELECT version_id, is_applied from %s ORDER BY id DESC` - return fmt.Sprintf(q, v.Table) + return fmt.Sprintf(q, tableName) } diff --git a/internal/dialect/store.go b/internal/dialect/store.go index e16e0356a..b51cdcefe 100644 --- a/internal/dialect/store.go +++ b/internal/dialect/store.go @@ -3,7 +3,6 @@ package dialect import ( "context" "database/sql" - "errors" "fmt" "time" @@ -22,55 +21,50 @@ import ( type Store interface { // CreateVersionTable creates the version table within a transaction. // This table is used to store goose migrations. - CreateVersionTable(ctx context.Context, tx *sql.Tx) error + CreateVersionTable(ctx context.Context, tx *sql.Tx, tableName string) error // InsertVersion inserts a version id into the version table within a transaction. - InsertVersion(ctx context.Context, tx *sql.Tx, version int64) error + InsertVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error // InsertVersionNoTx inserts a version id into the version table without a transaction. - InsertVersionNoTx(ctx context.Context, db *sql.DB, version int64) error + InsertVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error // DeleteVersion deletes a version id from the version table within a transaction. - DeleteVersion(ctx context.Context, tx *sql.Tx, version int64) error + DeleteVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error // DeleteVersionNoTx deletes a version id from the version table without a transaction. - DeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64) error + DeleteVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error // GetMigrationRow retrieves a single migration by version id. // // Returns the raw sql error if the query fails. It is the callers responsibility // to assert for the correct error, such as sql.ErrNoRows. - GetMigration(ctx context.Context, db *sql.DB, version int64) (*GetMigrationResult, error) + GetMigration(ctx context.Context, db *sql.DB, tableName string, version int64) (*GetMigrationResult, error) // ListMigrations retrieves all migrations sorted in descending order by id. // // If there are no migrations, an empty slice is returned with no error. - ListMigrations(ctx context.Context, db *sql.DB) ([]*ListMigrationsResult, error) + ListMigrations(ctx context.Context, db *sql.DB, tableName string) ([]*ListMigrationsResult, error) } // NewStore returns a new Store for the given dialect. -// -// The table name is used to store the goose migrations. -func NewStore(d Dialect, table string) (Store, error) { - if table == "" { - return nil, errors.New("table name cannot be empty") - } +func NewStore(d Dialect) (Store, error) { var querier dialectquery.Querier switch d { case Postgres: - querier = &dialectquery.Postgres{Table: table} + querier = &dialectquery.Postgres{} case Mysql: - querier = &dialectquery.Mysql{Table: table} + querier = &dialectquery.Mysql{} case Sqlite3: - querier = &dialectquery.Sqlite3{Table: table} + querier = &dialectquery.Sqlite3{} case Sqlserver: - querier = &dialectquery.Sqlserver{Table: table} + querier = &dialectquery.Sqlserver{} case Redshift: - querier = &dialectquery.Redshift{Table: table} + querier = &dialectquery.Redshift{} case Tidb: - querier = &dialectquery.Tidb{Table: table} + querier = &dialectquery.Tidb{} case Clickhouse: - querier = &dialectquery.Clickhouse{Table: table} + querier = &dialectquery.Clickhouse{} case Vertica: - querier = &dialectquery.Vertica{Table: table} + querier = &dialectquery.Vertica{} default: return nil, fmt.Errorf("unknown querier dialect: %v", d) } @@ -93,38 +87,38 @@ type store struct { var _ Store = (*store)(nil) -func (s *store) CreateVersionTable(ctx context.Context, tx *sql.Tx) error { - q := s.querier.CreateTable() +func (s *store) CreateVersionTable(ctx context.Context, tx *sql.Tx, tableName string) error { + q := s.querier.CreateTable(tableName) _, err := tx.ExecContext(ctx, q) return err } -func (s *store) InsertVersion(ctx context.Context, tx *sql.Tx, version int64) error { - q := s.querier.InsertVersion() +func (s *store) InsertVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error { + q := s.querier.InsertVersion(tableName) _, err := tx.ExecContext(ctx, q, version, true) return err } -func (s *store) InsertVersionNoTx(ctx context.Context, db *sql.DB, version int64) error { - q := s.querier.InsertVersion() +func (s *store) InsertVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error { + q := s.querier.InsertVersion(tableName) _, err := db.ExecContext(ctx, q, version, true) return err } -func (s *store) DeleteVersion(ctx context.Context, tx *sql.Tx, version int64) error { - q := s.querier.DeleteVersion() +func (s *store) DeleteVersion(ctx context.Context, tx *sql.Tx, tableName string, version int64) error { + q := s.querier.DeleteVersion(tableName) _, err := tx.ExecContext(ctx, q, version) return err } -func (s *store) DeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64) error { - q := s.querier.DeleteVersion() +func (s *store) DeleteVersionNoTx(ctx context.Context, db *sql.DB, tableName string, version int64) error { + q := s.querier.DeleteVersion(tableName) _, err := db.ExecContext(ctx, q, version) return err } -func (s *store) GetMigration(ctx context.Context, db *sql.DB, version int64) (*GetMigrationResult, error) { - q := s.querier.GetMigrationByVersion() +func (s *store) GetMigration(ctx context.Context, db *sql.DB, tableName string, version int64) (*GetMigrationResult, error) { + q := s.querier.GetMigrationByVersion(tableName) var timestamp time.Time var isApplied bool err := db.QueryRowContext(ctx, q, version).Scan(×tamp, &isApplied) @@ -137,8 +131,8 @@ func (s *store) GetMigration(ctx context.Context, db *sql.DB, version int64) (*G }, nil } -func (s *store) ListMigrations(ctx context.Context, db *sql.DB) ([]*ListMigrationsResult, error) { - q := s.querier.ListMigrations() +func (s *store) ListMigrations(ctx context.Context, db *sql.DB, tableName string) ([]*ListMigrationsResult, error) { + q := s.querier.ListMigrations(tableName) rows, err := db.QueryContext(ctx, q) if err != nil { return nil, err diff --git a/migrate.go b/migrate.go index baf6680ac..0cf96a860 100644 --- a/migrate.go +++ b/migrate.go @@ -296,7 +296,7 @@ func versionFilter(v, current, target int64) bool { // Create and initialize the DB version table if it doesn't exist. func EnsureDBVersion(db *sql.DB) (int64, error) { ctx := context.Background() - dbMigrations, err := store.ListMigrations(ctx, db) + dbMigrations, err := store.ListMigrations(ctx, db, TableName()) if err != nil { return 0, createVersionTable(ctx, db) } @@ -336,11 +336,11 @@ func createVersionTable(ctx context.Context, db *sql.DB) error { if err != nil { return err } - if err := store.CreateVersionTable(ctx, txn); err != nil { + if err := store.CreateVersionTable(ctx, txn, TableName()); err != nil { _ = txn.Rollback() return err } - if err := store.InsertVersion(ctx, txn, 0); err != nil { + if err := store.InsertVersion(ctx, txn, TableName(), 0); err != nil { _ = txn.Rollback() return err } diff --git a/migration.go b/migration.go index f3727338e..727ecc3d9 100644 --- a/migration.go +++ b/migration.go @@ -188,16 +188,16 @@ func runGoMigration( func insertOrDeleteVersion(ctx context.Context, tx *sql.Tx, version int64, direction bool) error { if direction { - return store.InsertVersion(ctx, tx, version) + return store.InsertVersion(ctx, tx, TableName(), version) } - return store.DeleteVersion(ctx, tx, version) + return store.DeleteVersion(ctx, tx, TableName(), version) } func insertOrDeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64, direction bool) error { if direction { - return store.InsertVersionNoTx(ctx, db, version) + return store.InsertVersionNoTx(ctx, db, TableName(), version) } - return store.DeleteVersionNoTx(ctx, db, version) + return store.DeleteVersionNoTx(ctx, db, TableName(), version) } // NumericComponent looks for migration scripts with names in the form: diff --git a/migration_sql.go b/migration_sql.go index d45fa64b5..f74b70d75 100644 --- a/migration_sql.go +++ b/migration_sql.go @@ -45,13 +45,13 @@ func runSQLMigration( if !noVersioning { if direction { - if err := store.InsertVersion(ctx, tx, v); err != nil { + if err := store.InsertVersion(ctx, tx, TableName(), v); err != nil { verboseInfo("Rollback transaction") _ = tx.Rollback() return fmt.Errorf("failed to insert new goose version: %w", err) } } else { - if err := store.DeleteVersion(ctx, tx, v); err != nil { + if err := store.DeleteVersion(ctx, tx, TableName(), v); err != nil { verboseInfo("Rollback transaction") _ = tx.Rollback() return fmt.Errorf("failed to delete goose version: %w", err) @@ -76,11 +76,11 @@ func runSQLMigration( } if !noVersioning { if direction { - if err := store.InsertVersionNoTx(ctx, db, v); err != nil { + if err := store.InsertVersionNoTx(ctx, db, TableName(), v); err != nil { return fmt.Errorf("failed to insert new goose version: %w", err) } } else { - if err := store.DeleteVersionNoTx(ctx, db, v); err != nil { + if err := store.DeleteVersionNoTx(ctx, db, TableName(), v); err != nil { return fmt.Errorf("failed to delete goose version: %w", err) } } diff --git a/reset.go b/reset.go index 7be46179c..e14d36d22 100644 --- a/reset.go +++ b/reset.go @@ -41,7 +41,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { } func dbMigrationsStatus(ctx context.Context, db *sql.DB) (map[int64]bool, error) { - dbMigrations, err := store.ListMigrations(ctx, db) + dbMigrations, err := store.ListMigrations(ctx, db, TableName()) if err != nil { return nil, err } diff --git a/status.go b/status.go index 73e73decf..dd1f16c4c 100644 --- a/status.go +++ b/status.go @@ -46,7 +46,7 @@ func Status(db *sql.DB, dir string, opts ...OptionsFunc) error { } func printMigrationStatus(ctx context.Context, db *sql.DB, version int64, script string) error { - m, err := store.GetMigration(ctx, db, version) + m, err := store.GetMigration(ctx, db, TableName(), version) if err != nil && !errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("failed to query the latest migration: %w", err) } diff --git a/up.go b/up.go index bc8ddc789..66ba2d6fe 100644 --- a/up.go +++ b/up.go @@ -225,7 +225,7 @@ func UpByOne(db *sql.DB, dir string, opts ...OptionsFunc) error { // listAllDBVersions returns a list of all migrations, ordered ascending. // TODO(mf): fairly cheap, but a nice-to-have is pagination support. func listAllDBVersions(ctx context.Context, db *sql.DB) (Migrations, error) { - dbMigrations, err := store.ListMigrations(ctx, db) + dbMigrations, err := store.ListMigrations(ctx, db, TableName()) if err != nil { return nil, err }