Skip to content

Commit

Permalink
feat: Add provider HasPending method (#751)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman committed Apr 21, 2024
1 parent 7e96a22 commit 1ad801c
Show file tree
Hide file tree
Showing 5 changed files with 326 additions and 44 deletions.
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ test-packages:
test-packages-short:
go test -test.short $(GO_TEST_FLAGS) $$(go list ./... | grep -v -e /tests -e /bin -e /cmd -e /examples)

coverage-short:
go test ./ -test.short $(GO_TEST_FLAGS) -cover -coverprofile=coverage.out
go tool cover -html=coverage.out

coverage:
go test ./ $(GO_TEST_FLAGS) -cover -coverprofile=coverage.out
go tool cover -html=coverage.out

#
# Integration-related targets
#
Expand Down
118 changes: 118 additions & 0 deletions internal/testing/integration/postgres_locking_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"hash/crc64"
"math/rand"
"os"
"sort"
"sync"
"testing"
"testing/fstest"
"time"

"github.com/pressly/goose/v3"
"github.com/pressly/goose/v3/internal/check"
"github.com/pressly/goose/v3/internal/testing/testdb"
"github.com/pressly/goose/v3/lock"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -406,6 +410,120 @@ func TestPostgresProviderLocking(t *testing.T) {
})
}

func TestPostgresHasPending(t *testing.T) {
t.Parallel()
if testing.Short() {
t.Skip("skipping test in short mode.")
}

db, cleanup, err := testdb.NewPostgres()
require.NoError(t, err)
t.Cleanup(cleanup)

workers := 15

run := func(want bool) {
var g errgroup.Group
boolCh := make(chan bool, workers)
for i := 0; i < workers; i++ {
g.Go(func() error {
p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS("testdata/migrations/postgres"))
check.NoError(t, err)
hasPending, err := p.HasPending(context.Background())
if err != nil {
return err
}
boolCh <- hasPending
return nil

})
}
check.NoError(t, g.Wait())
close(boolCh)
// expect all values to be true
for hasPending := range boolCh {
check.Bool(t, hasPending, want)
}
}
t.Run("concurrent_has_pending", func(t *testing.T) {
run(true)
})

// apply all migrations
p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS("testdata/migrations/postgres"))
check.NoError(t, err)
_, err = p.Up(context.Background())
check.NoError(t, err)

t.Run("concurrent_no_pending", func(t *testing.T) {
run(false)
})

// Add a new migration file
last := p.ListSources()[len(p.ListSources())-1]
newVersion := fmt.Sprintf("%d_new_migration.sql", last.Version+1)
fsys := fstest.MapFS{
newVersion: &fstest.MapFile{Data: []byte(`
-- +goose Up
SELECT pg_sleep_for('4 seconds');
`)},
}
lockID := int64(crc64.Checksum([]byte(t.Name()), crc64.MakeTable(crc64.ECMA)))
// Create a new provider with the new migration file
sessionLocker, err := lock.NewPostgresSessionLocker(lock.WithLockTimeout(1, 10), lock.WithLockID(lockID)) // Timeout 5min. Try every 1s up to 10 times.
require.NoError(t, err)
newProvider, err := goose.NewProvider(goose.DialectPostgres, db, fsys, goose.WithSessionLocker(sessionLocker))
check.NoError(t, err)
check.Number(t, len(newProvider.ListSources()), 1)
oldProvider := p
check.Number(t, len(oldProvider.ListSources()), 6)

var g errgroup.Group
g.Go(func() error {
hasPending, err := newProvider.HasPending(context.Background())
if err != nil {
return err
}
check.Bool(t, hasPending, true)
return nil
})
g.Go(func() error {
hasPending, err := oldProvider.HasPending(context.Background())
if err != nil {
return err
}
check.Bool(t, hasPending, false)
return nil
})
check.NoError(t, g.Wait())

// A new provider is running in the background with a session lock to simulate a long running
// migration. If older instances come up, they should not have any pending migrations and not be
// affected by the long running migration. Test the following scenario:
// https://github.com/pressly/goose/pull/507#discussion_r1266498077
g.Go(func() error {
_, err := newProvider.Up(context.Background())
return err
})
time.Sleep(1 * time.Second)
isLocked, err := existsPgLock(context.Background(), db, lockID)
check.NoError(t, err)
check.Bool(t, isLocked, true)
hasPending, err := oldProvider.HasPending(context.Background())
check.NoError(t, err)
check.Bool(t, hasPending, false)
// Wait for the long running migration to finish
check.NoError(t, g.Wait())
// Check that the new migration was applied
hasPending, err = newProvider.HasPending(context.Background())
check.NoError(t, err)
check.Bool(t, hasPending, false)
// The max version should be the new migration
currentVersion, err := newProvider.GetDBVersion(context.Background())
check.NoError(t, err)
check.Number(t, currentVersion, last.Version+1)
}

func existsPgLock(ctx context.Context, db *sql.DB, lockID int64) (bool, error) {
q := `SELECT EXISTS(SELECT 1 FROM pg_locks WHERE locktype='advisory' AND ((classid::bigint<<32)|objid::bigint)=$1)`
row := db.QueryRowContext(ctx, q, lockID)
Expand Down
98 changes: 87 additions & 11 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ type Provider struct {
// database.
mu sync.Mutex

db *sql.DB
store database.Store
db *sql.DB
store database.Store
versionTableOnce sync.Once

fsys fs.FS
cfg config

// migrations are ordered by version in ascending order.
// migrations are ordered by version in ascending order. This list will never be empty and
// contains all migrations known to the provider.
migrations []*Migration
}

Expand All @@ -49,8 +51,6 @@ type Provider struct {
// See [ProviderOption] for more information on configuring the provider.
//
// Unless otherwise specified, all methods on Provider are safe for concurrent use.
//
// Experimental: This API is experimental and may change in the future.
func NewProvider(dialect Dialect, db *sql.DB, fsys fs.FS, opts ...ProviderOption) (*Provider, error) {
if db == nil {
return nil, errors.New("db must not be nil")
Expand Down Expand Up @@ -154,6 +154,14 @@ func (p *Provider) Status(ctx context.Context) ([]*MigrationStatus, error) {
return p.status(ctx)
}

// HasPending returns true if there are pending migrations to apply, otherwise, it returns false.
//
// Note, this method will not use a SessionLocker if one is configured. This allows callers to check
// for pending migrations without blocking or being blocked by other operations.
func (p *Provider) HasPending(ctx context.Context) (bool, error) {
return p.hasPending(ctx)
}

// GetDBVersion returns the highest version recorded in the database, regardless of the order in
// which migrations were applied. For example, if migrations were applied out of order (1,4,2,3),
// this method returns 4. If no migrations have been applied, it returns 0.
Expand Down Expand Up @@ -214,12 +222,26 @@ func (p *Provider) ApplyVersion(ctx context.Context, version int64, direction bo
// Up applies all pending migrations. If there are no new migrations to apply, this method returns
// empty list and nil error.
func (p *Provider) Up(ctx context.Context) ([]*MigrationResult, error) {
hasPending, err := p.HasPending(ctx)
if err != nil {
return nil, err
}
if !hasPending {
return nil, nil
}
return p.up(ctx, false, math.MaxInt64)
}

// UpByOne applies the next pending migration. If there is no next migration to apply, this method
// returns [ErrNoNextVersion]. The returned list will always have exactly one migration result.
// returns [ErrNoNextVersion].
func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
hasPending, err := p.HasPending(ctx)
if err != nil {
return nil, err
}
if !hasPending {
return nil, ErrNoNextVersion
}
res, err := p.up(ctx, true, math.MaxInt64)
if err != nil {
return nil, err
Expand Down Expand Up @@ -247,6 +269,13 @@ func (p *Provider) UpByOne(ctx context.Context) (*MigrationResult, error) {
// For example, if there are three new migrations (9,10,11) and the current database version is 8
// with a requested version of 10, only versions 9,10 will be applied.
func (p *Provider) UpTo(ctx context.Context, version int64) ([]*MigrationResult, error) {
hasPending, err := p.HasPending(ctx)
if err != nil {
return nil, err
}
if !hasPending {
return nil, nil
}
return p.up(ctx, false, version)
}

Expand Down Expand Up @@ -303,7 +332,7 @@ func (p *Provider) up(
if version < 1 {
return nil, errInvalidVersion
}
conn, cleanup, err := p.initialize(ctx)
conn, cleanup, err := p.initialize(ctx, true)
if err != nil {
return nil, fmt.Errorf("failed to initialize: %w", err)
}
Expand Down Expand Up @@ -345,7 +374,7 @@ func (p *Provider) down(
byOne bool,
version int64,
) (_ []*MigrationResult, retErr error) {
conn, cleanup, err := p.initialize(ctx)
conn, cleanup, err := p.initialize(ctx, true)
if err != nil {
return nil, fmt.Errorf("failed to initialize: %w", err)
}
Expand Down Expand Up @@ -404,7 +433,7 @@ func (p *Provider) apply(
if err != nil {
return nil, err
}
conn, cleanup, err := p.initialize(ctx)
conn, cleanup, err := p.initialize(ctx, true)
if err != nil {
return nil, fmt.Errorf("failed to initialize: %w", err)
}
Expand Down Expand Up @@ -436,8 +465,55 @@ func (p *Provider) apply(
return p.runMigrations(ctx, conn, []*Migration{m}, d, true)
}

func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) {
conn, cleanup, err := p.initialize(ctx, false)
if err != nil {
return false, fmt.Errorf("failed to initialize: %w", err)
}
defer func() {
retErr = multierr.Append(retErr, cleanup())
}()

// If versioning is disabled, we always have pending migrations.
if p.cfg.disableVersioning {
return true, nil
}
if p.cfg.allowMissing {
// List all migrations from the database.
dbMigrations, err := p.store.ListMigrations(ctx, conn)
if err != nil {
return false, err
}
// If there are no migrations in the database, we have pending migrations.
if len(dbMigrations) == 0 {
return true, nil
}
applied := make(map[int64]bool, len(dbMigrations))
for _, m := range dbMigrations {
applied[m.Version] = true
}
// Iterate over all migrations and check if any are missing.
for _, m := range p.migrations {
if !applied[m.Version] {
return true, nil
}
}
return false, nil
}
// If out-of-order migrations are not allowed, we can optimize this by only checking whether the
// last migration the provider knows about is applied.
last := p.migrations[len(p.migrations)-1]
if _, err := p.store.GetMigration(ctx, conn, last.Version); err != nil {
if errors.Is(err, database.ErrVersionNotFound) {
return true, nil
}
return false, err
}
return false, nil
}

func (p *Provider) status(ctx context.Context) (_ []*MigrationStatus, retErr error) {
conn, cleanup, err := p.initialize(ctx)
conn, cleanup, err := p.initialize(ctx, true)
if err != nil {
return nil, fmt.Errorf("failed to initialize: %w", err)
}
Expand Down Expand Up @@ -478,7 +554,7 @@ func (p *Provider) getDBMaxVersion(ctx context.Context, conn *sql.Conn) (_ int64
if conn == nil {
var cleanup func() error
var err error
conn, cleanup, err = p.initialize(ctx)
conn, cleanup, err = p.initialize(ctx, true)
if err != nil {
return 0, err
}
Expand Down

0 comments on commit 1ad801c

Please sign in to comment.