Skip to content

Commit

Permalink
fix: collect go migrations (#588)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman committed Aug 27, 2023
1 parent 958c950 commit 7011525
Show file tree
Hide file tree
Showing 4 changed files with 320 additions and 58 deletions.
2 changes: 1 addition & 1 deletion create.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func CreateWithTemplate(db *sql.DB, dir string, tmpl *template.Template, name, m

if sequential {
// always use DirFS here because it's modifying operation
migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion)
migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion, registeredGoMigrations)
if err != nil && !errors.Is(err, ErrNoMigrationFiles) {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion fix.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const seqVersionTemplate = "%05v"

func Fix(dir string) error {
// always use osFS here because it's modifying operation
migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion)
migrations, err := collectMigrationsFS(osFS{}, dir, minVersion, maxVersion, registeredGoMigrations)
if err != nil {
return err
}
Expand Down
167 changes: 111 additions & 56 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,17 +236,19 @@ func register(
return nil
}

func collectMigrationsFS(fsys fs.FS, dirpath string, current, target int64) (Migrations, error) {
func collectMigrationsFS(
fsys fs.FS,
dirpath string,
current, target int64,
registered map[int64]*Migration,
) (Migrations, error) {
if _, err := fs.Stat(fsys, dirpath); err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("%s directory does not exist", dirpath)
}

return nil, err
}

var migrations Migrations

// SQL migration files.
sqlMigrationFiles, err := fs.Glob(fsys, path.Join(dirpath, "*.sql"))
if err != nil {
Expand All @@ -258,68 +260,30 @@ func collectMigrationsFS(fsys fs.FS, dirpath string, current, target int64) (Mig
return nil, fmt.Errorf("could not parse SQL migration file %q: %w", file, err)
}
if versionFilter(v, current, target) {
migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file}
migrations = append(migrations, migration)
migrations = append(migrations, &Migration{
Version: v,
Next: -1,
Previous: -1,
Source: file,
})
}
}

// Go migration files
fsGoMigrations := map[int64]*Migration{}
goMigrationFiles, err := fs.Glob(fsys, path.Join(dirpath, "*.go"))
// Go migration files.
goMigrations, err := collectGoMigrations(fsys, dirpath, registered, current, target)
if err != nil {
return nil, err
}
for _, file := range goMigrationFiles {
v, err := NumericComponent(file)
if err != nil {
continue // Skip any files that don't have version prefix.
}

if strings.HasSuffix(file, "_test.go") {
continue // Skip Go test files.
}

if versionFilter(v, current, target) {
migration := &Migration{Version: v, Next: -1, Previous: -1, Source: file, Registered: false}
fsGoMigrations[v] = migration
}
}

// Go migrations registered via goose.AddMigration().
for _, migration := range registeredGoMigrations {
v, err := NumericComponent(migration.Source)
if err != nil {
return nil, fmt.Errorf("could not parse go migration file %q: %w", migration.Source, err)
}
if !versionFilter(v, current, target) {
continue
}
if _, ok := fsGoMigrations[v]; ok {
migrations = append(migrations, migration)
}
}

for _, fsMigration := range fsGoMigrations {
// Skip migrations already existing migrations registered via goose.AddMigration().
if _, ok := registeredGoMigrations[fsMigration.Version]; ok {
continue
}
migrations = append(migrations, fsMigration)
}

migrations = append(migrations, goMigrations...)
if len(migrations) == 0 {
return nil, ErrNoMigrationFiles
}

migrations = sortAndConnectMigrations(migrations)

return migrations, nil
return sortAndConnectMigrations(migrations), nil
}

// CollectMigrations returns all the valid looking migration scripts in the
// migrations folder and go func registry, and key them by version.
func CollectMigrations(dirpath string, current, target int64) (Migrations, error) {
return collectMigrationsFS(baseFS, dirpath, current, target)
return collectMigrationsFS(baseFS, dirpath, current, target, registeredGoMigrations)
}

func sortAndConnectMigrations(migrations Migrations) Migrations {
Expand All @@ -340,15 +304,12 @@ func sortAndConnectMigrations(migrations Migrations) Migrations {
}

func versionFilter(v, current, target int64) bool {

if target > current {
return v > current && v <= target
}

if target < current {
return v <= current && v > target
}

return false
}

Expand Down Expand Up @@ -451,3 +412,97 @@ func withoutContext[T any](fn func(context.Context, T) error) func(T) error {
return fn(context.Background(), t)
}
}

// collectGoMigrations collects Go migrations from the filesystem and merges them with registered
// migrations.
//
// If Go migrations have been registered globally, with [goose.AddNamedMigration...], but there are
// no corresponding .go files in the filesystem, add them to the migrations slice.
//
// If Go migrations have been registered, and there are .go files in the filesystem dirpath, ONLY
// include those in the migrations slices.
//
// Lastly, if there are .go files in the filesystem but they have not been registered, raise an
// error. This is to prevent users from accidentally adding valid looking Go files to the migrations
// folder without registering them.
func collectGoMigrations(
fsys fs.FS,
dirpath string,
registeredGoMigrations map[int64]*Migration,
current, target int64,
) (Migrations, error) {
// Sanity check registered migrations have the correct version prefix.
for _, m := range registeredGoMigrations {
if _, err := NumericComponent(m.Source); err != nil {
return nil, fmt.Errorf("could not parse go migration file %s: %w", m.Source, err)
}
}
goFiles, err := fs.Glob(fsys, path.Join(dirpath, "*.go"))
if err != nil {
return nil, err
}
// If there are no Go files in the filesystem and no registered Go migrations, return early.
if len(goFiles) == 0 && len(registeredGoMigrations) == 0 {
return nil, nil
}
type source struct {
fullpath string
version int64
}
// Find all Go files that have a version prefix and are within the requested range.
var sources []source
for _, fullpath := range goFiles {
v, err := NumericComponent(fullpath)
if err != nil {
continue // Skip any files that don't have version prefix.
}
if strings.HasSuffix(fullpath, "_test.go") {
continue // Skip Go test files.
}
if versionFilter(v, current, target) {
sources = append(sources, source{
fullpath: fullpath,
version: v,
})
}
}
var (
migrations Migrations
)
if len(sources) > 0 {
for _, s := range sources {
migration, ok := registeredGoMigrations[s.version]
if ok {
migrations = append(migrations, migration)
} else {
// TODO(mf): something that bothers me about this implementation is it will be
// lazily evaluated and the error will only be raised if the user tries to run the
// migration. It would be better to raise an error much earlier in the process.
migrations = append(migrations, &Migration{
Version: s.version,
Next: -1,
Previous: -1,
Source: s.fullpath,
Registered: false,
})
}
}
} else {
// Some users may register Go migrations manually via AddNamedMigration_ functions but not
// provide the corresponding .go files in the filesystem. In this case, we include them
// wholesale in the migrations slice.
//
// This is a valid use case because users may want to build a custom binary that only embeds
// the SQL migration files and some other mechanism for registering Go migrations.
for _, migration := range registeredGoMigrations {
v, err := NumericComponent(migration.Source)
if err != nil {
return nil, fmt.Errorf("could not parse go migration file %s: %w", migration.Source, err)
}
if versionFilter(v, current, target) {
migrations = append(migrations, migration)
}
}
}
return migrations, nil
}

0 comments on commit 7011525

Please sign in to comment.