diff --git a/internal/sqladapter/store.go b/database/dialect.go similarity index 51% rename from internal/sqladapter/store.go rename to database/dialect.go index 0ee90ca49..d4b352973 100644 --- a/internal/sqladapter/store.go +++ b/database/dialect.go @@ -1,4 +1,4 @@ -package sqladapter +package database import ( "context" @@ -6,54 +6,67 @@ import ( "fmt" "github.com/pressly/goose/v3/internal/dialect/dialectquery" - "github.com/pressly/goose/v3/internal/sqlextended" ) -var _ Store = (*store)(nil) +// Dialect is the type of database dialect. +type Dialect string -type store struct { - tablename string - querier dialectquery.Querier -} +const ( + DialectClickHouse Dialect = "clickhouse" + DialectMSSQL Dialect = "mssql" + DialectMySQL Dialect = "mysql" + DialectPostgres Dialect = "postgres" + DialectRedshift Dialect = "redshift" + DialectSQLite3 Dialect = "sqlite3" + DialectTiDB Dialect = "tidb" + DialectVertica Dialect = "vertica" + DialectYdB Dialect = "ydb" + + // DialectCustom is a special dialect that allows users to provide their own [Store] + // implementation when constructing a [goose.Provider]. + DialectCustom Dialect = "custom" +) // NewStore returns a new [Store] backed by the given dialect. -// -// The dialect must match one of the supported dialects defined in dialect.go. -func NewStore(dialect string, table string) (Store, error) { - if table == "" { - return nil, errors.New("table must not be empty") +func NewStore(dialect Dialect, tablename string) (Store, error) { + if tablename == "" { + return nil, errors.New("tablename must not be empty") } if dialect == "" { return nil, errors.New("dialect must not be empty") } - var querier dialectquery.Querier - switch dialect { - case "clickhouse": - querier = &dialectquery.Clickhouse{} - case "mssql": - querier = &dialectquery.Sqlserver{} - case "mysql": - querier = &dialectquery.Mysql{} - case "postgres": - querier = &dialectquery.Postgres{} - case "redshift": - querier = &dialectquery.Redshift{} - case "sqlite3": - querier = &dialectquery.Sqlite3{} - case "tidb": - querier = &dialectquery.Tidb{} - case "vertica": - querier = &dialectquery.Vertica{} - default: + if dialect == DialectCustom { + return nil, errors.New("dialect must not be custom") + } + lookup := map[Dialect]dialectquery.Querier{ + DialectClickHouse: &dialectquery.Clickhouse{}, + DialectMSSQL: &dialectquery.Sqlserver{}, + DialectMySQL: &dialectquery.Mysql{}, + DialectPostgres: &dialectquery.Postgres{}, + DialectRedshift: &dialectquery.Redshift{}, + DialectSQLite3: &dialectquery.Sqlite3{}, + DialectTiDB: &dialectquery.Tidb{}, + DialectVertica: &dialectquery.Vertica{}, + DialectYdB: &dialectquery.Ydb{}, + } + querier, ok := lookup[dialect] + if !ok { return nil, fmt.Errorf("unknown dialect: %q", dialect) } return &store{ - tablename: table, + tablename: tablename, querier: querier, }, nil } -func (s *store) CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn) error { +type store struct { + tablename string + querier dialectquery.Querier +} + +var _ Store = (*store)(nil) + +func (s *store) CreateVersionTable(ctx context.Context, db DBTxConn) error { q := s.querier.CreateTable(s.tablename) if _, err := db.ExecContext(ctx, q); err != nil { return fmt.Errorf("failed to create version table %q: %w", s.tablename, err) @@ -61,7 +74,7 @@ func (s *store) CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn) return nil } -func (s *store) InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error { +func (s *store) InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error { if direction { q := s.querier.InsertVersion(s.tablename) if _, err := db.ExecContext(ctx, q, version, true); err != nil { @@ -76,7 +89,11 @@ func (s *store) InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, dir return nil } -func (s *store) GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error) { +func (s *store) GetMigration( + ctx context.Context, + db DBTxConn, + version int64, +) (*GetMigrationResult, error) { q := s.querier.GetMigrationByVersion(s.tablename) var result GetMigrationResult if err := db.QueryRowContext(ctx, q, version).Scan( @@ -88,7 +105,10 @@ func (s *store) GetMigration(ctx context.Context, db sqlextended.DBTxConn, versi return &result, nil } -func (s *store) ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error) { +func (s *store) ListMigrations( + ctx context.Context, + db DBTxConn, +) ([]*ListMigrationsResult, error) { q := s.querier.ListMigrations(s.tablename) rows, err := db.QueryContext(ctx, q) if err != nil { diff --git a/database/doc.go b/database/doc.go new file mode 100644 index 000000000..4748c72d8 --- /dev/null +++ b/database/doc.go @@ -0,0 +1,14 @@ +// Package database defines a generic [Store] interface for goose to use when interacting with the +// database. It is meant to be generic and not tied to any specific database technology. +// +// At a high level, a [Store] is responsible for: +// - Creating a version table +// - Inserting and deleting a version +// - Getting a specific version +// - Listing all applied versions +// +// Use the [NewStore] function to create a [Store] for one of the supported dialects. +// +// For more advanced use cases, it's possible to implement a custom [Store] for a database that +// goose does not support. +package database diff --git a/internal/sqlextended/sqlextended.go b/database/sql_extended.go similarity index 79% rename from internal/sqlextended/sqlextended.go rename to database/sql_extended.go index 83ca7ae8b..8eaa9399a 100644 --- a/internal/sqlextended/sqlextended.go +++ b/database/sql_extended.go @@ -1,14 +1,14 @@ -package sqlextended +package database import ( "context" "database/sql" ) -// DBTxConn is a thin interface for common method that is satisfied by *sql.DB, *sql.Tx and +// DBTxConn is a thin interface for common methods that is satisfied by *sql.DB, *sql.Tx and // *sql.Conn. // -// There is a long outstanding issue to formalize a std lib interface, but alas... See: +// There is a long outstanding issue to formalize a std lib interface, but alas. See: // https://github.com/golang/go/issues/14468 type DBTxConn interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) diff --git a/database/store.go b/database/store.go new file mode 100644 index 000000000..4e2ff0c2e --- /dev/null +++ b/database/store.go @@ -0,0 +1,39 @@ +package database + +import ( + "context" + "time" +) + +// Store is an interface that defines methods for managing database migrations and versioning. By +// defining a Store interface, we can support multiple databases with consistent functionality. +// +// Each database dialect requires a specific implementation of this interface. A dialect represents +// a set of SQL statements specific to a particular database system. +type Store interface { + // CreateVersionTable creates the version table. This table is used to record applied + // migrations. + CreateVersionTable(ctx context.Context, db DBTxConn) error + + // InsertOrDelete inserts or deletes a version id from the version table. If direction is true, + // insert the version id. If direction is false, delete the version id. + InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error + + // GetMigration retrieves a single migration by version id. This method may return the raw sql + // error if the query fails so the caller can assert for errors such as [sql.ErrNoRows]. + GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error) + + // ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If + // there are no migrations, return empty slice with no error. + ListMigrations(ctx context.Context, db DBTxConn) ([]*ListMigrationsResult, error) +} + +type GetMigrationResult struct { + Timestamp time.Time + IsApplied bool +} + +type ListMigrationsResult struct { + Version int64 + IsApplied bool +} diff --git a/internal/sqladapter/store_test.go b/database/store_test.go similarity index 86% rename from internal/sqladapter/store_test.go rename to database/store_test.go index 69d3d3115..1dfc3f97d 100644 --- a/internal/sqladapter/store_test.go +++ b/database/store_test.go @@ -1,4 +1,4 @@ -package sqladapter_test +package database_test import ( "context" @@ -8,29 +8,30 @@ import ( "testing" "github.com/jackc/pgx/v5/pgconn" - "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/check" - "github.com/pressly/goose/v3/internal/sqladapter" "github.com/pressly/goose/v3/internal/testdb" "go.uber.org/multierr" "modernc.org/sqlite" ) -// The goal of this test is to verify the sqladapter package works as expected. This test is not +// The goal of this test is to verify the database store package works as expected. This test is not // meant to be exhaustive or test every possible database dialect. It is meant to verify the Store // interface works against a real database. -func TestStore(t *testing.T) { +func TestDialectStore(t *testing.T) { t.Parallel() t.Run("invalid", func(t *testing.T) { // Test empty table name. - _, err := sqladapter.NewStore("sqlite3", "") + _, err := database.NewStore(database.DialectSQLite3, "") check.HasError(t, err) // Test unknown dialect. - _, err = sqladapter.NewStore("unknown-dialect", "foo") + _, err = database.NewStore("unknown-dialect", "foo") check.HasError(t, err) // Test empty dialect. - _, err = sqladapter.NewStore("", "foo") + _, err = database.NewStore("", "foo") + check.HasError(t, err) + _, err = database.NewStore(database.DialectCustom, "foo") check.HasError(t, err) }) t.Run("postgres", func(t *testing.T) { @@ -41,7 +42,7 @@ func TestStore(t *testing.T) { db, cleanup, err := testdb.NewPostgres() check.NoError(t, err) t.Cleanup(cleanup) - testStore(context.Background(), t, goose.DialectPostgres, db, func(t *testing.T, err error) { + testStore(context.Background(), t, database.DialectPostgres, db, func(t *testing.T, err error) { var pgErr *pgconn.PgError ok := errors.As(err, &pgErr) check.Bool(t, ok, true) @@ -50,10 +51,9 @@ func TestStore(t *testing.T) { }) // Test generic behavior. t.Run("sqlite3", func(t *testing.T) { - dir := t.TempDir() - db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) + db, err := sql.Open("sqlite", ":memory:") check.NoError(t, err) - testStore(context.Background(), t, goose.DialectSQLite3, db, func(t *testing.T, err error) { + testStore(context.Background(), t, database.DialectSQLite3, db, func(t *testing.T, err error) { var sqliteErr *sqlite.Error ok := errors.As(err, &sqliteErr) check.Bool(t, ok, true) @@ -65,7 +65,7 @@ func TestStore(t *testing.T) { dir := t.TempDir() db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db")) check.NoError(t, err) - store, err := sqladapter.NewStore("sqlite3", "foo") + store, err := database.NewStore(database.DialectSQLite3, "foo") check.NoError(t, err) err = store.CreateVersionTable(context.Background(), db) check.NoError(t, err) @@ -86,11 +86,17 @@ func TestStore(t *testing.T) { // // If alreadyExists is not nil, it will be used to assert the error returned by CreateVersionTable // when the version table already exists. -func testStore(ctx context.Context, t *testing.T, dialect goose.Dialect, db *sql.DB, alreadyExists func(t *testing.T, err error)) { +func testStore( + ctx context.Context, + t *testing.T, + d database.Dialect, + db *sql.DB, + alreadyExists func(t *testing.T, err error), +) { const ( tablename = "test_goose_db_version" ) - store, err := sqladapter.NewStore(string(dialect), tablename) + store, err := database.NewStore(d, tablename) check.NoError(t, err) // Create the version table. err = runTx(ctx, db, func(tx *sql.Tx) error { diff --git a/internal/provider/collect_test.go b/internal/provider/collect_test.go index 8417e8473..b1983d76d 100644 --- a/internal/provider/collect_test.go +++ b/internal/provider/collect_test.go @@ -5,8 +5,8 @@ import ( "testing" "testing/fstest" + "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/check" - "github.com/pressly/goose/v3/internal/sqladapter" ) func TestCollectFileSources(t *testing.T) { @@ -294,7 +294,7 @@ func TestFindMissingMigrations(t *testing.T) { // Test case: database has migrations 1, 3, 4, 5, 7 // Missing migrations: 2, 6 // Filesystem has migrations 1, 2, 3, 4, 5, 6, 7, 8 - dbMigrations := []*sqladapter.ListMigrationsResult{ + dbMigrations := []*database.ListMigrationsResult{ {Version: 1}, {Version: 3}, {Version: 4}, @@ -322,7 +322,7 @@ func TestFindMissingMigrations(t *testing.T) { check.Number(t, len(findMissingMigrations(nil, fsMigrations)), 0) }) t.Run("fs_has_max_version", func(t *testing.T) { - dbMigrations := []*sqladapter.ListMigrationsResult{ + dbMigrations := []*database.ListMigrationsResult{ {Version: 1}, {Version: 5}, {Version: 2}, diff --git a/internal/provider/migration.go b/internal/provider/migration.go index 05faf01d9..2ace5f93d 100644 --- a/internal/provider/migration.go +++ b/internal/provider/migration.go @@ -6,7 +6,7 @@ import ( "fmt" "path/filepath" - "github.com/pressly/goose/v3/internal/sqlextended" + "github.com/pressly/goose/v3/database" ) type migration struct { @@ -170,7 +170,7 @@ func (s *sqlMigration) IsEmpty(direction bool) bool { return len(s.DownStatements) == 0 } -func (s *sqlMigration) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error { +func (s *sqlMigration) run(ctx context.Context, db database.DBTxConn, direction bool) error { var statements []string if direction { statements = s.UpStatements diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 2dd3350be..c2e081c88 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -9,7 +9,7 @@ import ( "math" "sync" - "github.com/pressly/goose/v3/internal/sqladapter" + "github.com/pressly/goose/v3/database" ) // NewProvider returns a new goose Provider. @@ -28,13 +28,10 @@ import ( // Unless otherwise specified, all methods on Provider are safe for concurrent use. // // Experimental: This API is experimental and may change in the future. -func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) { +func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) { if db == nil { return nil, errors.New("db must not be nil") } - if dialect == "" { - return nil, errors.New("dialect must not be empty") - } if fsys == nil { fsys = noopFS{} } @@ -51,7 +48,7 @@ func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption if cfg.tableName == "" { cfg.tableName = DefaultTablename } - store, err := sqladapter.NewStore(string(dialect), cfg.tableName) + store, err := database.NewStore(dialect, cfg.tableName) if err != nil { return nil, err } @@ -138,7 +135,7 @@ type Provider struct { db *sql.DB fsys fs.FS cfg config - store sqladapter.Store + store database.Store // migrations are ordered by version in ascending order. migrations []*migration diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index 6cd7a5f5e..cb5ce10c4 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -9,6 +9,7 @@ import ( "testing" "testing/fstest" + "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/check" "github.com/pressly/goose/v3/internal/provider" _ "modernc.org/sqlite" @@ -51,7 +52,7 @@ func TestProvider(t *testing.T) { t.Cleanup(provider.ResetGlobalGoMigrations) db := newDB(t) - _, err = provider.NewProvider(provider.DialectSQLite3, db, nil, + _, err = provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithGoMigration(1, nil, nil), ) check.HasError(t, err) @@ -60,7 +61,7 @@ func TestProvider(t *testing.T) { t.Run("empty_go", func(t *testing.T) { db := newDB(t) // explicit - _, err := provider.NewProvider(provider.DialectSQLite3, db, nil, + _, err := provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithGoMigration(1, &provider.GoMigration{Run: nil}, &provider.GoMigration{Run: nil}), ) check.HasError(t, err) diff --git a/internal/provider/run.go b/internal/provider/run.go index 6dc600183..17edff98f 100644 --- a/internal/provider/run.go +++ b/internal/provider/run.go @@ -10,7 +10,7 @@ import ( "strings" "time" - "github.com/pressly/goose/v3/internal/sqladapter" + "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/sqlparser" "go.uber.org/multierr" ) @@ -63,7 +63,7 @@ func (p *Provider) up(ctx context.Context, upByOne bool, version int64) (_ []*Mi } func (p *Provider) resolveUpMigrations( - dbVersions []*sqladapter.ListMigrationsResult, + dbVersions []*database.ListMigrationsResult, version int64, ) ([]*migration, error) { var apply []*migration @@ -379,7 +379,7 @@ type missingMigration struct { // findMissingMigrations returns a list of migrations that are missing from the database. A missing // migration is one that has a version less than the max version in the database. func findMissingMigrations( - dbMigrations []*sqladapter.ListMigrationsResult, + dbMigrations []*database.ListMigrationsResult, fsMigrations []*migration, ) []missingMigration { existing := make(map[int64]bool) diff --git a/internal/provider/run_test.go b/internal/provider/run_test.go index 09c71cd2e..8d3e4463f 100644 --- a/internal/provider/run_test.go +++ b/internal/provider/run_test.go @@ -16,6 +16,7 @@ import ( "testing" "testing/fstest" + "github.com/pressly/goose/v3/database" "github.com/pressly/goose/v3/internal/check" "github.com/pressly/goose/v3/internal/provider" "github.com/pressly/goose/v3/internal/testdb" @@ -320,7 +321,7 @@ INSERT INTO owners (owner_name) VALUES ('seed-user-2'); INSERT INTO owners (owner_name) VALUES ('seed-user-3'); `), } - p, err := provider.NewProvider(provider.DialectSQLite3, db, mapFS) + p, err := provider.NewProvider(database.DialectSQLite3, db, mapFS) check.NoError(t, err) _, err = p.Up(ctx) check.HasError(t, err) @@ -485,7 +486,7 @@ func TestNoVersioning(t *testing.T) { // These are owners created by migration files. wantOwnerCount = 4 ) - p, err := provider.NewProvider(provider.DialectSQLite3, db, fsys, + p, err := provider.NewProvider(database.DialectSQLite3, db, fsys, provider.WithVerbose(testing.Verbose()), provider.WithNoVersioning(false), // This is the default. ) @@ -498,7 +499,7 @@ func TestNoVersioning(t *testing.T) { check.Number(t, baseVersion, 3) t.Run("seed-up-down-to-zero", func(t *testing.T) { fsys := os.DirFS(filepath.Join("testdata", "no-versioning", "seed")) - p, err := provider.NewProvider(provider.DialectSQLite3, db, fsys, + p, err := provider.NewProvider(database.DialectSQLite3, db, fsys, provider.WithVerbose(testing.Verbose()), provider.WithNoVersioning(true), // Provider with no versioning. ) @@ -551,7 +552,7 @@ func TestAllowMissing(t *testing.T) { t.Run("missing_now_allowed", func(t *testing.T) { db := newDB(t) - p, err := provider.NewProvider(provider.DialectSQLite3, db, newFsys(), + p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(), provider.WithAllowMissing(false), ) check.NoError(t, err) @@ -606,7 +607,7 @@ func TestAllowMissing(t *testing.T) { t.Run("missing_allowed", func(t *testing.T) { db := newDB(t) - p, err := provider.NewProvider(provider.DialectSQLite3, db, newFsys(), + p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(), provider.WithAllowMissing(true), ) check.NoError(t, err) @@ -714,7 +715,7 @@ func TestGoOnly(t *testing.T) { t.Cleanup(provider.ResetGlobalGoMigrations) db := newDB(t) - p, err := provider.NewProvider(provider.DialectSQLite3, db, nil, + p, err := provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithGoMigration( 2, &provider.GoMigration{Run: newTxFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, @@ -763,7 +764,7 @@ func TestGoOnly(t *testing.T) { t.Cleanup(provider.ResetGlobalGoMigrations) db := newDB(t) - p, err := provider.NewProvider(provider.DialectSQLite3, db, nil, + p, err := provider.NewProvider(database.DialectSQLite3, db, nil, provider.WithGoMigration( 2, &provider.GoMigration{RunNoTx: newDBFn("INSERT INTO users (id) VALUES (1), (2), (3)")}, @@ -820,7 +821,7 @@ func TestLockModeAdvisorySession(t *testing.T) { newProvider := func() *provider.Provider { sessionLocker, err := lock.NewPostgresSessionLocker() check.NoError(t, err) - p, err := provider.NewProvider(provider.DialectPostgres, db, os.DirFS("../../testdata/migrations"), + p, err := provider.NewProvider(database.DialectPostgres, db, os.DirFS("../../testdata/migrations"), provider.WithSessionLocker(sessionLocker), // Use advisory session lock mode. provider.WithVerbose(testing.Verbose()), ) @@ -1074,7 +1075,7 @@ func newProviderWithDB(t *testing.T, opts ...provider.ProviderOption) (*provider opts, provider.WithVerbose(testing.Verbose()), ) - p, err := provider.NewProvider(provider.DialectSQLite3, db, newFsys(), opts...) + p, err := provider.NewProvider(database.DialectSQLite3, db, newFsys(), opts...) check.NoError(t, err) return p, db } diff --git a/internal/sqladapter/sqladapter.go b/internal/sqladapter/sqladapter.go deleted file mode 100644 index f6c975dc4..000000000 --- a/internal/sqladapter/sqladapter.go +++ /dev/null @@ -1,49 +0,0 @@ -// Package sqladapter provides an interface for interacting with a SQL database. -// -// All supported database dialects must implement the Store interface. -package sqladapter - -import ( - "context" - "time" - - "github.com/pressly/goose/v3/internal/sqlextended" -) - -// Store is the interface that wraps the basic methods for a database dialect. -// -// A dialect is a set of SQL statements that are specific to a database. -// -// By defining a store interface, we can support multiple databases with a single codebase. -// -// The underlying implementation does not modify the error. It is the callers responsibility to -// assert for the correct error, such as [sql.ErrNoRows]. -type Store interface { - // CreateVersionTable creates the version table within a transaction. This table is used to - // record applied migrations. - CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn) error - - // InsertOrDelete inserts or deletes a version id from the version table. - InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error - - // GetMigration retrieves a single migration by version id. - // - // Returns the raw sql error if the query fails. It is the callers responsibility to assert for - // the correct error, such as [sql.ErrNoRows]. - GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error) - - // ListMigrations retrieves all migrations sorted in descending order by id. - // - // If there are no migrations, an empty slice is returned with no error. - ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error) -} - -type GetMigrationResult struct { - IsApplied bool - Timestamp time.Time -} - -type ListMigrationsResult struct { - Version int64 - IsApplied bool -}