From 93e12046f846cb0607e765f0930c6642808cf0dc Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Thu, 26 Oct 2023 09:48:01 -0400 Subject: [PATCH 1/8] wip --- database/dialect/dialect.go | 130 +++++++++++++++++ database/dialect/dialect_test.go | 238 +++++++++++++++++++++++++++++++ database/sql_extended.go | 23 +++ database/store.go | 39 +++++ 4 files changed, 430 insertions(+) create mode 100644 database/dialect/dialect.go create mode 100644 database/dialect/dialect_test.go create mode 100644 database/sql_extended.go create mode 100644 database/store.go diff --git a/database/dialect/dialect.go b/database/dialect/dialect.go new file mode 100644 index 000000000..414a66e77 --- /dev/null +++ b/database/dialect/dialect.go @@ -0,0 +1,130 @@ +package dialect + +import ( + "context" + "errors" + "fmt" + + "github.com/pressly/goose/v3/database" + "github.com/pressly/goose/v3/internal/dialect/dialectquery" +) + +// Dialect is the type of database dialect. +type Dialect string + +const ( + ClickHouse Dialect = "clickhouse" + MSSQL Dialect = "mssql" + MySQL Dialect = "mysql" + Postgres Dialect = "postgres" + Redshift Dialect = "redshift" + SQLite3 Dialect = "sqlite3" + TiDB Dialect = "tidb" + Vertica Dialect = "vertica" + YdB Dialect = "ydb" + + // Custom is a special dialect that allows users to provide their own [database.Store] + // implementation when constructing a [goose.Provider]. + Custom Dialect = "custom" +) + +// NewStore returns a new [Store] backed by the given dialect. +func NewStore(dialect Dialect, tablename string) (database.Store, error) { + if tablename == "" { + return nil, errors.New("tablename must not be empty") + } + if dialect == "" { + return nil, errors.New("dialect must not be empty") + } + if dialect == Custom { + return nil, errors.New("dialect must not be custom") + } + lookup := map[Dialect]dialectquery.Querier{ + ClickHouse: &dialectquery.Clickhouse{}, + MSSQL: &dialectquery.Sqlserver{}, + MySQL: &dialectquery.Mysql{}, + Postgres: &dialectquery.Postgres{}, + Redshift: &dialectquery.Redshift{}, + SQLite3: &dialectquery.Sqlite3{}, + TiDB: &dialectquery.Tidb{}, + Vertica: &dialectquery.Vertica{}, + } + querier, ok := lookup[dialect] + if !ok { + return nil, fmt.Errorf("unknown dialect: %q", dialect) + } + return &store{ + tablename: tablename, + querier: querier, + }, nil +} + +type store struct { + tablename string + querier dialectquery.Querier +} + +var _ database.Store = (*store)(nil) + +func (s *store) CreateVersionTable(ctx context.Context, db database.DBTxConn) error { + q := s.querier.CreateTable(s.tablename) + if _, err := db.ExecContext(ctx, q); err != nil { + return fmt.Errorf("failed to create version table %q: %w", s.tablename, err) + } + return nil +} + +func (s *store) InsertOrDelete(ctx context.Context, db database.DBTxConn, direction bool, version int64) error { + if direction { + q := s.querier.InsertVersion(s.tablename) + if _, err := db.ExecContext(ctx, q, version, true); err != nil { + return fmt.Errorf("failed to insert version %d: %w", version, err) + } + return nil + } + q := s.querier.DeleteVersion(s.tablename) + if _, err := db.ExecContext(ctx, q, version); err != nil { + return fmt.Errorf("failed to delete version %d: %w", version, err) + } + return nil +} + +func (s *store) GetMigration( + ctx context.Context, + db database.DBTxConn, + version int64) (*database.GetMigrationResult, error) { + q := s.querier.GetMigrationByVersion(s.tablename) + var result database.GetMigrationResult + if err := db.QueryRowContext(ctx, q, version).Scan( + &result.Timestamp, + &result.IsApplied, + ); err != nil { + return nil, fmt.Errorf("failed to get migration %d: %w", version, err) + } + return &result, nil +} + +func (s *store) ListMigrations( + ctx context.Context, + db database.DBTxConn, +) ([]*database.ListMigrationsResult, error) { + q := s.querier.ListMigrations(s.tablename) + rows, err := db.QueryContext(ctx, q) + if err != nil { + return nil, fmt.Errorf("failed to list migrations: %w", err) + } + defer rows.Close() + + var migrations []*database.ListMigrationsResult + for rows.Next() { + var result database.ListMigrationsResult + if err := rows.Scan(&result.Version, &result.IsApplied); err != nil { + return nil, fmt.Errorf("failed to scan list migrations result: %w", err) + } + migrations = append(migrations, &result) + } + if err := rows.Err(); err != nil { + return nil, err + } + return migrations, nil +} diff --git a/database/dialect/dialect_test.go b/database/dialect/dialect_test.go new file mode 100644 index 000000000..6375a4272 --- /dev/null +++ b/database/dialect/dialect_test.go @@ -0,0 +1,238 @@ +package dialect_test + +import ( + "context" + "database/sql" + "errors" + "path/filepath" + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/pressly/goose/v3/database/dialect" + "github.com/pressly/goose/v3/internal/check" + "github.com/pressly/goose/v3/internal/testdb" + "go.uber.org/multierr" + "modernc.org/sqlite" +) + +// The goal of this test is to verify the database store package works as expected. This test is not +// meant to be exhaustive or test every possible database dialect. It is meant to verify the Store +// interface works against a real database. + +func TestDialectStore(t *testing.T) { + t.Parallel() + t.Run("invalid", func(t *testing.T) { + // Test empty table name. + _, err := dialect.NewStore(dialect.SQLite3, "") + check.HasError(t, err) + // Test unknown dialect. + _, err = dialect.NewStore("unknown-dialect", "foo") + check.HasError(t, err) + // Test empty dialect. + _, err = dialect.NewStore("", "foo") + check.HasError(t, err) + _, err = dialect.NewStore(dialect.Custom, "foo") + check.HasError(t, err) + }) + t.Run("postgres", func(t *testing.T) { + if testing.Short() { + t.Skip("skip long-running test") + } + // Test postgres specific behavior. + db, cleanup, err := testdb.NewPostgres() + check.NoError(t, err) + t.Cleanup(cleanup) + testStore(context.Background(), t, dialect.Postgres, db, func(t *testing.T, err error) { + var pgErr *pgconn.PgError + ok := errors.As(err, &pgErr) + check.Bool(t, ok, true) + check.Equal(t, pgErr.Code, "42P07") // duplicate_table + }) + }) + // Test generic behavior. + t.Run("sqlite3", func(t *testing.T) { + dir := t.TempDir() + db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + check.NoError(t, err) + testStore(context.Background(), t, dialect.SQLite3, db, func(t *testing.T, err error) { + var sqliteErr *sqlite.Error + ok := errors.As(err, &sqliteErr) + check.Bool(t, ok, true) + check.Equal(t, sqliteErr.Code(), 1) // Generic error (SQLITE_ERROR) + check.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists") + }) + }) + t.Run("ListMigrations", func(t *testing.T) { + dir := t.TempDir() + db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + check.NoError(t, err) + store, err := dialect.NewStore(dialect.SQLite3, "foo") + check.NoError(t, err) + err = store.CreateVersionTable(context.Background(), db) + check.NoError(t, err) + check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 1)) + check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 3)) + check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 2)) + res, err := store.ListMigrations(context.Background(), db) + check.NoError(t, err) + check.Number(t, len(res), 3) + // Check versions are in descending order: [2, 3, 1] + check.Number(t, res[0].Version, 2) + check.Number(t, res[1].Version, 3) + check.Number(t, res[2].Version, 1) + }) +} + +// testStore tests various store operations. +// +// If alreadyExists is not nil, it will be used to assert the error returned by CreateVersionTable +// when the version table already exists. +func testStore(ctx context.Context, t *testing.T, d dialect.Dialect, db *sql.DB, alreadyExists func(t *testing.T, err error)) { + const ( + tablename = "test_goose_db_version" + ) + store, err := dialect.NewStore(d, tablename) + check.NoError(t, err) + // Create the version table. + err = runTx(ctx, db, func(tx *sql.Tx) error { + return store.CreateVersionTable(ctx, tx) + }) + check.NoError(t, err) + // Create the version table again. This should fail. + err = runTx(ctx, db, func(tx *sql.Tx) error { + return store.CreateVersionTable(ctx, tx) + }) + check.HasError(t, err) + if alreadyExists != nil { + alreadyExists(t, err) + } + + // List migrations. There should be none. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 0) + return nil + }) + check.NoError(t, err) + + // Insert 5 migrations in addition to the zero migration. + for i := 0; i < 6; i++ { + err = runConn(ctx, db, func(conn *sql.Conn) error { + return store.InsertOrDelete(ctx, conn, true, int64(i)) + }) + check.NoError(t, err) + } + + // List migrations. There should be 6. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 6) + // Check versions are in descending order. + for i := 0; i < 6; i++ { + check.Number(t, res[i].Version, 5-i) + } + return nil + }) + check.NoError(t, err) + + // Delete 3 migrations backwards + for i := 5; i >= 3; i-- { + err = runConn(ctx, db, func(conn *sql.Conn) error { + return store.InsertOrDelete(ctx, conn, false, int64(i)) + }) + check.NoError(t, err) + } + + // List migrations. There should be 3. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 3) + // Check that the remaining versions are in descending order. + for i := 0; i < 3; i++ { + check.Number(t, res[i].Version, 2-i) + } + return nil + }) + check.NoError(t, err) + + // Get remaining migrations one by one. + for i := 0; i < 3; i++ { + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.GetMigration(ctx, conn, int64(i)) + check.NoError(t, err) + check.Equal(t, res.IsApplied, true) + check.Equal(t, res.Timestamp.IsZero(), false) + return nil + }) + check.NoError(t, err) + } + + // Delete remaining migrations one by one and use all 3 connection types: + + // 1. *sql.Tx + err = runTx(ctx, db, func(tx *sql.Tx) error { + return store.InsertOrDelete(ctx, tx, false, 2) + }) + check.NoError(t, err) + // 2. *sql.Conn + err = runConn(ctx, db, func(conn *sql.Conn) error { + return store.InsertOrDelete(ctx, conn, false, 1) + }) + check.NoError(t, err) + // 3. *sql.DB + err = store.InsertOrDelete(ctx, db, false, 0) + check.NoError(t, err) + + // List migrations. There should be none. + err = runConn(ctx, db, func(conn *sql.Conn) error { + res, err := store.ListMigrations(ctx, conn) + check.NoError(t, err) + check.Number(t, len(res), 0) + return nil + }) + check.NoError(t, err) + + // Try to get a migration that does not exist. + err = runConn(ctx, db, func(conn *sql.Conn) error { + _, err := store.GetMigration(ctx, conn, 0) + check.HasError(t, err) + check.Bool(t, errors.Is(err, sql.ErrNoRows), true) + return nil + }) + check.NoError(t, err) +} + +func runTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (retErr error) { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { + if retErr != nil { + retErr = multierr.Append(retErr, tx.Rollback()) + } + }() + if err := fn(tx); err != nil { + return err + } + return tx.Commit() +} + +func runConn(ctx context.Context, db *sql.DB, fn func(*sql.Conn) error) (retErr error) { + conn, err := db.Conn(ctx) + if err != nil { + return err + } + defer func() { + if retErr != nil { + retErr = multierr.Append(retErr, conn.Close()) + } + }() + if err := fn(conn); err != nil { + return err + } + return conn.Close() +} diff --git a/database/sql_extended.go b/database/sql_extended.go new file mode 100644 index 000000000..8eaa9399a --- /dev/null +++ b/database/sql_extended.go @@ -0,0 +1,23 @@ +package database + +import ( + "context" + "database/sql" +) + +// DBTxConn is a thin interface for common methods that is satisfied by *sql.DB, *sql.Tx and +// *sql.Conn. +// +// There is a long outstanding issue to formalize a std lib interface, but alas. See: +// https://github.com/golang/go/issues/14468 +type DBTxConn interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} + +var ( + _ DBTxConn = (*sql.DB)(nil) + _ DBTxConn = (*sql.Tx)(nil) + _ DBTxConn = (*sql.Conn)(nil) +) diff --git a/database/store.go b/database/store.go new file mode 100644 index 000000000..4e2ff0c2e --- /dev/null +++ b/database/store.go @@ -0,0 +1,39 @@ +package database + +import ( + "context" + "time" +) + +// Store is an interface that defines methods for managing database migrations and versioning. By +// defining a Store interface, we can support multiple databases with consistent functionality. +// +// Each database dialect requires a specific implementation of this interface. A dialect represents +// a set of SQL statements specific to a particular database system. +type Store interface { + // CreateVersionTable creates the version table. This table is used to record applied + // migrations. + CreateVersionTable(ctx context.Context, db DBTxConn) error + + // InsertOrDelete inserts or deletes a version id from the version table. If direction is true, + // insert the version id. If direction is false, delete the version id. + InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error + + // GetMigration retrieves a single migration by version id. This method may return the raw sql + // error if the query fails so the caller can assert for errors such as [sql.ErrNoRows]. + GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error) + + // ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If + // there are no migrations, return empty slice with no error. + ListMigrations(ctx context.Context, db DBTxConn) ([]*ListMigrationsResult, error) +} + +type GetMigrationResult struct { + Timestamp time.Time + IsApplied bool +} + +type ListMigrationsResult struct { + Version int64 + IsApplied bool +} From a2aac0ebbc68c2377ba915c6b81400717fbfcad9 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Thu, 26 Oct 2023 19:33:58 -0400 Subject: [PATCH 2/8] Add database package with Store --- database/{dialect => }/dialect.go | 66 +++---- database/{dialect => }/dialect_test.go | 22 +-- internal/provider/collect_test.go | 6 +- internal/provider/migration.go | 4 +- internal/provider/provider.go | 11 +- internal/provider/provider_test.go | 5 +- internal/provider/run.go | 6 +- internal/provider/run_test.go | 19 +- internal/sqladapter/sqladapter.go | 49 ----- internal/sqladapter/store.go | 111 ------------ internal/sqladapter/store_test.go | 237 ------------------------- internal/sqlextended/sqlextended.go | 23 --- 12 files changed, 69 insertions(+), 490 deletions(-) rename database/{dialect => }/dialect.go (62%) rename database/{dialect => }/dialect_test.go (89%) delete mode 100644 internal/sqladapter/sqladapter.go delete mode 100644 internal/sqladapter/store.go delete mode 100644 internal/sqladapter/store_test.go delete mode 100644 internal/sqlextended/sqlextended.go diff --git a/database/dialect/dialect.go b/database/dialect.go similarity index 62% rename from database/dialect/dialect.go rename to database/dialect.go index 414a66e77..d2bad7ccc 100644 --- a/database/dialect/dialect.go +++ b/database/dialect.go @@ -1,11 +1,10 @@ -package dialect +package database import ( "context" "errors" "fmt" - "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/dialect/dialectquery" ) @@ -13,41 +12,41 @@ import ( type Dialect string const ( - ClickHouse Dialect = "clickhouse" - MSSQL Dialect = "mssql" - MySQL Dialect = "mysql" - Postgres Dialect = "postgres" - Redshift Dialect = "redshift" - SQLite3 Dialect = "sqlite3" - TiDB Dialect = "tidb" - Vertica Dialect = "vertica" - YdB Dialect = "ydb" + DialectClickHouse Dialect = "clickhouse" + DialectMSSQL Dialect = "mssql" + DialectMySQL Dialect = "mysql" + DialectPostgres Dialect = "postgres" + DialectRedshift Dialect = "redshift" + DialectSQLite3 Dialect = "sqlite3" + DialectTiDB Dialect = "tidb" + DialectVertica Dialect = "vertica" + DialectYdB Dialect = "ydb" - // Custom is a special dialect that allows users to provide their own [database.Store] + // DialectCustom is a special dialect that allows users to provide their own [Store] // implementation when constructing a [goose.Provider]. - Custom Dialect = "custom" + DialectCustom Dialect = "custom" ) // NewStore returns a new [Store] backed by the given dialect. -func NewStore(dialect Dialect, tablename string) (database.Store, error) { +func NewStore(dialect Dialect, tablename string) (Store, error) { if tablename == "" { return nil, errors.New("tablename must not be empty") } if dialect == "" { return nil, errors.New("dialect must not be empty") } - if dialect == Custom { + if dialect == DialectCustom { return nil, errors.New("dialect must not be custom") } lookup := map[Dialect]dialectquery.Querier{ - ClickHouse: &dialectquery.Clickhouse{}, - MSSQL: &dialectquery.Sqlserver{}, - MySQL: &dialectquery.Mysql{}, - Postgres: &dialectquery.Postgres{}, - Redshift: &dialectquery.Redshift{}, - SQLite3: &dialectquery.Sqlite3{}, - TiDB: &dialectquery.Tidb{}, - Vertica: &dialectquery.Vertica{}, + DialectClickHouse: &dialectquery.Clickhouse{}, + DialectMSSQL: &dialectquery.Sqlserver{}, + DialectMySQL: &dialectquery.Mysql{}, + DialectPostgres: &dialectquery.Postgres{}, + DialectRedshift: &dialectquery.Redshift{}, + DialectSQLite3: &dialectquery.Sqlite3{}, + DialectTiDB: &dialectquery.Tidb{}, + DialectVertica: &dialectquery.Vertica{}, } querier, ok := lookup[dialect] if !ok { @@ -64,9 +63,9 @@ type store struct { querier dialectquery.Querier } -var _ database.Store = (*store)(nil) +var _ Store = (*store)(nil) -func (s *store) CreateVersionTable(ctx context.Context, db database.DBTxConn) error { +func (s *store) CreateVersionTable(ctx context.Context, db DBTxConn) error { q := s.querier.CreateTable(s.tablename) if _, err := db.ExecContext(ctx, q); err != nil { return fmt.Errorf("failed to create version table %q: %w", s.tablename, err) @@ -74,7 +73,7 @@ func (s *store) CreateVersionTable(ctx context.Context, db database.DBTxConn) er return nil } -func (s *store) InsertOrDelete(ctx context.Context, db database.DBTxConn, direction bool, version int64) error { +func (s *store) InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error { if direction { q := s.querier.InsertVersion(s.tablename) if _, err := db.ExecContext(ctx, q, version, true); err != nil { @@ -91,10 +90,11 @@ func (s *store) InsertOrDelete(ctx context.Context, db database.DBTxConn, direct func (s *store) GetMigration( ctx context.Context, - db database.DBTxConn, - version int64) (*database.GetMigrationResult, error) { + db DBTxConn, + version int64, +) (*GetMigrationResult, error) { q := s.querier.GetMigrationByVersion(s.tablename) - var result database.GetMigrationResult + var result GetMigrationResult if err := db.QueryRowContext(ctx, q, version).Scan( &result.Timestamp, &result.IsApplied, @@ -106,8 +106,8 @@ func (s *store) GetMigration( func (s *store) ListMigrations( ctx context.Context, - db database.DBTxConn, -) ([]*database.ListMigrationsResult, error) { + db DBTxConn, +) ([]*ListMigrationsResult, error) { q := s.querier.ListMigrations(s.tablename) rows, err := db.QueryContext(ctx, q) if err != nil { @@ -115,9 +115,9 @@ func (s *store) ListMigrations( } defer rows.Close() - var migrations []*database.ListMigrationsResult + var migrations []*ListMigrationsResult for rows.Next() { - var result database.ListMigrationsResult + var result ListMigrationsResult if err := rows.Scan(&result.Version, &result.IsApplied); err != nil { return nil, fmt.Errorf("failed to scan list migrations result: %w", err) } diff --git a/database/dialect/dialect_test.go b/database/dialect_test.go similarity index 89% rename from database/dialect/dialect_test.go rename to database/dialect_test.go index 6375a4272..875c174a6 100644 --- a/database/dialect/dialect_test.go +++ b/database/dialect_test.go @@ -1,4 +1,4 @@ -package dialect_test +package database_test import ( "context" @@ -8,7 +8,7 @@ import ( "testing" "github.com/jackc/pgx/v5/pgconn" - "github.com/pressly/goose/v3/database/dialect" + "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/check" "github.com/pressly/goose/v3/internal/testdb" "go.uber.org/multierr" @@ -23,15 +23,15 @@ func TestDialectStore(t *testing.T) { t.Parallel() t.Run("invalid", func(t *testing.T) { // Test empty table name. - _, err := dialect.NewStore(dialect.SQLite3, "") + _, err := database.NewStore(database.DialectSQLite3, "") check.HasError(t, err) // Test unknown dialect. - _, err = dialect.NewStore("unknown-dialect", "foo") + _, err = database.NewStore("unknown-dialect", "foo") check.HasError(t, err) // Test empty dialect. - _, err = dialect.NewStore("", "foo") + _, err = database.NewStore("", "foo") check.HasError(t, err) - _, err = dialect.NewStore(dialect.Custom, "foo") + _, err = database.NewStore(database.DialectCustom, "foo") check.HasError(t, err) }) t.Run("postgres", func(t *testing.T) { @@ -42,7 +42,7 @@ func TestDialectStore(t *testing.T) { db, cleanup, err := testdb.NewPostgres() check.NoError(t, err) t.Cleanup(cleanup) - testStore(context.Background(), t, dialect.Postgres, db, func(t *testing.T, err error) { + testStore(context.Background(), t, database.DialectPostgres, db, func(t *testing.T, err error) { var pgErr *pgconn.PgError ok := errors.As(err, &pgErr) check.Bool(t, ok, true) @@ -54,7 +54,7 @@ func TestDialectStore(t *testing.T) { dir := t.TempDir() db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) check.NoError(t, err) - testStore(context.Background(), t, dialect.SQLite3, db, func(t *testing.T, err error) { + testStore(context.Background(), t, database.DialectSQLite3, db, func(t *testing.T, err error) { var sqliteErr *sqlite.Error ok := errors.As(err, &sqliteErr) check.Bool(t, ok, true) @@ -66,7 +66,7 @@ func TestDialectStore(t *testing.T) { dir := t.TempDir() db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) check.NoError(t, err) - store, err := dialect.NewStore(dialect.SQLite3, "foo") + store, err := database.NewStore(database.DialectSQLite3, "foo") check.NoError(t, err) err = store.CreateVersionTable(context.Background(), db) check.NoError(t, err) @@ -87,11 +87,11 @@ func TestDialectStore(t *testing.T) { // // If alreadyExists is not nil, it will be used to assert the error returned by CreateVersionTable // when the version table already exists. -func testStore(ctx context.Context, t *testing.T, d dialect.Dialect, db *sql.DB, alreadyExists func(t *testing.T, err error)) { +func testStore(ctx context.Context, t *testing.T, d database.Dialect, db *sql.DB, alreadyExists func(t *testing.T, err error)) { const ( tablename = "test_goose_db_version" ) - store, err := dialect.NewStore(d, tablename) + store, err := database.NewStore(d, tablename) check.NoError(t, err) // Create the version table. err = runTx(ctx, db, func(tx *sql.Tx) error { diff --git a/internal/provider/collect_test.go b/internal/provider/collect_test.go index 8417e8473..b1983d76d 100644 --- a/internal/provider/collect_test.go +++ b/internal/provider/collect_test.go @@ -5,8 +5,8 @@ import ( "testing" "testing/fstest" + "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/check" - "github.com/pressly/goose/v3/internal/sqladapter" ) func TestCollectFileSources(t *testing.T) { @@ -294,7 +294,7 @@ func TestFindMissingMigrations(t *testing.T) { // Test case: database has migrations 1, 3, 4, 5, 7 // Missing migrations: 2, 6 // Filesystem has migrations 1, 2, 3, 4, 5, 6, 7, 8 - dbMigrations := []*sqladapter.ListMigrationsResult{ + dbMigrations := []*database.ListMigrationsResult{ {Version: 1}, {Version: 3}, {Version: 4}, @@ -322,7 +322,7 @@ func TestFindMissingMigrations(t *testing.T) { check.Number(t, len(findMissingMigrations(nil, fsMigrations)), 0) }) t.Run("fs_has_max_version", func(t *testing.T) { - dbMigrations := []*sqladapter.ListMigrationsResult{ + dbMigrations := []*database.ListMigrationsResult{ {Version: 1}, {Version: 5}, {Version: 2}, diff --git a/internal/provider/migration.go b/internal/provider/migration.go index 05faf01d9..2ace5f93d 100644 --- a/internal/provider/migration.go +++ b/internal/provider/migration.go @@ -6,7 +6,7 @@ import ( "fmt" "path/filepath" - "github.com/pressly/goose/v3/internal/sqlextended" + "github.com/pressly/goose/v3/database" ) type migration struct { @@ -170,7 +170,7 @@ func (s *sqlMigration) IsEmpty(direction bool) bool { return len(s.DownStatements) == 0 } -func (s *sqlMigration) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error { +func (s *sqlMigration) run(ctx context.Context, db database.DBTxConn, direction bool) error { var statements []string if direction { statements = s.UpStatements diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 2dd3350be..c2e081c88 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -9,7 +9,7 @@ import ( "math" "sync" - "github.com/pressly/goose/v3/internal/sqladapter" + "github.com/pressly/goose/v3/database" ) // NewProvider returns a new goose Provider. @@ -28,13 +28,10 @@ import ( // Unless otherwise specified, all methods on Provider are safe for concurrent use. // // Experimental: This API is experimental and may change in the future. -func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) { +func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) { if db == nil { return nil, errors.New("db must not be nil") } - if dialect == "" { - return nil, errors.New("dialect must not be empty") - } if fsys == nil { fsys = noopFS{} } @@ -51,7 +48,7 @@ func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption if cfg.tableName == "" { cfg.tableName = DefaultTablename } - store, err := sqladapter.NewStore(string(dialect), cfg.tableName) + store, err := database.NewStore(dialect, cfg.tableName) if err != nil { return nil, err } @@ -138,7 +135,7 @@ type Provider struct { db *sql.DB fsys fs.FS cfg config - store sqladapter.Store + store database.Store // migrations are ordered by version in ascending order. migrations []*migration diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index 6cd7a5f5e..cb5ce10c4 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -9,6 +9,7 @@ import ( "testing" "testing/fstest" + "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/check" "github.com/pressly/goose/v3/internal/provider" _ "modernc.org/sqlite" @@ -51,7 +52,7 @@ func TestProvider(t *testing.T) { t.Cleanup(provider.ResetGlobalGoMigrations) db := newDB(t) - _, err = provider.NewProvider(provider.DialectSQLite3, db, nil, + _, err = provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithGoMigration(1, nil, nil), ) check.HasError(t, err) @@ -60,7 +61,7 @@ func TestProvider(t *testing.T) { t.Run("empty_go", func(t *testing.T) { db := newDB(t) // explicit - _, err := provider.NewProvider(provider.DialectSQLite3, db, nil, + _, err := provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithGoMigration(1, &provider.GoMigration{Run: nil}, &provider.GoMigration{Run: nil}), ) check.HasError(t, err) diff --git a/internal/provider/run.go b/internal/provider/run.go index 6dc600183..17edff98f 100644 --- a/internal/provider/run.go +++ b/internal/provider/run.go @@ -10,7 +10,7 @@ import ( "strings" "time" - "github.com/pressly/goose/v3/internal/sqladapter" + "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/sqlparser" "go.uber.org/multierr" ) @@ -63,7 +63,7 @@ func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*Mi } func (p *Provider) resolveUpMigrations( - dbVersions []*sqladapter.ListMigrationsResult, + dbVersions []*database.ListMigrationsResult, version int64, ) ([]*migration, error) { var apply []*migration @@ -379,7 +379,7 @@ type missingMigration struct { // findMissingMigrations returns a list of migrations that are missing from the database. A missing // migration is one that has a version less than the max version in the database. func findMissingMigrations( - dbMigrations []*sqladapter.ListMigrationsResult, + dbMigrations []*database.ListMigrationsResult, fsMigrations []*migration, ) []missingMigration { existing := make(map[int64]bool) diff --git a/internal/provider/run_test.go b/internal/provider/run_test.go index 09c71cd2e..8d3e4463f 100644 --- a/internal/provider/run_test.go +++ b/internal/provider/run_test.go @@ -16,6 +16,7 @@ import ( "testing" "testing/fstest" + "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/check" "github.com/pressly/goose/v3/internal/provider" "github.com/pressly/goose/v3/internal/testdb" @@ -320,7 +321,7 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-2'); INSERT INTO owners (owner_name) VALUES ('seed-user-3'); `), } - p, err := provider.NewProvider(provider.DialectSQLite3, db, mapFS) + p, err := provider.NewProvider(database.DialectSQLite3, db, mapFS) check.NoError(t, err) _, err = p.Up(ctx) check.HasError(t, err) @@ -485,7 +486,7 @@ func TestNoVersioning(t *testing.T) { // These are owners created by migration files. wantOwnerCount = 4 ) - p, err := provider.NewProvider(provider.DialectSQLite3, db, fsys, + p, err := provider.NewProvider(database.DialectSQLite3, db, fsys, provider.WithVerbose(testing.Verbose()), provider.WithNoVersioning(false), // This is the default. ) @@ -498,7 +499,7 @@ func TestNoVersioning(t *testing.T) { check.Number(t, baseVersion, 3) t.Run("seed-up-down-to-zero", func(t *testing.T) { fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed")) - p, err := provider.NewProvider(provider.DialectSQLite3, db, fsys, + p, err := provider.NewProvider(database.DialectSQLite3, db, fsys, provider.WithVerbose(testing.Verbose()), provider.WithNoVersioning(true), // Provider with no versioning. ) @@ -551,7 +552,7 @@ func TestAllowMissing(t *testing.T) { t.Run("missing_now_allowed", func(t *testing.T) { db := newDB(t) - p, err := provider.NewProvider(provider.DialectSQLite3, db, newFsys(), + p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(), provider.WithAllowMissing(false), ) check.NoError(t, err) @@ -606,7 +607,7 @@ func TestAllowMissing(t *testing.T) { t.Run("missing_allowed", func(t *testing.T) { db := newDB(t) - p, err := provider.NewProvider(provider.DialectSQLite3, db, newFsys(), + p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(), provider.WithAllowMissing(true), ) check.NoError(t, err) @@ -714,7 +715,7 @@ func TestGoOnly(t *testing.T) { t.Cleanup(provider.ResetGlobalGoMigrations) db := newDB(t) - p, err := provider.NewProvider(provider.DialectSQLite3, db, nil, + p, err := provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithGoMigration( 2, &provider.GoMigration{Run: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, @@ -763,7 +764,7 @@ func TestGoOnly(t *testing.T) { t.Cleanup(provider.ResetGlobalGoMigrations) db := newDB(t) - p, err := provider.NewProvider(provider.DialectSQLite3, db, nil, + p, err := provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithGoMigration( 2, &provider.GoMigration{RunNoTx: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, @@ -820,7 +821,7 @@ func TestLockModeAdvisorySession(t *testing.T) { newProvider := func() *provider.Provider { sessionLocker, err := lock.NewPostgresSessionLocker() check.NoError(t, err) - p, err := provider.NewProvider(provider.DialectPostgres, db, os.DirFS("../../testdata/migrations"), + p, err := provider.NewProvider(database.DialectPostgres, db, os.DirFS("../../testdata/migrations"), provider.WithSessionLocker(sessionLocker), // Use advisory session lock mode. provider.WithVerbose(testing.Verbose()), ) @@ -1074,7 +1075,7 @@ func newProviderWithDB(t *testing.T, opts ...provider.ProviderOption) (*provider opts, provider.WithVerbose(testing.Verbose()), ) - p, err := provider.NewProvider(provider.DialectSQLite3, db, newFsys(), opts...) + p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(), opts...) check.NoError(t, err) return p, db } diff --git a/internal/sqladapter/sqladapter.go b/internal/sqladapter/sqladapter.go deleted file mode 100644 index f6c975dc4..000000000 --- a/internal/sqladapter/sqladapter.go +++ /dev/null @@ -1,49 +0,0 @@ -// Package sqladapter provides an interface for interacting with a SQL database. -// -// All supported database dialects must implement the Store interface. -package sqladapter - -import ( - "context" - "time" - - "github.com/pressly/goose/v3/internal/sqlextended" -) - -// Store is the interface that wraps the basic methods for a database dialect. -// -// A dialect is a set of SQL statements that are specific to a database. -// -// By defining a store interface, we can support multiple databases with a single codebase. -// -// The underlying implementation does not modify the error. It is the callers responsibility to -// assert for the correct error, such as [sql.ErrNoRows]. -type Store interface { - // CreateVersionTable creates the version table within a transaction. This table is used to - // record applied migrations. - CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn) error - - // InsertOrDelete inserts or deletes a version id from the version table. - InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error - - // GetMigration retrieves a single migration by version id. - // - // Returns the raw sql error if the query fails. It is the callers responsibility to assert for - // the correct error, such as [sql.ErrNoRows]. - GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error) - - // ListMigrations retrieves all migrations sorted in descending order by id. - // - // If there are no migrations, an empty slice is returned with no error. - ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error) -} - -type GetMigrationResult struct { - IsApplied bool - Timestamp time.Time -} - -type ListMigrationsResult struct { - Version int64 - IsApplied bool -} diff --git a/internal/sqladapter/store.go b/internal/sqladapter/store.go deleted file mode 100644 index 0ee90ca49..000000000 --- a/internal/sqladapter/store.go +++ /dev/null @@ -1,111 +0,0 @@ -package sqladapter - -import ( - "context" - "errors" - "fmt" - - "github.com/pressly/goose/v3/internal/dialect/dialectquery" - "github.com/pressly/goose/v3/internal/sqlextended" -) - -var _ Store = (*store)(nil) - -type store struct { - tablename string - querier dialectquery.Querier -} - -// NewStore returns a new [Store] backed by the given dialect. -// -// The dialect must match one of the supported dialects defined in dialect.go. -func NewStore(dialect string, table string) (Store, error) { - if table == "" { - return nil, errors.New("table must not be empty") - } - if dialect == "" { - return nil, errors.New("dialect must not be empty") - } - var querier dialectquery.Querier - switch dialect { - case "clickhouse": - querier = &dialectquery.Clickhouse{} - case "mssql": - querier = &dialectquery.Sqlserver{} - case "mysql": - querier = &dialectquery.Mysql{} - case "postgres": - querier = &dialectquery.Postgres{} - case "redshift": - querier = &dialectquery.Redshift{} - case "sqlite3": - querier = &dialectquery.Sqlite3{} - case "tidb": - querier = &dialectquery.Tidb{} - case "vertica": - querier = &dialectquery.Vertica{} - default: - return nil, fmt.Errorf("unknown dialect: %q", dialect) - } - return &store{ - tablename: table, - querier: querier, - }, nil -} - -func (s *store) CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn) error { - q := s.querier.CreateTable(s.tablename) - if _, err := db.ExecContext(ctx, q); err != nil { - return fmt.Errorf("failed to create version table %q: %w", s.tablename, err) - } - return nil -} - -func (s *store) InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error { - if direction { - q := s.querier.InsertVersion(s.tablename) - if _, err := db.ExecContext(ctx, q, version, true); err != nil { - return fmt.Errorf("failed to insert version %d: %w", version, err) - } - return nil - } - q := s.querier.DeleteVersion(s.tablename) - if _, err := db.ExecContext(ctx, q, version); err != nil { - return fmt.Errorf("failed to delete version %d: %w", version, err) - } - return nil -} - -func (s *store) GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error) { - q := s.querier.GetMigrationByVersion(s.tablename) - var result GetMigrationResult - if err := db.QueryRowContext(ctx, q, version).Scan( - &result.Timestamp, - &result.IsApplied, - ); err != nil { - return nil, fmt.Errorf("failed to get migration %d: %w", version, err) - } - return &result, nil -} - -func (s *store) ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error) { - q := s.querier.ListMigrations(s.tablename) - rows, err := db.QueryContext(ctx, q) - if err != nil { - return nil, fmt.Errorf("failed to list migrations: %w", err) - } - defer rows.Close() - - var migrations []*ListMigrationsResult - for rows.Next() { - var result ListMigrationsResult - if err := rows.Scan(&result.Version, &result.IsApplied); err != nil { - return nil, fmt.Errorf("failed to scan list migrations result: %w", err) - } - migrations = append(migrations, &result) - } - if err := rows.Err(); err != nil { - return nil, err - } - return migrations, nil -} diff --git a/internal/sqladapter/store_test.go b/internal/sqladapter/store_test.go deleted file mode 100644 index 69d3d3115..000000000 --- a/internal/sqladapter/store_test.go +++ /dev/null @@ -1,237 +0,0 @@ -package sqladapter_test - -import ( - "context" - "database/sql" - "errors" - "path/filepath" - "testing" - - "github.com/jackc/pgx/v5/pgconn" - "github.com/pressly/goose/v3" - "github.com/pressly/goose/v3/internal/check" - "github.com/pressly/goose/v3/internal/sqladapter" - "github.com/pressly/goose/v3/internal/testdb" - "go.uber.org/multierr" - "modernc.org/sqlite" -) - -// The goal of this test is to verify the sqladapter package works as expected. This test is not -// meant to be exhaustive or test every possible database dialect. It is meant to verify the Store -// interface works against a real database. - -func TestStore(t *testing.T) { - t.Parallel() - t.Run("invalid", func(t *testing.T) { - // Test empty table name. - _, err := sqladapter.NewStore("sqlite3", "") - check.HasError(t, err) - // Test unknown dialect. - _, err = sqladapter.NewStore("unknown-dialect", "foo") - check.HasError(t, err) - // Test empty dialect. - _, err = sqladapter.NewStore("", "foo") - check.HasError(t, err) - }) - t.Run("postgres", func(t *testing.T) { - if testing.Short() { - t.Skip("skip long-running test") - } - // Test postgres specific behavior. - db, cleanup, err := testdb.NewPostgres() - check.NoError(t, err) - t.Cleanup(cleanup) - testStore(context.Background(), t, goose.DialectPostgres, db, func(t *testing.T, err error) { - var pgErr *pgconn.PgError - ok := errors.As(err, &pgErr) - check.Bool(t, ok, true) - check.Equal(t, pgErr.Code, "42P07") // duplicate_table - }) - }) - // Test generic behavior. - t.Run("sqlite3", func(t *testing.T) { - dir := t.TempDir() - db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) - check.NoError(t, err) - testStore(context.Background(), t, goose.DialectSQLite3, db, func(t *testing.T, err error) { - var sqliteErr *sqlite.Error - ok := errors.As(err, &sqliteErr) - check.Bool(t, ok, true) - check.Equal(t, sqliteErr.Code(), 1) // Generic error (SQLITE_ERROR) - check.Contains(t, sqliteErr.Error(), "table test_goose_db_version already exists") - }) - }) - t.Run("ListMigrations", func(t *testing.T) { - dir := t.TempDir() - db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) - check.NoError(t, err) - store, err := sqladapter.NewStore("sqlite3", "foo") - check.NoError(t, err) - err = store.CreateVersionTable(context.Background(), db) - check.NoError(t, err) - check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 1)) - check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 3)) - check.NoError(t, store.InsertOrDelete(context.Background(), db, true, 2)) - res, err := store.ListMigrations(context.Background(), db) - check.NoError(t, err) - check.Number(t, len(res), 3) - // Check versions are in descending order: [2, 3, 1] - check.Number(t, res[0].Version, 2) - check.Number(t, res[1].Version, 3) - check.Number(t, res[2].Version, 1) - }) -} - -// testStore tests various store operations. -// -// If alreadyExists is not nil, it will be used to assert the error returned by CreateVersionTable -// when the version table already exists. -func testStore(ctx context.Context, t *testing.T, dialect goose.Dialect, db *sql.DB, alreadyExists func(t *testing.T, err error)) { - const ( - tablename = "test_goose_db_version" - ) - store, err := sqladapter.NewStore(string(dialect), tablename) - check.NoError(t, err) - // Create the version table. - err = runTx(ctx, db, func(tx *sql.Tx) error { - return store.CreateVersionTable(ctx, tx) - }) - check.NoError(t, err) - // Create the version table again. This should fail. - err = runTx(ctx, db, func(tx *sql.Tx) error { - return store.CreateVersionTable(ctx, tx) - }) - check.HasError(t, err) - if alreadyExists != nil { - alreadyExists(t, err) - } - - // List migrations. There should be none. - err = runConn(ctx, db, func(conn *sql.Conn) error { - res, err := store.ListMigrations(ctx, conn) - check.NoError(t, err) - check.Number(t, len(res), 0) - return nil - }) - check.NoError(t, err) - - // Insert 5 migrations in addition to the zero migration. - for i := 0; i < 6; i++ { - err = runConn(ctx, db, func(conn *sql.Conn) error { - return store.InsertOrDelete(ctx, conn, true, int64(i)) - }) - check.NoError(t, err) - } - - // List migrations. There should be 6. - err = runConn(ctx, db, func(conn *sql.Conn) error { - res, err := store.ListMigrations(ctx, conn) - check.NoError(t, err) - check.Number(t, len(res), 6) - // Check versions are in descending order. - for i := 0; i < 6; i++ { - check.Number(t, res[i].Version, 5-i) - } - return nil - }) - check.NoError(t, err) - - // Delete 3 migrations backwards - for i := 5; i >= 3; i-- { - err = runConn(ctx, db, func(conn *sql.Conn) error { - return store.InsertOrDelete(ctx, conn, false, int64(i)) - }) - check.NoError(t, err) - } - - // List migrations. There should be 3. - err = runConn(ctx, db, func(conn *sql.Conn) error { - res, err := store.ListMigrations(ctx, conn) - check.NoError(t, err) - check.Number(t, len(res), 3) - // Check that the remaining versions are in descending order. - for i := 0; i < 3; i++ { - check.Number(t, res[i].Version, 2-i) - } - return nil - }) - check.NoError(t, err) - - // Get remaining migrations one by one. - for i := 0; i < 3; i++ { - err = runConn(ctx, db, func(conn *sql.Conn) error { - res, err := store.GetMigration(ctx, conn, int64(i)) - check.NoError(t, err) - check.Equal(t, res.IsApplied, true) - check.Equal(t, res.Timestamp.IsZero(), false) - return nil - }) - check.NoError(t, err) - } - - // Delete remaining migrations one by one and use all 3 connection types: - - // 1. *sql.Tx - err = runTx(ctx, db, func(tx *sql.Tx) error { - return store.InsertOrDelete(ctx, tx, false, 2) - }) - check.NoError(t, err) - // 2. *sql.Conn - err = runConn(ctx, db, func(conn *sql.Conn) error { - return store.InsertOrDelete(ctx, conn, false, 1) - }) - check.NoError(t, err) - // 3. *sql.DB - err = store.InsertOrDelete(ctx, db, false, 0) - check.NoError(t, err) - - // List migrations. There should be none. - err = runConn(ctx, db, func(conn *sql.Conn) error { - res, err := store.ListMigrations(ctx, conn) - check.NoError(t, err) - check.Number(t, len(res), 0) - return nil - }) - check.NoError(t, err) - - // Try to get a migration that does not exist. - err = runConn(ctx, db, func(conn *sql.Conn) error { - _, err := store.GetMigration(ctx, conn, 0) - check.HasError(t, err) - check.Bool(t, errors.Is(err, sql.ErrNoRows), true) - return nil - }) - check.NoError(t, err) -} - -func runTx(ctx context.Context, db *sql.DB, fn func(*sql.Tx) error) (retErr error) { - tx, err := db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer func() { - if retErr != nil { - retErr = multierr.Append(retErr, tx.Rollback()) - } - }() - if err := fn(tx); err != nil { - return err - } - return tx.Commit() -} - -func runConn(ctx context.Context, db *sql.DB, fn func(*sql.Conn) error) (retErr error) { - conn, err := db.Conn(ctx) - if err != nil { - return err - } - defer func() { - if retErr != nil { - retErr = multierr.Append(retErr, conn.Close()) - } - }() - if err := fn(conn); err != nil { - return err - } - return conn.Close() -} diff --git a/internal/sqlextended/sqlextended.go b/internal/sqlextended/sqlextended.go deleted file mode 100644 index 83ca7ae8b..000000000 --- a/internal/sqlextended/sqlextended.go +++ /dev/null @@ -1,23 +0,0 @@ -package sqlextended - -import ( - "context" - "database/sql" -) - -// DBTxConn is a thin interface for common method that is satisfied by *sql.DB, *sql.Tx and -// *sql.Conn. -// -// There is a long outstanding issue to formalize a std lib interface, but alas... See: -// https://github.com/golang/go/issues/14468 -type DBTxConn interface { - ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) - QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) - QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row -} - -var ( - _ DBTxConn = (*sql.DB)(nil) - _ DBTxConn = (*sql.Tx)(nil) - _ DBTxConn = (*sql.Conn)(nil) -) From b780265b3c1e321d490c7020ccf4c33b0d48a5ce Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Thu, 26 Oct 2023 19:35:39 -0400 Subject: [PATCH 3/8] wip --- database/{dialect_test.go => store_test.go} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename database/{dialect_test.go => store_test.go} (100%) diff --git a/database/dialect_test.go b/database/store_test.go similarity index 100% rename from database/dialect_test.go rename to database/store_test.go From 569e92ae37228f386d7c86f3847eb85b5497c510 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Thu, 26 Oct 2023 19:36:55 -0400 Subject: [PATCH 4/8] wip --- database/dialect.go | 1 + 1 file changed, 1 insertion(+) diff --git a/database/dialect.go b/database/dialect.go index d2bad7ccc..d4b352973 100644 --- a/database/dialect.go +++ b/database/dialect.go @@ -47,6 +47,7 @@ func NewStore(dialect Dialect, tablename string) (Store, error) { DialectSQLite3: &dialectquery.Sqlite3{}, DialectTiDB: &dialectquery.Tidb{}, DialectVertica: &dialectquery.Vertica{}, + DialectYdB: &dialectquery.Ydb{}, } querier, ok := lookup[dialect] if !ok { From 3195fc1a6b2a9ea5054f3bf2052d968e6f5f54e7 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Thu, 26 Oct 2023 19:38:25 -0400 Subject: [PATCH 5/8] wip --- database/store_test.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/database/store_test.go b/database/store_test.go index 875c174a6..c49a6bf4b 100644 --- a/database/store_test.go +++ b/database/store_test.go @@ -87,7 +87,13 @@ func TestDialectStore(t *testing.T) { // // If alreadyExists is not nil, it will be used to assert the error returned by CreateVersionTable // when the version table already exists. -func testStore(ctx context.Context, t *testing.T, d database.Dialect, db *sql.DB, alreadyExists func(t *testing.T, err error)) { +func testStore( + ctx context.Context, + t *testing.T, + d database.Dialect, + db *sql.DB, + alreadyExists func(t *testing.T, err error), +) { const ( tablename = "test_goose_db_version" ) From beb07388433d5aed980cdba08add9f1c87b85d20 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Thu, 26 Oct 2023 19:54:55 -0400 Subject: [PATCH 6/8] wip --- database/doc.go | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 database/doc.go diff --git a/database/doc.go b/database/doc.go new file mode 100644 index 000000000..f8f9ef1e3 --- /dev/null +++ b/database/doc.go @@ -0,0 +1,8 @@ +// Package database provides a Store interface for goose to use when interacting with the database. +// It also provides a an implementation for each supported database dialect. +// +// The Store interface is meant to be generic and not tied to any specific database. +// +// It's possible to implement a custom Store for a database that goose does not support. To do so, +// implement the [Store] interface and pass it to [goose.NewProvider]. +package database From b3afcf46ce3880d53931bdfdcf24a9c14c9db3f8 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Thu, 26 Oct 2023 20:15:47 -0400 Subject: [PATCH 7/8] use :memory: --- database/store_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/database/store_test.go b/database/store_test.go index c49a6bf4b..1dfc3f97d 100644 --- a/database/store_test.go +++ b/database/store_test.go @@ -51,8 +51,7 @@ func TestDialectStore(t *testing.T) { }) // Test generic behavior. t.Run("sqlite3", func(t *testing.T) { - dir := t.TempDir() - db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + db, err := sql.Open("sqlite", ":memory:") check.NoError(t, err) testStore(context.Background(), t, database.DialectSQLite3, db, func(t *testing.T, err error) { var sqliteErr *sqlite.Error From c24017af50aeb61a81e434d7adbaff13bbee6a74 Mon Sep 17 00:00:00 2001 From: Mike Fridman Date: Thu, 26 Oct 2023 20:38:46 -0400 Subject: [PATCH 8/8] wip --- database/doc.go | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/database/doc.go b/database/doc.go index f8f9ef1e3..4748c72d8 100644 --- a/database/doc.go +++ b/database/doc.go @@ -1,8 +1,14 @@ -// Package database provides a Store interface for goose to use when interacting with the database. -// It also provides a an implementation for each supported database dialect. +// Package database defines a generic [Store] interface for goose to use when interacting with the +// database. It is meant to be generic and not tied to any specific database technology. // -// The Store interface is meant to be generic and not tied to any specific database. +// At a high level, a [Store] is responsible for: +// - Creating a version table +// - Inserting and deleting a version +// - Getting a specific version +// - Listing all applied versions // -// It's possible to implement a custom Store for a database that goose does not support. To do so, -// implement the [Store] interface and pass it to [goose.NewProvider]. +// Use the [NewStore] function to create a [Store] for one of the supported dialects. +// +// For more advanced use cases, it's possible to implement a custom [Store] for a database that +// goose does not support. package database