Skip to content

Commit

Permalink
Add WithRawSQLURL to configure an optional URl for raw sql operatio…
Browse files Browse the repository at this point in the history
…ns (#315)

This setting comes handy when you need to use a different connection
string for raw SQL operations, as they may require some more security
checks.

---------

Co-authored-by: Andrew Farries <andyrb@gmail.com>
  • Loading branch information
exekias and andrew-farries committed Mar 13, 2024
1 parent f5d09d5 commit 8ed5799
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 20 deletions.
19 changes: 16 additions & 3 deletions pkg/roll/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package roll

import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
Expand Down Expand Up @@ -66,7 +67,7 @@ func (m *Roll) StartDDLOperations(ctx context.Context, migration *migrations.Mig
// execute operations
var tablesToBackfill []*schema.Table
for _, op := range migration.Operations {
table, err := op.Start(ctx, m.pgConn, m.state.Schema(), newSchema, cbs...)
table, err := op.Start(ctx, m.connForOp(op), m.state.Schema(), newSchema, cbs...)
if err != nil {
errRollback := m.Rollback(ctx)

Expand Down Expand Up @@ -154,7 +155,7 @@ func (m *Roll) Complete(ctx context.Context) error {
// execute operations
refreshViews := false
for _, op := range migration.Operations {
err := op.Complete(ctx, m.pgConn, schema)
err := op.Complete(ctx, m.connForOp(op), schema)
if err != nil {
return fmt.Errorf("unable to execute complete operation: %w", err)
}
Expand Down Expand Up @@ -204,7 +205,7 @@ func (m *Roll) Rollback(ctx context.Context) error {

// execute operations
for _, op := range migration.Operations {
err := op.Rollback(ctx, m.pgConn)
err := op.Rollback(ctx, m.connForOp(op))
if err != nil {
return fmt.Errorf("unable to execute rollback operation: %w", err)
}
Expand All @@ -219,6 +220,18 @@ func (m *Roll) Rollback(ctx context.Context) error {
return nil
}

// connForOp returns the connection to use for the given operation.
// If the operation is a raw SQL operation, it will use the rawSQLConn (if set)
// otherwise it will use the regular pgConn
func (m *Roll) connForOp(op migrations.Operation) *sql.DB {
if m.pgRawSQLConn != nil {
if _, ok := op.(*migrations.OpRawSQL); ok {
return m.pgRawSQLConn
}
}
return m.pgConn
}

// create view creates a view for the new version of the schema
func (m *Roll) ensureView(ctx context.Context, version, name string, table schema.Table) error {
columns := make([]string, 0, len(table.Columns))
Expand Down
75 changes: 75 additions & 0 deletions pkg/roll/execute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"database/sql"
"errors"
"fmt"
"net/url"
"testing"

"github.com/lib/pq"
Expand Down Expand Up @@ -586,6 +587,63 @@ func TestMigrationHooksAreInvoked(t *testing.T) {
})
}

func TestRawSQLURLOption(t *testing.T) {
t.Parallel()

testutils.WithConnectionString(t, schema, func(st *state.State, connStr string) {
ctx := context.Background()

db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}

// create a user for rawSQLURL
_, err = db.Exec(`
CREATE USER rawsql WITH PASSWORD 'rawsql';
GRANT ALL PRIVILEGES ON SCHEMA public TO rawsql;
`)
assert.NoError(t, err)

// init pgroll with a rawSQLURL
rawSQL, err := url.Parse(connStr)
assert.NoError(t, err)
rawSQL.User = url.UserPassword("rawsql", "rawsql")

mig, err := roll.New(ctx, connStr, schema, st, roll.WithRawSQLURL(rawSQL.String()))
assert.NoError(t, err)

t.Cleanup(func() {
if err := mig.Close(); err != nil {
t.Fatal(err)
}
})

// Start a migration with raw and regular SQL operations
err = mig.Start(ctx, &migrations.Migration{
Name: "01_create_table",
Operations: migrations.Operations{
createTableOp("table1"),
&migrations.OpRawSQL{
Up: "CREATE TABLE raw_sql_table (id integer)",
OnComplete: true,
},
},
})
assert.NoError(t, err)

// Complete the migration
err = mig.Complete(ctx)
assert.NoError(t, err)

// Ensure both tables were created
assert.True(t, tableExists(t, db, "public", "table1"))
assert.True(t, tableExists(t, db, "public", "raw_sql_table"))
assert.Equal(t, "rawsql", tableOwner(t, db, "public", "raw_sql_table"))
assert.Equal(t, "postgres", tableOwner(t, db, "public", "table1"))
})
}

func TestRollSchemaMethodReturnsCorrectSchema(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -708,6 +766,23 @@ func tableExists(t *testing.T, db *sql.DB, schema, table string) bool {
return exists
}

func tableOwner(t *testing.T, db *sql.DB, schema, table string) string {
t.Helper()

var owner string
err := db.QueryRow(`
SELECT tableowner
FROM pg_catalog.pg_tables
WHERE schemaname = $1
AND tablename = $2`,
schema, table).Scan(&owner)
if err != nil {
t.Fatal(err)
}

return owner
}

func ptr[T any](v T) *T {
return &v
}
13 changes: 13 additions & 0 deletions pkg/roll/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ type options struct {
// optional role to set before executing migrations
role string

// optional rawSQLURL to use for raw SQL operations
rawSQLURL string

// disable pgroll version schemas creation and deletion
disableVersionSchemas bool
migrationHooks MigrationHooks
Expand Down Expand Up @@ -55,3 +58,13 @@ func WithMigrationHooks(hooks MigrationHooks) Option {
o.migrationHooks = hooks
}
}

// WithRawSQLURL sets the postgres URL to use for raw SQL operations
// This is useful when the raw SQL operations need to be executed against
// a different endpoint than the main migration operations (ie with a different user or
// more security checks)
func WithRawSQLURL(rawSQLURL string) Option {
return func(o *options) {
o.rawSQLURL = rawSQLURL
}
}
61 changes: 44 additions & 17 deletions pkg/roll/roll.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ type PGVersion int
const PGVersion15 PGVersion = 15

type Roll struct {
pgConn *sql.DB // TODO abstract sql connection
pgConn *sql.DB

pgRawSQLConn *sql.DB

// schema we are acting on
schema string
Expand All @@ -31,11 +33,42 @@ type Roll struct {
}

func New(ctx context.Context, pgURL, schema string, state *state.State, opts ...Option) (*Roll, error) {
options := &options{}
rollOpts := &options{}
for _, o := range opts {
o(options)
o(rollOpts)
}

conn, err := setupConn(ctx, pgURL, schema, *rollOpts)
if err != nil {
return nil, err
}

var rawSQLConn *sql.DB
if rollOpts.rawSQLURL != "" {
rawSQLConn, err = setupConn(ctx, rollOpts.rawSQLURL, schema, options{})
if err != nil {
return nil, err
}
}

var pgMajorVersion PGVersion
err = conn.QueryRowContext(ctx, "SELECT split_part(split_part(version(), ' ', 2), '.', 1)").Scan(&pgMajorVersion)
if err != nil {
return nil, fmt.Errorf("unable to retrieve postgres version: %w", err)
}

return &Roll{
pgConn: conn,
pgRawSQLConn: rawSQLConn,
schema: schema,
state: state,
pgVersion: PGVersion(pgMajorVersion),
disableVersionSchemas: rollOpts.disableVersionSchemas,
migrationHooks: rollOpts.migrationHooks,
}, nil
}

func setupConn(ctx context.Context, pgURL, schema string, options options) (*sql.DB, error) {
dsn, err := pq.ParseURL(pgURL)
if err != nil {
dsn = pgURL
Expand Down Expand Up @@ -71,20 +104,7 @@ func New(ctx context.Context, pgURL, schema string, state *state.State, opts ...
}
}

var pgMajorVersion PGVersion
err = conn.QueryRowContext(ctx, "SELECT split_part(split_part(version(), ' ', 2), '.', 1)").Scan(&pgMajorVersion)
if err != nil {
return nil, fmt.Errorf("unable to retrieve postgres version: %w", err)
}

return &Roll{
pgConn: conn,
schema: schema,
state: state,
pgVersion: PGVersion(pgMajorVersion),
disableVersionSchemas: options.disableVersionSchemas,
migrationHooks: options.migrationHooks,
}, nil
return conn, nil
}

func (m *Roll) Init(ctx context.Context) error {
Expand Down Expand Up @@ -113,5 +133,12 @@ func (m *Roll) Close() error {
return err
}

if m.pgRawSQLConn != nil {
err = m.pgRawSQLConn.Close()
if err != nil {
return err
}
}

return m.pgConn.Close()
}
23 changes: 23 additions & 0 deletions pkg/testutils/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,29 @@ func TestSchema() string {
return "public"
}

func WithConnectionString(t *testing.T, schema string, fn func(st *state.State, connStr string)) {
t.Helper()
_, connStr, _ := setupTestDatabase(t)

ctx := context.Background()
st, err := state.New(ctx, connStr, schema)
if err != nil {
t.Fatal(err)
}

if err := st.Init(ctx); err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := st.Close(); err != nil {
t.Fatal(err)
}
})

fn(st, connStr)
}

func WithStateInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(*state.State, *sql.DB)) {
t.Helper()
ctx := context.Background()
Expand Down

0 comments on commit 8ed5799

Please sign in to comment.