Skip to content

Commit

Permalink
feat: allow additional SQL migrations (#3587)
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr committed Aug 3, 2023
1 parent dfb129a commit 8900cbb
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 32 deletions.
2 changes: 1 addition & 1 deletion cmd/cli/handler.go
Expand Up @@ -16,7 +16,7 @@ type Handler struct {

func NewHandler(slOpts []servicelocatorx.Option, dOpts []driver.OptionsModifier, cOpts []configx.OptionModifier) *Handler {
return &Handler{
Migration: newMigrateHandler(),
Migration: newMigrateHandler(slOpts, dOpts, cOpts),
Janitor: NewJanitorHandler(slOpts, dOpts, cOpts),
}
}
28 changes: 18 additions & 10 deletions cmd/cli/handler_migrate.go
Expand Up @@ -34,10 +34,18 @@ import (
"github.com/ory/x/flagx"
)

type MigrateHandler struct{}
type MigrateHandler struct {
slOpts []servicelocatorx.Option
dOpts []driver.OptionsModifier
cOpts []configx.OptionModifier
}

func newMigrateHandler() *MigrateHandler {
return &MigrateHandler{}
func newMigrateHandler(slOpts []servicelocatorx.Option, dOpts []driver.OptionsModifier, cOpts []configx.OptionModifier) *MigrateHandler {
return &MigrateHandler{
slOpts: slOpts,
dOpts: dOpts,
cOpts: cOpts,
}
}

const (
Expand Down Expand Up @@ -262,21 +270,21 @@ func (h *MigrateHandler) MigrateGen(cmd *cobra.Command, args []string) {
os.Exit(0)
}

func makePersister(cmd *cobra.Command, args []string) (p persistence.Persister, err error) {
func (h *MigrateHandler) makePersister(cmd *cobra.Command, args []string) (p persistence.Persister, err error) {
var d driver.Registry

if flagx.MustGetBool(cmd, "read-from-env") {
d, err = driver.New(
cmd.Context(),
servicelocatorx.NewOptions(),
[]driver.OptionsModifier{
append([]driver.OptionsModifier{
driver.WithOptions(
configx.SkipValidation(),
configx.WithFlags(cmd.Flags())),
driver.DisableValidation(),
driver.DisablePreloading(),
driver.SkipNetworkInit(),
})
}, h.dOpts...))
if err != nil {
return nil, err
}
Expand All @@ -292,7 +300,7 @@ func makePersister(cmd *cobra.Command, args []string) (p persistence.Persister,
d, err = driver.New(
cmd.Context(),
servicelocatorx.NewOptions(),
[]driver.OptionsModifier{
append([]driver.OptionsModifier{
driver.WithOptions(
configx.WithFlags(cmd.Flags()),
configx.SkipValidation(),
Expand All @@ -301,7 +309,7 @@ func makePersister(cmd *cobra.Command, args []string) (p persistence.Persister,
driver.DisableValidation(),
driver.DisablePreloading(),
driver.SkipNetworkInit(),
})
}, h.dOpts...))
if err != nil {
return nil, err
}
Expand All @@ -310,7 +318,7 @@ func makePersister(cmd *cobra.Command, args []string) (p persistence.Persister,
}

func (h *MigrateHandler) MigrateSQL(cmd *cobra.Command, args []string) (err error) {
p, err := makePersister(cmd, args)
p, err := h.makePersister(cmd, args)
if err != nil {
return err
}
Expand Down Expand Up @@ -360,7 +368,7 @@ func (h *MigrateHandler) MigrateSQL(cmd *cobra.Command, args []string) (err erro
}

func (h *MigrateHandler) MigrateStatus(cmd *cobra.Command, args []string) error {
p, err := makePersister(cmd, args)
p, err := h.makePersister(cmd, args)
if err != nil {
return err
}
Expand Down
2 changes: 2 additions & 0 deletions cmd/migrate_status.go
Expand Up @@ -4,6 +4,7 @@
package cmd

import (
"github.com/ory/x/cmdx"
"github.com/ory/x/configx"
"github.com/ory/x/servicelocatorx"

Expand All @@ -20,6 +21,7 @@ func NewMigrateStatusCmd(slOpts []servicelocatorx.Option, dOpts []driver.Options
RunE: cli.NewHandler(slOpts, dOpts, cOpts).Migration.MigrateStatus,
}

cmdx.RegisterFormatFlags(cmd.PersistentFlags())
cmd.Flags().BoolP("read-from-env", "e", false, "If set, reads the database connection string from the environment variable DSN or config file key dsn.")
cmd.Flags().Bool("block", false, "Block until all migrations have been applied")

Expand Down
17 changes: 13 additions & 4 deletions driver/factory.go
Expand Up @@ -5,6 +5,7 @@ package driver

import (
"context"
"io/fs"

"github.com/ory/hydra/v2/driver/config"
"github.com/ory/x/configx"
Expand All @@ -22,7 +23,10 @@ type (
// The first default refers to determining the NID at startup; the second default referes to the fact that the Contextualizer may dynamically change the NID.
skipNetworkInit bool
tracerWrapper TracerWrapper
extraMigrations []fs.FS
}
OptionsModifier func(*options)

TracerWrapper func(*otelx.Tracer) *otelx.Tracer
)

Expand All @@ -34,14 +38,12 @@ func newOptions() *options {
}
}

func WithConfig(config *config.DefaultProvider) func(o *options) {
func WithConfig(config *config.DefaultProvider) OptionsModifier {
return func(o *options) {
o.config = config
}
}

type OptionsModifier func(*options)

func WithOptions(opts ...configx.OptionModifier) OptionsModifier {
return func(o *options) {
o.opts = append(o.opts, opts...)
Expand Down Expand Up @@ -77,6 +79,13 @@ func WithTracerWrapper(wrapper TracerWrapper) OptionsModifier {
}
}

// WithExtraMigrations specifies additional database migration.
func WithExtraMigrations(m ...fs.FS) OptionsModifier {
return func(o *options) {
o.extraMigrations = append(o.extraMigrations, m...)
}
}

func New(ctx context.Context, sl *servicelocatorx.Options, opts []OptionsModifier) (Registry, error) {
o := newOptions()
for _, f := range opts {
Expand Down Expand Up @@ -115,7 +124,7 @@ func New(ctx context.Context, sl *servicelocatorx.Options, opts []OptionsModifie
r.WithTracerWrapper(o.tracerWrapper)
}

if err = r.Init(ctx, o.skipNetworkInit, false, ctxter); err != nil {
if err = r.Init(ctx, o.skipNetworkInit, false, ctxter, o.extraMigrations); err != nil {
l.WithError(err).Error("Unable to initialize service registry.")
return nil, err
}
Expand Down
5 changes: 3 additions & 2 deletions driver/registry.go
Expand Up @@ -5,6 +5,7 @@ package driver

import (
"context"
"io/fs"
"net/http"

"go.opentelemetry.io/otel/trace"
Expand Down Expand Up @@ -44,7 +45,7 @@ import (
type Registry interface {
dbal.Driver

Init(ctx context.Context, skipNetworkInit bool, migrate bool, ctxer contextx.Contextualizer) error
Init(ctx context.Context, skipNetworkInit bool, migrate bool, ctxer contextx.Contextualizer, extraMigrations []fs.FS) error

WithBuildInfo(v, h, d string) Registry
WithConfig(c *config.DefaultProvider) Registry
Expand Down Expand Up @@ -89,7 +90,7 @@ func NewRegistryFromDSN(ctx context.Context, c *config.DefaultProvider, l *logru
if err != nil {
return nil, err
}
if err := registry.Init(ctx, skipNetworkInit, migrate, ctxer); err != nil {
if err := registry.Init(ctx, skipNetworkInit, migrate, ctxer, nil); err != nil {
return nil, err
}
return registry, nil
Expand Down
2 changes: 1 addition & 1 deletion driver/registry_base_test.go
Expand Up @@ -67,7 +67,7 @@ func TestRegistryBase_newKeyStrategy_handlesNetworkError(t *testing.T) {
r := registry.(*RegistrySQL)
r.initialPing = failedPing(errors.New("snizzles"))

_ = r.Init(context.Background(), true, false, &contextx.TestContextualizer{})
_ = r.Init(context.Background(), true, false, &contextx.TestContextualizer{}, nil)

registryBase := RegistryBase{r: r, l: l}
registryBase.WithConfig(c)
Expand Down
9 changes: 7 additions & 2 deletions driver/registry_sql.go
Expand Up @@ -5,6 +5,7 @@ package driver

import (
"context"
"io/fs"
"strings"
"time"

Expand Down Expand Up @@ -64,7 +65,11 @@ func NewRegistrySQL() *RegistrySQL {
}

func (m *RegistrySQL) Init(
ctx context.Context, skipNetworkInit bool, migrate bool, ctxer contextx.Contextualizer,
ctx context.Context,
skipNetworkInit bool,
migrate bool,
ctxer contextx.Contextualizer,
extraMigrations []fs.FS,
) error {
if m.persister == nil {
m.WithContextualizer(ctxer)
Expand Down Expand Up @@ -100,7 +105,7 @@ func (m *RegistrySQL) Init(
return errorsx.WithStack(err)
}

p, err := sql.NewPersister(ctx, c, m, m.Config(), m.l)
p, err := sql.NewPersister(ctx, c, m, m.Config(), extraMigrations)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion driver/registry_sql_test.go
Expand Up @@ -31,7 +31,7 @@ func TestDefaultKeyManager_HsmDisabled(t *testing.T) {
reg, err := NewRegistryWithoutInit(c, l)
r := reg.(*RegistrySQL)
r.initialPing = sussessfulPing()
if err := r.Init(context.Background(), true, false, &contextx.Default{}); err != nil {
if err := r.Init(context.Background(), true, false, &contextx.Default{}, nil); err != nil {
t.Fatalf("unable to init registry: %s", err)
}
assert.NoError(t, err)
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Expand Up @@ -40,10 +40,10 @@ require (
github.com/ory/fosite v0.44.1-0.20230704083823-8098e48b2e09
github.com/ory/go-acc v0.2.9-0.20230103102148-6b1c9a70dbbe
github.com/ory/graceful v0.1.3
github.com/ory/herodot v0.10.2
github.com/ory/herodot v0.10.3-0.20230626083119-d7e5192f0d88
github.com/ory/hydra-client-go/v2 v2.1.1
github.com/ory/jsonschema/v3 v3.0.8
github.com/ory/x v0.0.567
github.com/ory/x v0.0.574
github.com/pborman/uuid v1.2.1
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.13.0
Expand Down
9 changes: 5 additions & 4 deletions go.sum
Expand Up @@ -549,6 +549,7 @@ github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/laher/mergefs v0.1.1 h1:nV2bTS57vrmbMxeR6uvJpI8LyGl3QHj4bLBZO3aUV58=
github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
Expand Down Expand Up @@ -653,12 +654,12 @@ github.com/ory/go-convenience v0.1.0 h1:zouLKfF2GoSGnJwGq+PE/nJAE6dj2Zj5QlTgmMTs
github.com/ory/go-convenience v0.1.0/go.mod h1:uEY/a60PL5c12nYz4V5cHY03IBmwIAEm8TWB0yn9KNs=
github.com/ory/graceful v0.1.3 h1:FaeXcHZh168WzS+bqruqWEw/HgXWLdNv2nJ+fbhxbhc=
github.com/ory/graceful v0.1.3/go.mod h1:4zFz687IAF7oNHHiB586U4iL+/4aV09o/PYLE34t2bA=
github.com/ory/herodot v0.10.2 h1:gGvNMHgAwWzdP/eo+roSiT5CGssygHSjDU7MSQNlJ4E=
github.com/ory/herodot v0.10.2/go.mod h1:MMNmY6MG1uB6fnXYFaHoqdV23DTWctlPsmRCeq/2+wc=
github.com/ory/herodot v0.10.3-0.20230626083119-d7e5192f0d88 h1:J0CIFKdpUeqKbVMw7pQ1qLtUnflRM1JWAcOEq7Hp4yg=
github.com/ory/herodot v0.10.3-0.20230626083119-d7e5192f0d88/go.mod h1:MMNmY6MG1uB6fnXYFaHoqdV23DTWctlPsmRCeq/2+wc=
github.com/ory/jsonschema/v3 v3.0.8 h1:Ssdb3eJ4lDZ/+XnGkvQS/te0p+EkolqwTsDOCxr/FmU=
github.com/ory/jsonschema/v3 v3.0.8/go.mod h1:ZPzqjDkwd3QTnb2Z6PAS+OTvBE2x5i6m25wCGx54W/0=
github.com/ory/x v0.0.567 h1:oUj75hIqBv3ESsmIwc/4u8jaD2zSx/HTGzRnfMKUykg=
github.com/ory/x v0.0.567/go.mod h1:g0QdN0Z47vdCYtfrTQkgWJdIOPuez9VGiuQivLxa4lo=
github.com/ory/x v0.0.574 h1:JjdOP6iIh4ngoR1zDxaZL9bsBzIAyvw0aZdqSfJOEVI=
github.com/ory/x v0.0.574/go.mod h1:aeJFTlvDLGYSABzPS3z5SeLcYC52Ek7uGZiuYGcTMSU=
github.com/pborman/uuid v1.2.1 h1:+ZZIw58t/ozdjRaXh/3awHfmWRbzYxJoAdNJxe/3pvw=
github.com/pborman/uuid v1.2.1/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
github.com/pelletier/go-toml v1.7.0/go.mod h1:vwGMzjaWMwyfHwgIBhI2YUM4fB6nL6lVAvS1LBMMhTE=
Expand Down
2 changes: 1 addition & 1 deletion hsm/manager_hsm_test.go
Expand Up @@ -52,7 +52,7 @@ func TestDefaultKeyManager_HSMEnabled(t *testing.T) {
reg.WithLogger(l)
reg.WithConfig(c)
reg.WithHsmContext(mockHsmContext)
err := reg.Init(context.Background(), false, true, &contextx.TestContextualizer{})
err := reg.Init(context.Background(), false, true, &contextx.TestContextualizer{}, nil)
assert.NoError(t, err)
assert.IsType(t, &jwk.ManagerStrategy{}, reg.KeyManager())
assert.IsType(t, &sql.Persister{}, reg.SoftwareKeyManager())
Expand Down
9 changes: 5 additions & 4 deletions persistence/sql/persister.go
Expand Up @@ -6,11 +6,11 @@ package sql
import (
"context"
"database/sql"
"io/fs"
"reflect"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"

"github.com/pkg/errors"

"github.com/ory/fosite"
Expand All @@ -21,6 +21,7 @@ import (
"github.com/ory/hydra/v2/x"
"github.com/ory/x/contextx"
"github.com/ory/x/errorsx"
"github.com/ory/x/fsx"
"github.com/ory/x/logrusx"
"github.com/ory/x/networkx"
"github.com/ory/x/otelx"
Expand Down Expand Up @@ -104,8 +105,8 @@ func (p *Persister) Rollback(ctx context.Context) (err error) {
return errorsx.WithStack(tx.TX.Rollback())
}

func NewPersister(ctx context.Context, c *pop.Connection, r Dependencies, config *config.DefaultProvider, l *logrusx.Logger) (*Persister, error) {
mb, err := popx.NewMigrationBox(migrations, popx.NewMigrator(c, r.Logger(), r.Tracer(ctx), 0))
func NewPersister(ctx context.Context, c *pop.Connection, r Dependencies, config *config.DefaultProvider, extraMigrations []fs.FS) (*Persister, error) {
mb, err := popx.NewMigrationBox(fsx.Merge(append([]fs.FS{migrations}, extraMigrations...)...), popx.NewMigrator(c, r.Logger(), r.Tracer(ctx), 0))
if err != nil {
return nil, errorsx.WithStack(err)
}
Expand All @@ -115,7 +116,7 @@ func NewPersister(ctx context.Context, c *pop.Connection, r Dependencies, config
mb: mb,
r: r,
config: config,
l: l,
l: r.Logger(),
p: networkx.NewManager(c, r.Logger(), r.Tracer(ctx)),
}, nil
}
Expand Down

0 comments on commit 8900cbb

Please sign in to comment.