Skip to content

Commit

Permalink
feat(experimental): add package database with Store interface (#623)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman authored Oct 27, 2023
1 parent a9da750 commit 4ec43df
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 129 deletions.
92 changes: 56 additions & 36 deletions internal/sqladapter/store.go → database/dialect.go
Original file line number Diff line number Diff line change
@@ -1,67 +1,80 @@
package sqladapter
package database

import (
"context"
"errors"
"fmt"

"github.com/pressly/goose/v3/internal/dialect/dialectquery"
"github.com/pressly/goose/v3/internal/sqlextended"
)

var _ Store = (*store)(nil)
// Dialect is the type of database dialect.
type Dialect string

type store struct {
tablename string
querier dialectquery.Querier
}
const (
DialectClickHouse Dialect = "clickhouse"
DialectMSSQL Dialect = "mssql"
DialectMySQL Dialect = "mysql"
DialectPostgres Dialect = "postgres"
DialectRedshift Dialect = "redshift"
DialectSQLite3 Dialect = "sqlite3"
DialectTiDB Dialect = "tidb"
DialectVertica Dialect = "vertica"
DialectYdB Dialect = "ydb"

// DialectCustom is a special dialect that allows users to provide their own [Store]
// implementation when constructing a [goose.Provider].
DialectCustom Dialect = "custom"
)

// NewStore returns a new [Store] backed by the given dialect.
//
// The dialect must match one of the supported dialects defined in dialect.go.
func NewStore(dialect string, table string) (Store, error) {
if table == "" {
return nil, errors.New("table must not be empty")
func NewStore(dialect Dialect, tablename string) (Store, error) {
if tablename == "" {
return nil, errors.New("tablename must not be empty")
}
if dialect == "" {
return nil, errors.New("dialect must not be empty")
}
var querier dialectquery.Querier
switch dialect {
case "clickhouse":
querier = &dialectquery.Clickhouse{}
case "mssql":
querier = &dialectquery.Sqlserver{}
case "mysql":
querier = &dialectquery.Mysql{}
case "postgres":
querier = &dialectquery.Postgres{}
case "redshift":
querier = &dialectquery.Redshift{}
case "sqlite3":
querier = &dialectquery.Sqlite3{}
case "tidb":
querier = &dialectquery.Tidb{}
case "vertica":
querier = &dialectquery.Vertica{}
default:
if dialect == DialectCustom {
return nil, errors.New("dialect must not be custom")
}
lookup := map[Dialect]dialectquery.Querier{
DialectClickHouse: &dialectquery.Clickhouse{},
DialectMSSQL: &dialectquery.Sqlserver{},
DialectMySQL: &dialectquery.Mysql{},
DialectPostgres: &dialectquery.Postgres{},
DialectRedshift: &dialectquery.Redshift{},
DialectSQLite3: &dialectquery.Sqlite3{},
DialectTiDB: &dialectquery.Tidb{},
DialectVertica: &dialectquery.Vertica{},
DialectYdB: &dialectquery.Ydb{},
}
querier, ok := lookup[dialect]
if !ok {
return nil, fmt.Errorf("unknown dialect: %q", dialect)
}
return &store{
tablename: table,
tablename: tablename,
querier: querier,
}, nil
}

func (s *store) CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn) error {
type store struct {
tablename string
querier dialectquery.Querier
}

var _ Store = (*store)(nil)

func (s *store) CreateVersionTable(ctx context.Context, db DBTxConn) error {
q := s.querier.CreateTable(s.tablename)
if _, err := db.ExecContext(ctx, q); err != nil {
return fmt.Errorf("failed to create version table %q: %w", s.tablename, err)
}
return nil
}

func (s *store) InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error {
func (s *store) InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error {
if direction {
q := s.querier.InsertVersion(s.tablename)
if _, err := db.ExecContext(ctx, q, version, true); err != nil {
Expand All @@ -76,7 +89,11 @@ func (s *store) InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, dir
return nil
}

func (s *store) GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error) {
func (s *store) GetMigration(
ctx context.Context,
db DBTxConn,
version int64,
) (*GetMigrationResult, error) {
q := s.querier.GetMigrationByVersion(s.tablename)
var result GetMigrationResult
if err := db.QueryRowContext(ctx, q, version).Scan(
Expand All @@ -88,7 +105,10 @@ func (s *store) GetMigration(ctx context.Context, db sqlextended.DBTxConn, versi
return &result, nil
}

func (s *store) ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error) {
func (s *store) ListMigrations(
ctx context.Context,
db DBTxConn,
) ([]*ListMigrationsResult, error) {
q := s.querier.ListMigrations(s.tablename)
rows, err := db.QueryContext(ctx, q)
if err != nil {
Expand Down
14 changes: 14 additions & 0 deletions database/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Package database defines a generic [Store] interface for goose to use when interacting with the
// database. It is meant to be generic and not tied to any specific database technology.
//
// At a high level, a [Store] is responsible for:
// - Creating a version table
// - Inserting and deleting a version
// - Getting a specific version
// - Listing all applied versions
//
// Use the [NewStore] function to create a [Store] for one of the supported dialects.
//
// For more advanced use cases, it's possible to implement a custom [Store] for a database that
// goose does not support.
package database
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package sqlextended
package database

import (
"context"
"database/sql"
)

// DBTxConn is a thin interface for common method that is satisfied by *sql.DB, *sql.Tx and
// DBTxConn is a thin interface for common methods that is satisfied by *sql.DB, *sql.Tx and
// *sql.Conn.
//
// There is a long outstanding issue to formalize a std lib interface, but alas... See:
// There is a long outstanding issue to formalize a std lib interface, but alas. See:
// https://github.com/golang/go/issues/14468
type DBTxConn interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Expand Down
39 changes: 39 additions & 0 deletions database/store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package database

import (
"context"
"time"
)

// Store is an interface that defines methods for managing database migrations and versioning. By
// defining a Store interface, we can support multiple databases with consistent functionality.
//
// Each database dialect requires a specific implementation of this interface. A dialect represents
// a set of SQL statements specific to a particular database system.
type Store interface {
// CreateVersionTable creates the version table. This table is used to record applied
// migrations.
CreateVersionTable(ctx context.Context, db DBTxConn) error

// InsertOrDelete inserts or deletes a version id from the version table. If direction is true,
// insert the version id. If direction is false, delete the version id.
InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error

// GetMigration retrieves a single migration by version id. This method may return the raw sql
// error if the query fails so the caller can assert for errors such as [sql.ErrNoRows].
GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error)

// ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If
// there are no migrations, return empty slice with no error.
ListMigrations(ctx context.Context, db DBTxConn) ([]*ListMigrationsResult, error)
}

type GetMigrationResult struct {
Timestamp time.Time
IsApplied bool
}

type ListMigrationsResult struct {
Version int64
IsApplied bool
}
36 changes: 21 additions & 15 deletions internal/sqladapter/store_test.go → database/store_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package sqladapter_test
package database_test

import (
"context"
Expand All @@ -8,29 +8,30 @@ import (
"testing"

"github.com/jackc/pgx/v5/pgconn"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/sqladapter"
"github.com/pressly/goose/v3/internal/testdb"
"go.uber.org/multierr"
"modernc.org/sqlite"
)

// The goal of this test is to verify the sqladapter package works as expected. This test is not
// The goal of this test is to verify the database store package works as expected. This test is not
// meant to be exhaustive or test every possible database dialect. It is meant to verify the Store
// interface works against a real database.

func TestStore(t *testing.T) {
func TestDialectStore(t *testing.T) {
t.Parallel()
t.Run("invalid", func(t *testing.T) {
// Test empty table name.
_, err := sqladapter.NewStore("sqlite3", "")
_, err := database.NewStore(database.DialectSQLite3, "")
check.HasError(t, err)
// Test unknown dialect.
_, err = sqladapter.NewStore("unknown-dialect", "foo")
_, err = database.NewStore("unknown-dialect", "foo")
check.HasError(t, err)
// Test empty dialect.
_, err = sqladapter.NewStore("", "foo")
_, err = database.NewStore("", "foo")
check.HasError(t, err)
_, err = database.NewStore(database.DialectCustom, "foo")
check.HasError(t, err)
})
t.Run("postgres", func(t *testing.T) {
Expand All @@ -41,7 +42,7 @@ func TestStore(t *testing.T) {
db, cleanup, err := testdb.NewPostgres()
check.NoError(t, err)
t.Cleanup(cleanup)
testStore(context.Background(), t, goose.DialectPostgres, db, func(t *testing.T, err error) {
testStore(context.Background(), t, database.DialectPostgres, db, func(t *testing.T, err error) {
var pgErr *pgconn.PgError
ok := errors.As(err, &pgErr)
check.Bool(t, ok, true)
Expand All @@ -50,10 +51,9 @@ func TestStore(t *testing.T) {
})
// Test generic behavior.
t.Run("sqlite3", func(t *testing.T) {
dir := t.TempDir()
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
db, err := sql.Open("sqlite", ":memory:")
check.NoError(t, err)
testStore(context.Background(), t, goose.DialectSQLite3, db, func(t *testing.T, err error) {
testStore(context.Background(), t, database.DialectSQLite3, db, func(t *testing.T, err error) {
var sqliteErr *sqlite.Error
ok := errors.As(err, &sqliteErr)
check.Bool(t, ok, true)
Expand All @@ -65,7 +65,7 @@ func TestStore(t *testing.T) {
dir := t.TempDir()
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
check.NoError(t, err)
store, err := sqladapter.NewStore("sqlite3", "foo")
store, err := database.NewStore(database.DialectSQLite3, "foo")
check.NoError(t, err)
err = store.CreateVersionTable(context.Background(), db)
check.NoError(t, err)
Expand All @@ -86,11 +86,17 @@ func TestStore(t *testing.T) {
//
// If alreadyExists is not nil, it will be used to assert the error returned by CreateVersionTable
// when the version table already exists.
func testStore(ctx context.Context, t *testing.T, dialect goose.Dialect, db *sql.DB, alreadyExists func(t *testing.T, err error)) {
func testStore(
ctx context.Context,
t *testing.T,
d database.Dialect,
db *sql.DB,
alreadyExists func(t *testing.T, err error),
) {
const (
tablename = "test_goose_db_version"
)
store, err := sqladapter.NewStore(string(dialect), tablename)
store, err := database.NewStore(d, tablename)
check.NoError(t, err)
// Create the version table.
err = runTx(ctx, db, func(tx *sql.Tx) error {
Expand Down
6 changes: 3 additions & 3 deletions internal/provider/collect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"testing"
"testing/fstest"

"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/sqladapter"
)

func TestCollectFileSources(t *testing.T) {
Expand Down Expand Up @@ -294,7 +294,7 @@ func TestFindMissingMigrations(t *testing.T) {
// Test case: database has migrations 1, 3, 4, 5, 7
// Missing migrations: 2, 6
// Filesystem has migrations 1, 2, 3, 4, 5, 6, 7, 8
dbMigrations := []*sqladapter.ListMigrationsResult{
dbMigrations := []*database.ListMigrationsResult{
{Version: 1},
{Version: 3},
{Version: 4},
Expand Down Expand Up @@ -322,7 +322,7 @@ func TestFindMissingMigrations(t *testing.T) {
check.Number(t, len(findMissingMigrations(nil, fsMigrations)), 0)
})
t.Run("fs_has_max_version", func(t *testing.T) {
dbMigrations := []*sqladapter.ListMigrationsResult{
dbMigrations := []*database.ListMigrationsResult{
{Version: 1},
{Version: 5},
{Version: 2},
Expand Down
4 changes: 2 additions & 2 deletions internal/provider/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"fmt"
"path/filepath"

"github.com/pressly/goose/v3/internal/sqlextended"
"github.com/pressly/goose/v3/database"
)

type migration struct {
Expand Down Expand Up @@ -170,7 +170,7 @@ func (s *sqlMigration) IsEmpty(direction bool) bool {
return len(s.DownStatements) == 0
}

func (s *sqlMigration) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error {
func (s *sqlMigration) run(ctx context.Context, db database.DBTxConn, direction bool) error {
var statements []string
if direction {
statements = s.UpStatements
Expand Down
11 changes: 4 additions & 7 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"math"
"sync"

"github.com/pressly/goose/v3/internal/sqladapter"
"github.com/pressly/goose/v3/database"
)

// NewProvider returns a new goose Provider.
Expand All @@ -28,13 +28,10 @@ import (
// Unless otherwise specified, all methods on Provider are safe for concurrent use.
//
// Experimental: This API is experimental and may change in the future.
func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
if db == nil {
return nil, errors.New("db must not be nil")
}
if dialect == "" {
return nil, errors.New("dialect must not be empty")
}
if fsys == nil {
fsys = noopFS{}
}
Expand All @@ -51,7 +48,7 @@ func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption
if cfg.tableName == "" {
cfg.tableName = DefaultTablename
}
store, err := sqladapter.NewStore(string(dialect), cfg.tableName)
store, err := database.NewStore(dialect, cfg.tableName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -138,7 +135,7 @@ type Provider struct {
db *sql.DB
fsys fs.FS
cfg config
store sqladapter.Store
store database.Store

// migrations are ordered by version in ascending order.
migrations []*migration
Expand Down
Loading

0 comments on commit 4ec43df

Please sign in to comment.