Skip to content

Commit

Permalink
add mutex with concurrency test
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman committed May 19, 2023
1 parent 209973b commit 6e2416e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
11 changes: 10 additions & 1 deletion run.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,11 @@ func (p *Provider) beginTx(
}

func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, error) {
p.mu.Lock()

conn, err := p.db.Conn(ctx)
if err != nil {
p.mu.Unlock()
return nil, nil, err
}
var (
Expand All @@ -217,14 +220,20 @@ func (p *Provider) initialize(ctx context.Context) (*sql.Conn, func() error, err
switch p.opt.LockMode {
case LockModeAdvisorySession:
if err := p.store.LockSession(ctx, conn); err != nil {
p.mu.Unlock()
return nil, nil, err
}
cleanup = func() error {
defer p.mu.Unlock()
return errors.Join(p.store.UnlockSession(ctx, conn), conn.Close())
}
case LockModeNone:
cleanup = conn.Close
cleanup = func() error {
defer p.mu.Unlock()
return conn.Close()
}
default:
p.mu.Unlock()
return nil, nil, fmt.Errorf("invalid lock mode: %d", p.opt.LockMode)
}
// If versioning is enabled, ensure the version table exists.
Expand Down
48 changes: 48 additions & 0 deletions tests/e2e/postgres/concurrency_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package postgres_test

import (
"context"
"sync"
"testing"

"github.com/pressly/goose/v4/internal/check"
)

func TestConcurrentProvider(t *testing.T) {
t.Parallel()
ctx := context.Background()
te := newTestEnv(t, migrationsDir, nil)

expected := 7

ch := make(chan int64)
var wg sync.WaitGroup
for i := 0; i < expected; i++ {
wg.Add(1)

go func() {
defer wg.Done()
res, err := te.provider.UpByOne(ctx)
if err != nil {
t.Error(err)
return
}
ch <- res.Version
}()
}
go func() {
wg.Wait()
close(ch)
}()
var versions []int64
for version := range ch {
versions = append(versions, version)
}
check.Number(t, len(versions), expected)
for i := 0; i < expected; i++ {
check.Number(t, versions[i], int64(i+1))
}
version, err := te.provider.GetDBVersion(ctx)
check.NoError(t, err)
check.Number(t, version, expected)
}

0 comments on commit 6e2416e

Please sign in to comment.