Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(experimental): add package database with Store interface #623

Merged
merged 8 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 56 additions & 36 deletions internal/sqladapter/store.go → database/dialect.go
Original file line number Diff line number Diff line change
@@ -1,67 +1,80 @@
package sqladapter
package database

import (
"context"
"errors"
"fmt"

"github.com/pressly/goose/v3/internal/dialect/dialectquery"
"github.com/pressly/goose/v3/internal/sqlextended"
)

var _ Store = (*store)(nil)
// Dialect is the type of database dialect.
type Dialect string

type store struct {
tablename string
querier dialectquery.Querier
}
const (
DialectClickHouse Dialect = "clickhouse"
DialectMSSQL Dialect = "mssql"
DialectMySQL Dialect = "mysql"
DialectPostgres Dialect = "postgres"
DialectRedshift Dialect = "redshift"
DialectSQLite3 Dialect = "sqlite3"
DialectTiDB Dialect = "tidb"
DialectVertica Dialect = "vertica"
DialectYdB Dialect = "ydb"

// DialectCustom is a special dialect that allows users to provide their own [Store]
// implementation when constructing a [goose.Provider].
DialectCustom Dialect = "custom"
)

// NewStore returns a new [Store] backed by the given dialect.
//
// The dialect must match one of the supported dialects defined in dialect.go.
func NewStore(dialect string, table string) (Store, error) {
if table == "" {
return nil, errors.New("table must not be empty")
func NewStore(dialect Dialect, tablename string) (Store, error) {
if tablename == "" {
return nil, errors.New("tablename must not be empty")
}
if dialect == "" {
return nil, errors.New("dialect must not be empty")
}
var querier dialectquery.Querier
switch dialect {
case "clickhouse":
querier = &dialectquery.Clickhouse{}
case "mssql":
querier = &dialectquery.Sqlserver{}
case "mysql":
querier = &dialectquery.Mysql{}
case "postgres":
querier = &dialectquery.Postgres{}
case "redshift":
querier = &dialectquery.Redshift{}
case "sqlite3":
querier = &dialectquery.Sqlite3{}
case "tidb":
querier = &dialectquery.Tidb{}
case "vertica":
querier = &dialectquery.Vertica{}
default:
if dialect == DialectCustom {
return nil, errors.New("dialect must not be custom")
}
lookup := map[Dialect]dialectquery.Querier{
DialectClickHouse: &dialectquery.Clickhouse{},
DialectMSSQL: &dialectquery.Sqlserver{},
DialectMySQL: &dialectquery.Mysql{},
DialectPostgres: &dialectquery.Postgres{},
DialectRedshift: &dialectquery.Redshift{},
DialectSQLite3: &dialectquery.Sqlite3{},
DialectTiDB: &dialectquery.Tidb{},
DialectVertica: &dialectquery.Vertica{},
DialectYdB: &dialectquery.Ydb{},
}
querier, ok := lookup[dialect]
if !ok {
return nil, fmt.Errorf("unknown dialect: %q", dialect)
}
return &store{
tablename: table,
tablename: tablename,
querier: querier,
}, nil
}

func (s *store) CreateVersionTable(ctx context.Context, db sqlextended.DBTxConn) error {
type store struct {
tablename string
querier dialectquery.Querier
}

var _ Store = (*store)(nil)

func (s *store) CreateVersionTable(ctx context.Context, db DBTxConn) error {
q := s.querier.CreateTable(s.tablename)
if _, err := db.ExecContext(ctx, q); err != nil {
return fmt.Errorf("failed to create version table %q: %w", s.tablename, err)
}
return nil
}

func (s *store) InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, direction bool, version int64) error {
func (s *store) InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error {
if direction {
q := s.querier.InsertVersion(s.tablename)
if _, err := db.ExecContext(ctx, q, version, true); err != nil {
Expand All @@ -76,7 +89,11 @@ func (s *store) InsertOrDelete(ctx context.Context, db sqlextended.DBTxConn, dir
return nil
}

func (s *store) GetMigration(ctx context.Context, db sqlextended.DBTxConn, version int64) (*GetMigrationResult, error) {
func (s *store) GetMigration(
ctx context.Context,
db DBTxConn,
version int64,
) (*GetMigrationResult, error) {
q := s.querier.GetMigrationByVersion(s.tablename)
var result GetMigrationResult
if err := db.QueryRowContext(ctx, q, version).Scan(
Expand All @@ -88,7 +105,10 @@ func (s *store) GetMigration(ctx context.Context, db sqlextended.DBTxConn, versi
return &result, nil
}

func (s *store) ListMigrations(ctx context.Context, db sqlextended.DBTxConn) ([]*ListMigrationsResult, error) {
func (s *store) ListMigrations(
ctx context.Context,
db DBTxConn,
) ([]*ListMigrationsResult, error) {
q := s.querier.ListMigrations(s.tablename)
rows, err := db.QueryContext(ctx, q)
if err != nil {
Expand Down
14 changes: 14 additions & 0 deletions database/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Package database defines a generic [Store] interface for goose to use when interacting with the
// database. It is meant to be generic and not tied to any specific database technology.
//
// At a high level, a [Store] is responsible for:
// - Creating a version table
// - Inserting and deleting a version
// - Getting a specific version
// - Listing all applied versions
//
// Use the [NewStore] function to create a [Store] for one of the supported dialects.
//
// For more advanced use cases, it's possible to implement a custom [Store] for a database that
// goose does not support.
package database
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package sqlextended
package database

import (
"context"
"database/sql"
)

// DBTxConn is a thin interface for common method that is satisfied by *sql.DB, *sql.Tx and
// DBTxConn is a thin interface for common methods that is satisfied by *sql.DB, *sql.Tx and
// *sql.Conn.
//
// There is a long outstanding issue to formalize a std lib interface, but alas... See:
// There is a long outstanding issue to formalize a std lib interface, but alas. See:
// https://github.com/golang/go/issues/14468
type DBTxConn interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Expand Down
39 changes: 39 additions & 0 deletions database/store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package database

import (
"context"
"time"
)

// Store is an interface that defines methods for managing database migrations and versioning. By
// defining a Store interface, we can support multiple databases with consistent functionality.
//
// Each database dialect requires a specific implementation of this interface. A dialect represents
// a set of SQL statements specific to a particular database system.
type Store interface {
// CreateVersionTable creates the version table. This table is used to record applied
// migrations.
CreateVersionTable(ctx context.Context, db DBTxConn) error

// InsertOrDelete inserts or deletes a version id from the version table. If direction is true,
// insert the version id. If direction is false, delete the version id.
InsertOrDelete(ctx context.Context, db DBTxConn, direction bool, version int64) error

// GetMigration retrieves a single migration by version id. This method may return the raw sql
// error if the query fails so the caller can assert for errors such as [sql.ErrNoRows].
GetMigration(ctx context.Context, db DBTxConn, version int64) (*GetMigrationResult, error)

// ListMigrations retrieves all migrations sorted in descending order by id or timestamp. If
// there are no migrations, return empty slice with no error.
ListMigrations(ctx context.Context, db DBTxConn) ([]*ListMigrationsResult, error)
}

type GetMigrationResult struct {
Timestamp time.Time
IsApplied bool
}

type ListMigrationsResult struct {
Version int64
IsApplied bool
}
36 changes: 21 additions & 15 deletions internal/sqladapter/store_test.go → database/store_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package sqladapter_test
package database_test

import (
"context"
Expand All @@ -8,29 +8,30 @@ import (
"testing"

"github.com/jackc/pgx/v5/pgconn"
"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/sqladapter"
"github.com/pressly/goose/v3/internal/testdb"
"go.uber.org/multierr"
"modernc.org/sqlite"
)

// The goal of this test is to verify the sqladapter package works as expected. This test is not
// The goal of this test is to verify the database store package works as expected. This test is not
// meant to be exhaustive or test every possible database dialect. It is meant to verify the Store
// interface works against a real database.

func TestStore(t *testing.T) {
func TestDialectStore(t *testing.T) {
t.Parallel()
t.Run("invalid", func(t *testing.T) {
// Test empty table name.
_, err := sqladapter.NewStore("sqlite3", "")
_, err := database.NewStore(database.DialectSQLite3, "")
check.HasError(t, err)
// Test unknown dialect.
_, err = sqladapter.NewStore("unknown-dialect", "foo")
_, err = database.NewStore("unknown-dialect", "foo")
check.HasError(t, err)
// Test empty dialect.
_, err = sqladapter.NewStore("", "foo")
_, err = database.NewStore("", "foo")
check.HasError(t, err)
_, err = database.NewStore(database.DialectCustom, "foo")
check.HasError(t, err)
})
t.Run("postgres", func(t *testing.T) {
Expand All @@ -41,7 +42,7 @@ func TestStore(t *testing.T) {
db, cleanup, err := testdb.NewPostgres()
check.NoError(t, err)
t.Cleanup(cleanup)
testStore(context.Background(), t, goose.DialectPostgres, db, func(t *testing.T, err error) {
testStore(context.Background(), t, database.DialectPostgres, db, func(t *testing.T, err error) {
var pgErr *pgconn.PgError
ok := errors.As(err, &pgErr)
check.Bool(t, ok, true)
Expand All @@ -50,10 +51,9 @@ func TestStore(t *testing.T) {
})
// Test generic behavior.
t.Run("sqlite3", func(t *testing.T) {
dir := t.TempDir()
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
db, err := sql.Open("sqlite", ":memory:")
check.NoError(t, err)
testStore(context.Background(), t, goose.DialectSQLite3, db, func(t *testing.T, err error) {
testStore(context.Background(), t, database.DialectSQLite3, db, func(t *testing.T, err error) {
var sqliteErr *sqlite.Error
ok := errors.As(err, &sqliteErr)
check.Bool(t, ok, true)
Expand All @@ -65,7 +65,7 @@ func TestStore(t *testing.T) {
dir := t.TempDir()
db, err := sql.Open("sqlite", filepath.Join(dir, "sql_embed.db"))
check.NoError(t, err)
store, err := sqladapter.NewStore("sqlite3", "foo")
store, err := database.NewStore(database.DialectSQLite3, "foo")
check.NoError(t, err)
err = store.CreateVersionTable(context.Background(), db)
check.NoError(t, err)
Expand All @@ -86,11 +86,17 @@ func TestStore(t *testing.T) {
//
// If alreadyExists is not nil, it will be used to assert the error returned by CreateVersionTable
// when the version table already exists.
func testStore(ctx context.Context, t *testing.T, dialect goose.Dialect, db *sql.DB, alreadyExists func(t *testing.T, err error)) {
func testStore(
ctx context.Context,
t *testing.T,
d database.Dialect,
db *sql.DB,
alreadyExists func(t *testing.T, err error),
) {
const (
tablename = "test_goose_db_version"
)
store, err := sqladapter.NewStore(string(dialect), tablename)
store, err := database.NewStore(d, tablename)
check.NoError(t, err)
// Create the version table.
err = runTx(ctx, db, func(tx *sql.Tx) error {
Expand Down
6 changes: 3 additions & 3 deletions internal/provider/collect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"testing"
"testing/fstest"

"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/sqladapter"
)

func TestCollectFileSources(t *testing.T) {
Expand Down Expand Up @@ -294,7 +294,7 @@ func TestFindMissingMigrations(t *testing.T) {
// Test case: database has migrations 1, 3, 4, 5, 7
// Missing migrations: 2, 6
// Filesystem has migrations 1, 2, 3, 4, 5, 6, 7, 8
dbMigrations := []*sqladapter.ListMigrationsResult{
dbMigrations := []*database.ListMigrationsResult{
{Version: 1},
{Version: 3},
{Version: 4},
Expand Down Expand Up @@ -322,7 +322,7 @@ func TestFindMissingMigrations(t *testing.T) {
check.Number(t, len(findMissingMigrations(nil, fsMigrations)), 0)
})
t.Run("fs_has_max_version", func(t *testing.T) {
dbMigrations := []*sqladapter.ListMigrationsResult{
dbMigrations := []*database.ListMigrationsResult{
{Version: 1},
{Version: 5},
{Version: 2},
Expand Down
4 changes: 2 additions & 2 deletions internal/provider/migration.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"fmt"
"path/filepath"

"github.com/pressly/goose/v3/internal/sqlextended"
"github.com/pressly/goose/v3/database"
)

type migration struct {
Expand Down Expand Up @@ -170,7 +170,7 @@ func (s *sqlMigration) IsEmpty(direction bool) bool {
return len(s.DownStatements) == 0
}

func (s *sqlMigration) run(ctx context.Context, db sqlextended.DBTxConn, direction bool) error {
func (s *sqlMigration) run(ctx context.Context, db database.DBTxConn, direction bool) error {
var statements []string
if direction {
statements = s.UpStatements
Expand Down
11 changes: 4 additions & 7 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"math"
"sync"

"github.com/pressly/goose/v3/internal/sqladapter"
"github.com/pressly/goose/v3/database"
)

// NewProvider returns a new goose Provider.
Expand All @@ -28,13 +28,10 @@ import (
// Unless otherwise specified, all methods on Provider are safe for concurrent use.
//
// Experimental: This API is experimental and may change in the future.
func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
func NewProvider(dialect database.Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
if db == nil {
return nil, errors.New("db must not be nil")
}
if dialect == "" {
return nil, errors.New("dialect must not be empty")
}
if fsys == nil {
fsys = noopFS{}
}
Expand All @@ -51,7 +48,7 @@ func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption
if cfg.tableName == "" {
cfg.tableName = DefaultTablename
}
store, err := sqladapter.NewStore(string(dialect), cfg.tableName)
store, err := database.NewStore(dialect, cfg.tableName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -138,7 +135,7 @@ type Provider struct {
db *sql.DB
fsys fs.FS
cfg config
store sqladapter.Store
store database.Store

// migrations are ordered by version in ascending order.
migrations []*migration
Expand Down
Loading
Loading