From c462979327385eb7bf68e5bf2e232e8fbd3f9773 Mon Sep 17 00:00:00 2001 From: Michael Fridman Date: Tue, 14 Mar 2023 08:34:47 -0400 Subject: [PATCH] refactor: create a generic store and stub out dialect queries (#477) --- dialect.go | 365 +----------------- internal/dialect/dialectquery/clickhouse.go | 39 ++ internal/dialect/dialectquery/dialectquery.go | 60 +++ internal/dialect/dialectquery/mysql.go | 38 ++ internal/dialect/dialectquery/postgres.go | 38 ++ internal/dialect/dialectquery/redshift.go | 38 ++ internal/dialect/dialectquery/sqlite3.go | 37 ++ internal/dialect/dialectquery/sqlserver.go | 49 +++ internal/dialect/dialectquery/tidb.go | 38 ++ internal/dialect/dialectquery/vertica.go | 38 ++ internal/dialect/dialects.go | 15 + internal/dialect/store.go | 164 ++++++++ migrate.go | 74 ++-- migration.go | 35 +- migration_sql.go | 52 +-- reset.go | 30 +- status.go | 25 +- up.go | 29 +- 18 files changed, 671 insertions(+), 493 deletions(-) create mode 100644 internal/dialect/dialectquery/clickhouse.go create mode 100644 internal/dialect/dialectquery/dialectquery.go create mode 100644 internal/dialect/dialectquery/mysql.go create mode 100644 internal/dialect/dialectquery/postgres.go create mode 100644 internal/dialect/dialectquery/redshift.go create mode 100644 internal/dialect/dialectquery/sqlite3.go create mode 100644 internal/dialect/dialectquery/sqlserver.go create mode 100644 internal/dialect/dialectquery/tidb.go create mode 100644 internal/dialect/dialectquery/vertica.go create mode 100644 internal/dialect/dialects.go create mode 100644 internal/dialect/store.go diff --git a/dialect.go b/dialect.go index abda163ea..92641a3f7 100644 --- a/dialect.go +++ b/dialect.go @@ -1,364 +1,41 @@ package goose import ( - "database/sql" "fmt" + + "github.com/pressly/goose/v3/internal/dialect" ) -// SQLDialect abstracts the details of specific SQL dialects -// for goose's few SQL specific statements -type SQLDialect interface { - createVersionTableSQL() string // sql string to create the db version table - insertVersionSQL() string // sql string to insert the initial version table row - deleteVersionSQL() string // sql string to delete version - migrationSQL() string // sql string to retrieve migrations - dbVersionQuery(db *sql.DB) (*sql.Rows, error) +func init() { + store, _ = dialect.NewStore(dialect.Postgres, TableName()) } -var dialect SQLDialect = &PostgresDialect{} - -// GetDialect gets the SQLDialect -func GetDialect() SQLDialect { - return dialect -} +var store dialect.Store -// SetDialect sets the SQLDialect -func SetDialect(d string) error { - switch d { +// SetDialect sets the dialect to use for the goose package. +func SetDialect(s string) error { + var d dialect.Dialect + switch s { case "postgres", "pgx": - dialect = &PostgresDialect{} + d = dialect.Postgres case "mysql": - dialect = &MySQLDialect{} + d = dialect.Mysql case "sqlite3", "sqlite": - dialect = &Sqlite3Dialect{} + d = dialect.Sqlite3 case "mssql": - dialect = &SqlServerDialect{} + d = dialect.Sqlserver case "redshift": - dialect = &RedshiftDialect{} + d = dialect.Redshift case "tidb": - dialect = &TiDBDialect{} + d = dialect.Tidb case "clickhouse": - dialect = &ClickHouseDialect{} + d = dialect.Clickhouse case "vertica": - dialect = &VerticaDialect{} + d = dialect.Vertica default: - return fmt.Errorf("%q: unknown dialect", d) - } - - return nil -} - -//////////////////////////// -// Postgres -//////////////////////////// - -// PostgresDialect struct. -type PostgresDialect struct{} - -func (pg PostgresDialect) createVersionTableSQL() string { - return fmt.Sprintf(`CREATE TABLE %s ( - id serial NOT NULL, - version_id bigint NOT NULL, - is_applied boolean NOT NULL, - tstamp timestamp NULL default now(), - PRIMARY KEY(id) - );`, TableName()) -} - -func (pg PostgresDialect) insertVersionSQL() string { - return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName()) -} - -func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) - if err != nil { - return nil, err - } - - return rows, err -} - -func (m PostgresDialect) migrationSQL() string { - return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName()) -} - -func (pg PostgresDialect) deleteVersionSQL() string { - return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName()) -} - -//////////////////////////// -// MySQL -//////////////////////////// - -// MySQLDialect struct. -type MySQLDialect struct{} - -func (m MySQLDialect) createVersionTableSQL() string { - return fmt.Sprintf(`CREATE TABLE %s ( - id serial NOT NULL, - version_id bigint NOT NULL, - is_applied boolean NOT NULL, - tstamp timestamp NULL default now(), - PRIMARY KEY(id) - );`, TableName()) -} - -func (m MySQLDialect) insertVersionSQL() string { - return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) -} - -func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) - if err != nil { - return nil, err - } - - return rows, err -} - -func (m MySQLDialect) migrationSQL() string { - return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName()) -} - -func (m MySQLDialect) deleteVersionSQL() string { - return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) -} - -//////////////////////////// -// MSSQL -//////////////////////////// - -// SqlServerDialect struct. -type SqlServerDialect struct{} - -func (m SqlServerDialect) createVersionTableSQL() string { - return fmt.Sprintf(`CREATE TABLE %s ( - id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, - version_id BIGINT NOT NULL, - is_applied BIT NOT NULL, - tstamp DATETIME NULL DEFAULT CURRENT_TIMESTAMP - );`, TableName()) -} - -func (m SqlServerDialect) insertVersionSQL() string { - return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2);", TableName()) -} - -func (m SqlServerDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied FROM %s ORDER BY id DESC", TableName())) - if err != nil { - return nil, err - } - - return rows, err -} - -func (m SqlServerDialect) migrationSQL() string { - const tpl = ` -WITH Migrations AS -( - SELECT tstamp, is_applied, - ROW_NUMBER() OVER (ORDER BY tstamp) AS 'RowNumber' - FROM %s - WHERE version_id=@p1 -) -SELECT tstamp, is_applied -FROM Migrations -WHERE RowNumber BETWEEN 1 AND 2 -ORDER BY tstamp DESC -` - return fmt.Sprintf(tpl, TableName()) -} - -func (m SqlServerDialect) deleteVersionSQL() string { - return fmt.Sprintf("DELETE FROM %s WHERE version_id=@p1;", TableName()) -} - -//////////////////////////// -// sqlite3 -//////////////////////////// - -// Sqlite3Dialect struct. -type Sqlite3Dialect struct{} - -func (m Sqlite3Dialect) createVersionTableSQL() string { - return fmt.Sprintf(`CREATE TABLE %s ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - version_id INTEGER NOT NULL, - is_applied INTEGER NOT NULL, - tstamp TIMESTAMP DEFAULT (datetime('now')) - );`, TableName()) -} - -func (m Sqlite3Dialect) insertVersionSQL() string { - return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) -} - -func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) - if err != nil { - return nil, err - } - - return rows, err -} - -func (m Sqlite3Dialect) migrationSQL() string { - return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName()) -} - -func (m Sqlite3Dialect) deleteVersionSQL() string { - return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) -} - -//////////////////////////// -// Redshift -//////////////////////////// - -// RedshiftDialect struct. -type RedshiftDialect struct{} - -func (rs RedshiftDialect) createVersionTableSQL() string { - return fmt.Sprintf(`CREATE TABLE %s ( - id integer NOT NULL identity(1, 1), - version_id bigint NOT NULL, - is_applied boolean NOT NULL, - tstamp timestamp NULL default sysdate, - PRIMARY KEY(id) - );`, TableName()) -} - -func (rs RedshiftDialect) insertVersionSQL() string { - return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName()) -} - -func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) - if err != nil { - return nil, err - } - - return rows, err -} - -func (m RedshiftDialect) migrationSQL() string { - return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName()) -} - -func (rs RedshiftDialect) deleteVersionSQL() string { - return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName()) -} - -//////////////////////////// -// TiDB -//////////////////////////// - -// TiDBDialect struct. -type TiDBDialect struct{} - -func (m TiDBDialect) createVersionTableSQL() string { - return fmt.Sprintf(`CREATE TABLE %s ( - id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE, - version_id bigint NOT NULL, - is_applied boolean NOT NULL, - tstamp timestamp NULL default now(), - PRIMARY KEY(id) - );`, TableName()) -} - -func (m TiDBDialect) insertVersionSQL() string { - return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) -} - -func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) - if err != nil { - return nil, err - } - - return rows, err -} - -func (m TiDBDialect) migrationSQL() string { - return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName()) -} - -func (m TiDBDialect) deleteVersionSQL() string { - return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) -} - -//////////////////////////// -// ClickHouse -//////////////////////////// - -// ClickHouseDialect struct. -type ClickHouseDialect struct{} - -func (m ClickHouseDialect) createVersionTableSQL() string { - return fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( - version_id Int64, - is_applied UInt8, - date Date default now(), - tstamp DateTime default now() - ) - ENGINE = MergeTree() - ORDER BY (date)`, TableName()) -} - -func (m ClickHouseDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied FROM %s ORDER BY version_id DESC", TableName())) - if err != nil { - return nil, err + return fmt.Errorf("%q: unknown dialect", s) } - return rows, err -} - -func (m ClickHouseDialect) insertVersionSQL() string { - return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)", TableName()) -} - -func (m ClickHouseDialect) migrationSQL() string { - return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id = $1 ORDER BY tstamp DESC LIMIT 1", TableName()) -} - -func (m ClickHouseDialect) deleteVersionSQL() string { - return fmt.Sprintf("ALTER TABLE %s DELETE WHERE version_id = $1 SETTINGS mutations_sync = 2", TableName()) -} - -//////////////////////////// -// Vertica -//////////////////////////// - -// VerticaDialect struct. -type VerticaDialect struct{} - -func (v VerticaDialect) createVersionTableSQL() string { - return fmt.Sprintf(`CREATE TABLE %s ( - id identity(1,1) NOT NULL, - version_id bigint NOT NULL, - is_applied boolean NOT NULL, - tstamp timestamp NULL default now(), - PRIMARY KEY(id) - );`, TableName()) -} - -func (v VerticaDialect) insertVersionSQL() string { - return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) -} - -func (v VerticaDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { - rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) - if err != nil { - return nil, err - } - - return rows, err -} - -func (m VerticaDialect) migrationSQL() string { - return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName()) -} - -func (v VerticaDialect) deleteVersionSQL() string { - return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) + var err error + store, err = dialect.NewStore(d, TableName()) + return err } diff --git a/internal/dialect/dialectquery/clickhouse.go b/internal/dialect/dialectquery/clickhouse.go new file mode 100644 index 000000000..3104b89ad --- /dev/null +++ b/internal/dialect/dialectquery/clickhouse.go @@ -0,0 +1,39 @@ +package dialectquery + +import "fmt" + +type clickhouse struct { + table string +} + +func (c *clickhouse) CreateTable() string { + q := `CREATE TABLE IF NOT EXISTS %s ( + version_id Int64, + is_applied UInt8, + date Date default now(), + tstamp DateTime default now() + ) + ENGINE = MergeTree() + ORDER BY (date)` + return fmt.Sprintf(q, c.table) +} + +func (c *clickhouse) InsertVersion() string { + q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)` + return fmt.Sprintf(q, c.table) +} + +func (c *clickhouse) DeleteVersion() string { + q := `ALTER TABLE %s DELETE WHERE version_id = $1 SETTINGS mutations_sync = 2` + return fmt.Sprintf(q, c.table) +} + +func (c *clickhouse) GetMigrationByVersion() string { + q := `SELECT tstamp, is_applied FROM %s WHERE version_id = $1 ORDER BY tstamp DESC LIMIT 1` + return fmt.Sprintf(q, c.table) +} + +func (c *clickhouse) ListMigrations() string { + q := `SELECT version_id, is_applied FROM %s ORDER BY version_id DESC` + return fmt.Sprintf(q, c.table) +} diff --git a/internal/dialect/dialectquery/dialectquery.go b/internal/dialect/dialectquery/dialectquery.go new file mode 100644 index 000000000..cea39318c --- /dev/null +++ b/internal/dialect/dialectquery/dialectquery.go @@ -0,0 +1,60 @@ +package dialectquery + +// Querier is the interface that wraps the basic methods to create a dialect +// specific query. +type Querier interface { + // CreateTable returns the SQL query string to create the db version table. + CreateTable() string + + // InsertVersion returns the SQL query string to insert a new version into + // the db version table. + InsertVersion() string + + // DeleteVersion returns the SQL query string to delete a version from + // the db version table. + DeleteVersion() string + + // GetMigrationByVersion returns the SQL query string to get a single + // migration by version. + // + // The query should return the timestamp and is_applied columns. + GetMigrationByVersion() string + + // ListMigrations returns the SQL query string to list all migrations in + // descending order by id. + // + // The query should return the version_id and is_applied columns. + ListMigrations() string +} + +func NewPostgres(table string) Querier { + return &postgres{table: table} +} + +func NewMysql(table string) Querier { + return &mysql{table: table} +} + +func NewSqlite3(table string) Querier { + return &sqlite3{table: table} +} + +func NewSqlserver(table string) Querier { + return &sqlserver{table: table} +} + +func NewRedshift(table string) Querier { + return &redshift{table: table} +} + +func NewTidb(table string) Querier { + return &tidb{table: table} +} + +func NewClickhouse(table string) Querier { + return &clickhouse{table: table} +} + +func NewVertica(table string) Querier { + return &vertica{table: table} +} diff --git a/internal/dialect/dialectquery/mysql.go b/internal/dialect/dialectquery/mysql.go new file mode 100644 index 000000000..17ee7d6f1 --- /dev/null +++ b/internal/dialect/dialectquery/mysql.go @@ -0,0 +1,38 @@ +package dialectquery + +import "fmt" + +type mysql struct { + table string +} + +func (m *mysql) CreateTable() string { + q := `CREATE TABLE %s ( + id serial NOT NULL, + version_id bigint NOT NULL, + is_applied boolean NOT NULL, + tstamp timestamp NULL default now(), + PRIMARY KEY(id) + )` + return fmt.Sprintf(q, m.table) +} + +func (m *mysql) InsertVersion() string { + q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)` + return fmt.Sprintf(q, m.table) +} + +func (m *mysql) DeleteVersion() string { + q := `DELETE FROM %s WHERE version_id=?` + return fmt.Sprintf(q, m.table) +} + +func (m *mysql) GetMigrationByVersion() string { + q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1` + return fmt.Sprintf(q, m.table) +} + +func (m *mysql) ListMigrations() string { + q := `SELECT version_id, is_applied from %s ORDER BY id DESC` + return fmt.Sprintf(q, m.table) +} diff --git a/internal/dialect/dialectquery/postgres.go b/internal/dialect/dialectquery/postgres.go new file mode 100644 index 000000000..6f9fc9ab5 --- /dev/null +++ b/internal/dialect/dialectquery/postgres.go @@ -0,0 +1,38 @@ +package dialectquery + +import "fmt" + +type postgres struct { + table string +} + +func (p *postgres) CreateTable() string { + q := `CREATE TABLE %s ( + id serial NOT NULL, + version_id bigint NOT NULL, + is_applied boolean NOT NULL, + tstamp timestamp NULL default now(), + PRIMARY KEY(id) + )` + return fmt.Sprintf(q, p.table) +} + +func (p *postgres) InsertVersion() string { + q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)` + return fmt.Sprintf(q, p.table) +} + +func (p *postgres) DeleteVersion() string { + q := `DELETE FROM %s WHERE version_id=$1` + return fmt.Sprintf(q, p.table) +} + +func (p *postgres) GetMigrationByVersion() string { + q := `SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1` + return fmt.Sprintf(q, p.table) +} + +func (p *postgres) ListMigrations() string { + q := `SELECT version_id, is_applied from %s ORDER BY id DESC` + return fmt.Sprintf(q, p.table) +} diff --git a/internal/dialect/dialectquery/redshift.go b/internal/dialect/dialectquery/redshift.go new file mode 100644 index 000000000..3ba0cd286 --- /dev/null +++ b/internal/dialect/dialectquery/redshift.go @@ -0,0 +1,38 @@ +package dialectquery + +import "fmt" + +type redshift struct { + table string +} + +func (r *redshift) CreateTable() string { + q := `CREATE TABLE %s ( + id integer NOT NULL identity(1, 1), + version_id bigint NOT NULL, + is_applied boolean NOT NULL, + tstamp timestamp NULL default sysdate, + PRIMARY KEY(id) + )` + return fmt.Sprintf(q, r.table) +} + +func (r *redshift) InsertVersion() string { + q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)` + return fmt.Sprintf(q, r.table) +} + +func (r *redshift) DeleteVersion() string { + q := `DELETE FROM %s WHERE version_id=$1` + return fmt.Sprintf(q, r.table) +} + +func (r *redshift) GetMigrationByVersion() string { + q := `SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1` + return fmt.Sprintf(q, r.table) +} + +func (r *redshift) ListMigrations() string { + q := `SELECT version_id, is_applied from %s ORDER BY id DESC` + return fmt.Sprintf(q, r.table) +} diff --git a/internal/dialect/dialectquery/sqlite3.go b/internal/dialect/dialectquery/sqlite3.go new file mode 100644 index 000000000..ecafb7d00 --- /dev/null +++ b/internal/dialect/dialectquery/sqlite3.go @@ -0,0 +1,37 @@ +package dialectquery + +import "fmt" + +type sqlite3 struct { + table string +} + +func (s *sqlite3) CreateTable() string { + q := `CREATE TABLE %s ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + version_id INTEGER NOT NULL, + is_applied INTEGER NOT NULL, + tstamp TIMESTAMP DEFAULT (datetime('now')) + )` + return fmt.Sprintf(q, s.table) +} + +func (s *sqlite3) InsertVersion() string { + q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)` + return fmt.Sprintf(q, s.table) +} + +func (s *sqlite3) DeleteVersion() string { + q := `DELETE FROM %s WHERE version_id=?` + return fmt.Sprintf(q, s.table) +} + +func (s *sqlite3) GetMigrationByVersion() string { + q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1` + return fmt.Sprintf(q, s.table) +} + +func (s *sqlite3) ListMigrations() string { + q := `SELECT version_id, is_applied from %s ORDER BY id DESC` + return fmt.Sprintf(q, s.table) +} diff --git a/internal/dialect/dialectquery/sqlserver.go b/internal/dialect/dialectquery/sqlserver.go new file mode 100644 index 000000000..28d426304 --- /dev/null +++ b/internal/dialect/dialectquery/sqlserver.go @@ -0,0 +1,49 @@ +package dialectquery + +import "fmt" + +type sqlserver struct { + table string +} + +func (s *sqlserver) CreateTable() string { + q := `CREATE TABLE %s ( + id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, + version_id BIGINT NOT NULL, + is_applied BIT NOT NULL, + tstamp DATETIME NULL DEFAULT CURRENT_TIMESTAMP + )` + return fmt.Sprintf(q, s.table) +} + +func (s *sqlserver) InsertVersion() string { + q := `INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2)` + return fmt.Sprintf(q, s.table) +} + +func (s *sqlserver) DeleteVersion() string { + q := `DELETE FROM %s WHERE version_id=@p1` + return fmt.Sprintf(q, s.table) +} + +func (s *sqlserver) GetMigrationByVersion() string { + q := ` +WITH Migrations AS +( + SELECT tstamp, is_applied, + ROW_NUMBER() OVER (ORDER BY tstamp) AS 'RowNumber' + FROM %s + WHERE version_id=@p1 +) +SELECT tstamp, is_applied +FROM Migrations +WHERE RowNumber BETWEEN 1 AND 2 +ORDER BY tstamp DESC +` + return fmt.Sprintf(q, s.table) +} + +func (s *sqlserver) ListMigrations() string { + q := `SELECT version_id, is_applied FROM %s ORDER BY id DESC` + return fmt.Sprintf(q, s.table) +} diff --git a/internal/dialect/dialectquery/tidb.go b/internal/dialect/dialectquery/tidb.go new file mode 100644 index 000000000..bf96ccc47 --- /dev/null +++ b/internal/dialect/dialectquery/tidb.go @@ -0,0 +1,38 @@ +package dialectquery + +import "fmt" + +type tidb struct { + table string +} + +func (t *tidb) CreateTable() string { + q := `CREATE TABLE %s ( + id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE, + version_id bigint NOT NULL, + is_applied boolean NOT NULL, + tstamp timestamp NULL default now(), + PRIMARY KEY(id) + )` + return fmt.Sprintf(q, t.table) +} + +func (t *tidb) InsertVersion() string { + q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)` + return fmt.Sprintf(q, t.table) +} + +func (t *tidb) DeleteVersion() string { + q := `DELETE FROM %s WHERE version_id=?` + return fmt.Sprintf(q, t.table) +} + +func (t *tidb) GetMigrationByVersion() string { + q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1` + return fmt.Sprintf(q, t.table) +} + +func (t *tidb) ListMigrations() string { + q := `SELECT version_id, is_applied from %s ORDER BY id DESC` + return fmt.Sprintf(q, t.table) +} diff --git a/internal/dialect/dialectquery/vertica.go b/internal/dialect/dialectquery/vertica.go new file mode 100644 index 000000000..67a091257 --- /dev/null +++ b/internal/dialect/dialectquery/vertica.go @@ -0,0 +1,38 @@ +package dialectquery + +import "fmt" + +type vertica struct { + table string +} + +func (v *vertica) CreateTable() string { + q := `CREATE TABLE %s ( + id identity(1,1) NOT NULL, + version_id bigint NOT NULL, + is_applied boolean NOT NULL, + tstamp timestamp NULL default now(), + PRIMARY KEY(id) + )` + return fmt.Sprintf(q, v.table) +} + +func (v *vertica) InsertVersion() string { + q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)` + return fmt.Sprintf(q, v.table) +} + +func (v *vertica) DeleteVersion() string { + q := `DELETE FROM %s WHERE version_id=?` + return fmt.Sprintf(q, v.table) +} + +func (v *vertica) GetMigrationByVersion() string { + q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1` + return fmt.Sprintf(q, v.table) +} + +func (v *vertica) ListMigrations() string { + q := `SELECT version_id, is_applied from %s ORDER BY id DESC` + return fmt.Sprintf(q, v.table) +} diff --git a/internal/dialect/dialects.go b/internal/dialect/dialects.go new file mode 100644 index 000000000..377140c62 --- /dev/null +++ b/internal/dialect/dialects.go @@ -0,0 +1,15 @@ +package dialect + +// Dialect is the type of database dialect. +type Dialect string + +const ( + Postgres Dialect = "postgres" + Mysql Dialect = "mysql" + Sqlite3 Dialect = "sqlite3" + Sqlserver Dialect = "sqlserver" + Redshift Dialect = "redshift" + Tidb Dialect = "tidb" + Clickhouse Dialect = "clickhouse" + Vertica Dialect = "vertica" +) diff --git a/internal/dialect/store.go b/internal/dialect/store.go new file mode 100644 index 000000000..1b330e3b3 --- /dev/null +++ b/internal/dialect/store.go @@ -0,0 +1,164 @@ +package dialect + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/pressly/goose/v3/internal/dialect/dialectquery" +) + +// Store is the interface that wraps the basic methods for a database dialect. +// +// A dialect is a set of SQL statements that are specific to a database. +// +// By defining a store interface, we can support multiple databases +// with a single codebase. +// +// The underlying implementation does not modify the error. It is the callers +// responsibility to assert for the correct error, such as sql.ErrNoRows. +type Store interface { + // CreateVersionTable creates the version table within a transaction. + // This table is used to store goose migrations. + CreateVersionTable(ctx context.Context, tx *sql.Tx) error + + // InsertVersion inserts a version id into the version table within a transaction. + InsertVersion(ctx context.Context, tx *sql.Tx, version int64) error + // InsertVersionNoTx inserts a version id into the version table without a transaction. + InsertVersionNoTx(ctx context.Context, db *sql.DB, version int64) error + + // DeleteVersion deletes a version id from the version table within a transaction. + DeleteVersion(ctx context.Context, tx *sql.Tx, version int64) error + // DeleteVersionNoTx deletes a version id from the version table without a transaction. + DeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64) error + + // GetMigrationRow retrieves a single migration by version id. + // + // Returns the raw sql error if the query fails. It is the callers responsibility + // to assert for the correct error, such as sql.ErrNoRows. + GetMigration(ctx context.Context, db *sql.DB, version int64) (*GetMigrationResult, error) + + // ListMigrations retrieves all migrations sorted in descending order by id. + // + // If there are no migrations, an empty slice is returned with no error. + ListMigrations(ctx context.Context, db *sql.DB) ([]*ListMigrationsResult, error) +} + +// NewStore returns a new Store for the given dialect. +// +// The table name is used to store the goose migrations. +func NewStore(d Dialect, table string) (Store, error) { + if table == "" { + return nil, errors.New("table name cannot be empty") + } + var querier dialectquery.Querier + switch d { + case Postgres: + querier = dialectquery.NewPostgres(table) + case Mysql: + querier = dialectquery.NewMysql(table) + case Sqlite3: + querier = dialectquery.NewSqlite3(table) + case Sqlserver: + querier = dialectquery.NewSqlserver(table) + case Redshift: + querier = dialectquery.NewRedshift(table) + case Tidb: + querier = dialectquery.NewTidb(table) + case Clickhouse: + querier = dialectquery.NewClickhouse(table) + case Vertica: + querier = dialectquery.NewVertica(table) + default: + return nil, fmt.Errorf("unknown querier dialect: %v", d) + } + return &store{querier: querier}, nil +} + +type GetMigrationResult struct { + IsApplied bool + Timestamp time.Time +} + +type ListMigrationsResult struct { + VersionID int64 + IsApplied bool +} + +type store struct { + querier dialectquery.Querier +} + +var _ Store = (*store)(nil) + +func (s *store) CreateVersionTable(ctx context.Context, tx *sql.Tx) error { + q := s.querier.CreateTable() + _, err := tx.ExecContext(ctx, q) + return err +} + +func (s *store) InsertVersion(ctx context.Context, tx *sql.Tx, version int64) error { + q := s.querier.InsertVersion() + _, err := tx.ExecContext(ctx, q, version, true) + return err +} + +func (s *store) InsertVersionNoTx(ctx context.Context, db *sql.DB, version int64) error { + q := s.querier.InsertVersion() + _, err := db.ExecContext(ctx, q, version, true) + return err +} + +func (s *store) DeleteVersion(ctx context.Context, tx *sql.Tx, version int64) error { + q := s.querier.DeleteVersion() + _, err := tx.ExecContext(ctx, q, version) + return err +} + +func (s *store) DeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64) error { + q := s.querier.DeleteVersion() + _, err := db.ExecContext(ctx, q, version) + return err +} + +func (s *store) GetMigration(ctx context.Context, db *sql.DB, version int64) (*GetMigrationResult, error) { + q := s.querier.GetMigrationByVersion() + var timestamp time.Time + var isApplied bool + err := db.QueryRowContext(ctx, q, version).Scan(×tamp, &isApplied) + if err != nil { + return nil, err + } + return &GetMigrationResult{ + IsApplied: isApplied, + Timestamp: timestamp, + }, nil +} + +func (s *store) ListMigrations(ctx context.Context, db *sql.DB) ([]*ListMigrationsResult, error) { + q := s.querier.ListMigrations() + rows, err := db.QueryContext(ctx, q) + if err != nil { + return nil, err + } + defer rows.Close() + + var migrations []*ListMigrationsResult + for rows.Next() { + var version int64 + var isApplied bool + if err := rows.Scan(&version, &isApplied); err != nil { + return nil, err + } + migrations = append(migrations, &ListMigrationsResult{ + VersionID: version, + IsApplied: isApplied, + }) + } + if err := rows.Err(); err != nil { + return nil, err + } + return migrations, nil +} diff --git a/migrate.go b/migrate.go index 02da090df..63faaee17 100644 --- a/migrate.go +++ b/migrate.go @@ -1,6 +1,7 @@ package goose import ( + "context" "database/sql" "errors" "fmt" @@ -294,74 +295,55 @@ func versionFilter(v, current, target int64) bool { // EnsureDBVersion retrieves the current version for this DB. // Create and initialize the DB version table if it doesn't exist. func EnsureDBVersion(db *sql.DB) (int64, error) { - rows, err := GetDialect().dbVersionQuery(db) + ctx := context.Background() + dbMigrations, err := store.ListMigrations(ctx, db) if err != nil { - return 0, createVersionTable(db) + return 0, createVersionTable(ctx, db) } - defer rows.Close() - // The most recent record for each migration specifies // whether it has been applied or rolled back. // The first version we find that has been applied is the current version. - - toSkip := make([]int64, 0) - - for rows.Next() { - var row MigrationRecord - if err = rows.Scan(&row.VersionID, &row.IsApplied); err != nil { - return 0, fmt.Errorf("failed to scan row: %w", err) - } - - // have we already marked this version to be skipped? - skip := false - for _, v := range toSkip { - if v == row.VersionID { - skip = true - break - } - } - - if skip { + // + // TODO(mf): for historic reasons, we continue to use the is_applied column, + // but at some point we need to deprecate this logic and ideally remove + // this column. + // + // For context, see: + // https://github.com/pressly/goose/pull/131#pullrequestreview-178409168 + // + // The dbMigrations list is expected to be ordered by descending ID. But + // in the future we should be able to query the last record only. + skipLookup := make(map[int64]struct{}) + for _, m := range dbMigrations { + // Have we already marked this version to be skipped? + if _, ok := skipLookup[m.VersionID]; ok { continue } - - // if version has been applied we're done - if row.IsApplied { - return row.VersionID, nil + // If version has been applied we are done. + if m.IsApplied { + return m.VersionID, nil } - - // latest version of migration has not been applied. - toSkip = append(toSkip, row.VersionID) - } - if err := rows.Err(); err != nil { - return 0, fmt.Errorf("failed to get next row: %w", err) + // Latest version of migration has not been applied. + skipLookup[m.VersionID] = struct{}{} } - return 0, ErrNoNextVersion } -// Create the db version table -// and insert the initial 0 value into it -func createVersionTable(db *sql.DB) error { +// createVersionTable creates the db version table and inserts the +// initial 0 value into it. +func createVersionTable(ctx context.Context, db *sql.DB) error { txn, err := db.Begin() if err != nil { return err } - - d := GetDialect() - - if _, err := txn.Exec(d.createVersionTableSQL()); err != nil { + if err := store.CreateVersionTable(ctx, txn); err != nil { _ = txn.Rollback() return err } - - version := 0 - applied := true - if _, err := txn.Exec(d.insertVersionSQL(), version, applied); err != nil { + if err := store.InsertVersion(ctx, txn, 0); err != nil { _ = txn.Rollback() return err } - return txn.Commit() } diff --git a/migration.go b/migration.go index 0762d6496..f3727338e 100644 --- a/migration.go +++ b/migration.go @@ -1,6 +1,7 @@ package goose import ( + "context" "database/sql" "errors" "fmt" @@ -38,7 +39,8 @@ func (m *Migration) String() string { // Up runs an up migration. func (m *Migration) Up(db *sql.DB) error { - if err := m.run(db, true); err != nil { + ctx := context.Background() + if err := m.run(ctx, db, true); err != nil { return err } return nil @@ -46,13 +48,14 @@ func (m *Migration) Up(db *sql.DB) error { // Down runs a down migration. func (m *Migration) Down(db *sql.DB) error { - if err := m.run(db, false); err != nil { + ctx := context.Background() + if err := m.run(ctx, db, false); err != nil { return err } return nil } -func (m *Migration) run(db *sql.DB, direction bool) error { +func (m *Migration) run(ctx context.Context, db *sql.DB, direction bool) error { switch filepath.Ext(m.Source) { case ".sql": f, err := baseFS.Open(m.Source) @@ -67,7 +70,7 @@ func (m *Migration) run(db *sql.DB, direction bool) error { } start := time.Now() - if err := runSQLMigration(db, statements, useTx, m.Version, direction, m.noVersioning); err != nil { + if err := runSQLMigration(ctx, db, statements, useTx, m.Version, direction, m.noVersioning); err != nil { return fmt.Errorf("ERROR %v: failed to run SQL migration: %w", filepath.Base(m.Source), err) } finish := truncateDuration(time.Since(start)) @@ -92,6 +95,7 @@ func (m *Migration) run(db *sql.DB, direction bool) error { } empty = (fn == nil) if err := runGoMigration( + ctx, db, fn, m.Version, @@ -108,6 +112,7 @@ func (m *Migration) run(db *sql.DB, direction bool) error { } empty = (fn == nil) if err := runGoMigrationNoTx( + ctx, db, fn, m.Version, @@ -128,6 +133,7 @@ func (m *Migration) run(db *sql.DB, direction bool) error { } func runGoMigrationNoTx( + ctx context.Context, db *sql.DB, fn GoMigrationNoTx, version int64, @@ -141,12 +147,13 @@ func runGoMigrationNoTx( } } if recordVersion { - return insertOrDeleteVersionNoTx(db, version, direction) + return insertOrDeleteVersionNoTx(ctx, db, version, direction) } return nil } func runGoMigration( + ctx context.Context, db *sql.DB, fn GoMigration, version int64, @@ -168,7 +175,7 @@ func runGoMigration( } } if recordVersion { - if err := insertOrDeleteVersion(tx, version, direction); err != nil { + if err := insertOrDeleteVersion(ctx, tx, version, direction); err != nil { _ = tx.Rollback() return fmt.Errorf("failed to update version: %w", err) } @@ -179,22 +186,18 @@ func runGoMigration( return nil } -func insertOrDeleteVersion(tx *sql.Tx, version int64, direction bool) error { +func insertOrDeleteVersion(ctx context.Context, tx *sql.Tx, version int64, direction bool) error { if direction { - _, err := tx.Exec(GetDialect().insertVersionSQL(), version, direction) - return err + return store.InsertVersion(ctx, tx, version) } - _, err := tx.Exec(GetDialect().deleteVersionSQL(), version) - return err + return store.DeleteVersion(ctx, tx, version) } -func insertOrDeleteVersionNoTx(db *sql.DB, version int64, direction bool) error { +func insertOrDeleteVersionNoTx(ctx context.Context, db *sql.DB, version int64, direction bool) error { if direction { - _, err := db.Exec(GetDialect().insertVersionSQL(), version, direction) - return err + return store.InsertVersionNoTx(ctx, db, version) } - _, err := db.Exec(GetDialect().deleteVersionSQL(), version) - return err + return store.DeleteVersionNoTx(ctx, db, version) } // NumericComponent looks for migration scripts with names in the form: diff --git a/migration_sql.go b/migration_sql.go index aea456562..d45fa64b5 100644 --- a/migration_sql.go +++ b/migration_sql.go @@ -1,10 +1,10 @@ package goose import ( + "context" "database/sql" "fmt" "regexp" - "time" ) // Run a migration specified in raw SQL. @@ -14,8 +14,16 @@ import ( // be applied during an Up or Down migration // // All statements following an Up or Down annotation are grouped together -// until another direction is found. -func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direction bool, noVersioning bool) error { +// until another direction annotation is found. +func runSQLMigration( + ctx context.Context, + db *sql.DB, + statements []string, + useTx bool, + v int64, + direction bool, + noVersioning bool, +) error { if useTx { // TRANSACTION. @@ -28,7 +36,7 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc for _, query := range statements { verboseInfo("Executing statement: %s\n", clearStatement(query)) - if err = execQuery(tx.Exec, query); err != nil { + if _, err := tx.ExecContext(ctx, query); err != nil { verboseInfo("Rollback transaction") _ = tx.Rollback() return fmt.Errorf("failed to execute SQL query %q: %w", clearStatement(query), err) @@ -37,13 +45,13 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc if !noVersioning { if direction { - if err := execQuery(tx.Exec, GetDialect().insertVersionSQL(), v, direction); err != nil { + if err := store.InsertVersion(ctx, tx, v); err != nil { verboseInfo("Rollback transaction") _ = tx.Rollback() return fmt.Errorf("failed to insert new goose version: %w", err) } } else { - if err := execQuery(tx.Exec, GetDialect().deleteVersionSQL(), v); err != nil { + if err := store.DeleteVersion(ctx, tx, v); err != nil { verboseInfo("Rollback transaction") _ = tx.Rollback() return fmt.Errorf("failed to delete goose version: %w", err) @@ -62,17 +70,17 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc // NO TRANSACTION. for _, query := range statements { verboseInfo("Executing statement: %s", clearStatement(query)) - if err := execQuery(db.Exec, query); err != nil { + if _, err := db.ExecContext(ctx, query); err != nil { return fmt.Errorf("failed to execute SQL query %q: %w", clearStatement(query), err) } } if !noVersioning { if direction { - if err := execQuery(db.Exec, GetDialect().insertVersionSQL(), v, direction); err != nil { + if err := store.InsertVersionNoTx(ctx, db, v); err != nil { return fmt.Errorf("failed to insert new goose version: %w", err) } } else { - if err := execQuery(db.Exec, GetDialect().deleteVersionSQL(), v); err != nil { + if err := store.DeleteVersionNoTx(ctx, db, v); err != nil { return fmt.Errorf("failed to delete goose version: %w", err) } } @@ -81,32 +89,6 @@ func runSQLMigration(db *sql.DB, statements []string, useTx bool, v int64, direc return nil } -func execQuery(fn func(string, ...interface{}) (sql.Result, error), query string, args ...interface{}) error { - if !verbose { - _, err := fn(query, args...) - return err - } - - ch := make(chan error) - - go func() { - _, err := fn(query, args...) - ch <- err - }() - - t := time.Now() - ticker := time.NewTicker(time.Minute) - defer ticker.Stop() - for { - select { - case err := <-ch: - return err - case <-ticker.C: - verboseInfo("Executing statement still in progress for %v", time.Since(t).Round(time.Second)) - } - } -} - const ( grayColor = "\033[90m" resetColor = "\033[00m" diff --git a/reset.go b/reset.go index 258841fad..7be46179c 100644 --- a/reset.go +++ b/reset.go @@ -1,6 +1,7 @@ package goose import ( + "context" "database/sql" "fmt" "sort" @@ -8,6 +9,7 @@ import ( // Reset rolls back all migrations func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { + ctx := context.Background() option := &options{} for _, f := range opts { f(option) @@ -20,7 +22,7 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { return DownTo(db, dir, minVersion, opts...) } - statuses, err := dbMigrationsStatus(db) + statuses, err := dbMigrationsStatus(ctx, db) if err != nil { return fmt.Errorf("failed to get status of migrations: %w", err) } @@ -38,30 +40,20 @@ func Reset(db *sql.DB, dir string, opts ...OptionsFunc) error { return nil } -func dbMigrationsStatus(db *sql.DB) (map[int64]bool, error) { - rows, err := GetDialect().dbVersionQuery(db) +func dbMigrationsStatus(ctx context.Context, db *sql.DB) (map[int64]bool, error) { + dbMigrations, err := store.ListMigrations(ctx, db) if err != nil { - return map[int64]bool{}, nil + return nil, err } - defer rows.Close() - // The most recent record for each migration specifies // whether it has been applied or rolled back. + results := make(map[int64]bool) - result := make(map[int64]bool) - - for rows.Next() { - var row MigrationRecord - if err = rows.Scan(&row.VersionID, &row.IsApplied); err != nil { - return nil, fmt.Errorf("failed to scan row: %w", err) - } - - if _, ok := result[row.VersionID]; ok { + for _, m := range dbMigrations { + if _, ok := results[m.VersionID]; ok { continue } - - result[row.VersionID] = row.IsApplied + results[m.VersionID] = m.IsApplied } - - return result, nil + return results, nil } diff --git a/status.go b/status.go index f53f1bece..73e73decf 100644 --- a/status.go +++ b/status.go @@ -1,7 +1,9 @@ package goose import ( + "context" "database/sql" + "errors" "fmt" "path/filepath" "time" @@ -9,6 +11,7 @@ import ( // Status prints the status of all migrations. func Status(db *sql.DB, dir string, opts ...OptionsFunc) error { + ctx := context.Background() option := &options{} for _, f := range opts { f(option) @@ -34,7 +37,7 @@ func Status(db *sql.DB, dir string, opts ...OptionsFunc) error { log.Println(" Applied At Migration") log.Println(" =======================================") for _, migration := range migrations { - if err := printMigrationStatus(db, migration.Version, filepath.Base(migration.Source)); err != nil { + if err := printMigrationStatus(ctx, db, migration.Version, filepath.Base(migration.Source)); err != nil { return fmt.Errorf("failed to print status: %w", err) } } @@ -42,23 +45,15 @@ func Status(db *sql.DB, dir string, opts ...OptionsFunc) error { return nil } -func printMigrationStatus(db *sql.DB, version int64, script string) error { - q := GetDialect().migrationSQL() - - var row MigrationRecord - - err := db.QueryRow(q, version).Scan(&row.TStamp, &row.IsApplied) - if err != nil && err != sql.ErrNoRows { +func printMigrationStatus(ctx context.Context, db *sql.DB, version int64, script string) error { + m, err := store.GetMigration(ctx, db, version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("failed to query the latest migration: %w", err) } - - var appliedAt string - if row.IsApplied { - appliedAt = row.TStamp.Format(time.ANSIC) - } else { - appliedAt = "Pending" + appliedAt := "Pending" + if m != nil && m.IsApplied { + appliedAt = m.Timestamp.Format(time.ANSIC) } - log.Printf(" %-24s -- %v\n", appliedAt, script) return nil } diff --git a/up.go b/up.go index 1d668e38c..bc8ddc789 100644 --- a/up.go +++ b/up.go @@ -1,6 +1,7 @@ package goose import ( + "context" "database/sql" "errors" "fmt" @@ -34,6 +35,7 @@ func withApplyUpByOne() OptionsFunc { // UpTo migrates up to a specific version. func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { + ctx := context.Background() option := &options{} for _, f := range opts { f(option) @@ -58,7 +60,7 @@ func UpTo(db *sql.DB, dir string, version int64, opts ...OptionsFunc) error { if _, err := EnsureDBVersion(db); err != nil { return err } - dbMigrations, err := listAllDBVersions(db) + dbMigrations, err := listAllDBVersions(ctx, db) if err != nil { return err } @@ -222,28 +224,19 @@ func UpByOne(db *sql.DB, dir string, opts ...OptionsFunc) error { // listAllDBVersions returns a list of all migrations, ordered ascending. // TODO(mf): fairly cheap, but a nice-to-have is pagination support. -func listAllDBVersions(db *sql.DB) (Migrations, error) { - rows, err := GetDialect().dbVersionQuery(db) +func listAllDBVersions(ctx context.Context, db *sql.DB) (Migrations, error) { + dbMigrations, err := store.ListMigrations(ctx, db) if err != nil { - return nil, createVersionTable(db) + return nil, err } - var all Migrations - for rows.Next() { - var versionID int64 - var isApplied bool - if err := rows.Scan(&versionID, &isApplied); err != nil { - return nil, err - } + all := make(Migrations, 0, len(dbMigrations)) + for _, m := range dbMigrations { all = append(all, &Migration{ - Version: versionID, + Version: m.VersionID, }) } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } + // ListMigrations returns migrations in descending order by id. + // But we want to return them in ascending order by version_id, so we re-sort. sort.SliceStable(all, func(i, j int) bool { return all[i].Version < all[j].Version })