Skip to content
This repository has been archived by the owner on Apr 2, 2024. It is now read-only.

Commit

Permalink
Add tests for migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
cevian committed Aug 7, 2020
1 parent 98b9f0d commit b13cf55
Show file tree
Hide file tree
Showing 14 changed files with 471 additions and 72 deletions.
105 changes: 105 additions & 0 deletions pkg/pgmodel/end_to_end_tests/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ package end_to_end_tests

import (
"context"
"reflect"
"testing"

"github.com/blang/semver/v4"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/timescale/timescale-prometheus/pkg/internal/testhelpers"
"github.com/timescale/timescale-prometheus/pkg/pgmodel"
"github.com/timescale/timescale-prometheus/pkg/pgmodel/test_migrations"
)

const (
Expand Down Expand Up @@ -51,3 +54,105 @@ func TestMigrateTwice(t *testing.T) {
}
})
}

func verifyLogs(t testing.TB, db *pgxpool.Pool, expected []string) {
rows, err := db.Query(context.Background(), "SELECT msg FROM log ORDER BY id")
if err != nil {
t.Fatal(err)
}

found := make([]string, 0)
for rows.Next() {
var value string
err = rows.Scan(&value)
if err != nil {
t.Fatal(err)
}
found = append(found, value)
}
if !reflect.DeepEqual(expected, found) {
t.Errorf("wrong values in DB\nexpected:\n\t%v\ngot:\n\t%v", expected, found)
}
}

func TestMigrationLib(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
testhelpers.WithDB(t, *testDatabase, testhelpers.NoSuperuser, func(db *pgxpool.Pool, t testing.TB, connectURL string) {
testTOC := map[string][]string{
"idempotent": {
"2-toc-run_first.sql",
"1-toc-run_second.sql",
},
"versions/0.1.0": {
"1-migration.sql",
},
"versions/0.2.0": {
"1-migration.sql",
},
"versions/0.10.0": {
"2-toc_migration.sql",
"1-toc_migration.sql",
},
}

expected := []string{
"setup",
"idempotent 1",
"idempotent 2",
"migration 0.2.0",
"idempotent 1",
"idempotent 2",
"idempotent 1",
"idempotent 2",
"migration 0.10.0=1",
"migration 0.10.0=2",
"idempotent 1",
"idempotent 2",
}

mig := pgmodel.NewMigrator(db, test_migrations.MigrationFiles, testTOC)

err := mig.Migrate(semver.MustParse("0.1.1"))
if err != nil {
t.Fatal(err)
}

verifyLogs(t, db, expected[0:3])

//does nothing
err = mig.Migrate(semver.MustParse("0.1.1"))
if err != nil {
t.Fatal(err)
}

verifyLogs(t, db, expected[0:3])

err = mig.Migrate(semver.MustParse("0.2.0"))
if err != nil {
t.Fatal(err)
}
verifyLogs(t, db, expected[0:6])

//does nothing
err = mig.Migrate(semver.MustParse("0.2.0"))
if err != nil {
t.Fatal(err)
}
verifyLogs(t, db, expected[0:6])

//even if no version upgrades, idempotent files apply
err = mig.Migrate(semver.MustParse("0.9.0"))
if err != nil {
t.Fatal(err)
}
verifyLogs(t, db, expected[0:8])

err = mig.Migrate(semver.MustParse("0.10.0"))
if err != nil {
t.Fatal(err)
}
verifyLogs(t, db, expected[0:12])
})
}
77 changes: 44 additions & 33 deletions pkg/pgmodel/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,12 @@ const (
preinstallScripts = "preinstall"
versionScripts = "versions"
idempotentScripts = "idempotent"

// TODO: once we have a schema upgrade version for migration scripts,
// add tests and remove this flag
upgradeUntested = true
)

var (
ExtensionIsInstalled = false

toc = map[string][]string{
tableOfContets = map[string][]string{
"idempotent": {
"base.sql",
"matcher-functions.sql",
Expand Down Expand Up @@ -106,20 +102,43 @@ func Migrate(db *pgxpool.Pool, versionInfo VersionInfo) (err error) {
if err != nil {
return fmt.Errorf("app version is not semver format, aborting migration")
}
if err := ensureVersionTable(db); err != nil {

mig := NewMigrator(db, migrations.MigrationFiles, tableOfContets)

err = mig.Migrate(appVersion)
if err != nil {
return fmt.Errorf("Error encountered during migration: %w", err)
}

installExtension(conn)

metadataUpdate(db, ExtensionIsInstalled, "version", versionInfo.Version)
metadataUpdate(db, ExtensionIsInstalled, "commit_hash", versionInfo.CommitHash)
return nil
}

type Migrator struct {
db *pgxpool.Pool
sqlFiles http.FileSystem
toc map[string][]string
}

func NewMigrator(db *pgxpool.Pool, sqlFiles http.FileSystem, toc map[string][]string) *Migrator {
return &Migrator{db: db, sqlFiles: sqlFiles, toc: toc}
}

func (t *Migrator) Migrate(appVersion semver.Version) error {
if err := ensureVersionTable(t.db); err != nil {
return fmt.Errorf("error ensuring version table: %w", err)
}
dbVersion, err := getDBVersion(db)

dbVersion, err := getDBVersion(t.db)
if err != nil {
return fmt.Errorf("failed to get the version from database: %w", err)
}

// If already at correct version, nothing to migrate.
if dbVersion.Compare(appVersion) == 0 {
installExtension(conn)

metadataUpdate(db, ExtensionIsInstalled, "version", versionInfo.Version)
metadataUpdate(db, ExtensionIsInstalled, "commit_hash", versionInfo.CommitHash)
return nil
}

Expand All @@ -128,7 +147,7 @@ func Migrate(db *pgxpool.Pool, versionInfo VersionInfo) (err error) {
return fmt.Errorf("schema version is above the application version, cannot migrate")
}

tx, err := db.Begin(context.Background())
tx, err := t.db.Begin(context.Background())
if err != nil {
return fmt.Errorf("unable to start transaction: %w", err)
}
Expand All @@ -143,13 +162,13 @@ func Migrate(db *pgxpool.Pool, versionInfo VersionInfo) (err error) {

// No version in DB.
if dbVersion.Compare(semver.Version{}) == 0 {
if err = execMigrationFiles(tx, preinstallScripts); err != nil {
if err = t.execMigrationFiles(tx, preinstallScripts); err != nil {
return err
}
} else if err = upgradeVersion(tx, dbVersion, appVersion); err != nil {
} else if err = t.upgradeVersion(tx, dbVersion, appVersion); err != nil {
return err
}
if err = execMigrationFiles(tx, idempotentScripts); err != nil {
if err = t.execMigrationFiles(tx, idempotentScripts); err != nil {
return err
}
if err = setDBVersion(tx, &appVersion); err != nil {
Expand All @@ -160,11 +179,6 @@ func Migrate(db *pgxpool.Pool, versionInfo VersionInfo) (err error) {
return fmt.Errorf("unable to commit migration transaction: %w", err)
}

installExtension(conn)

metadataUpdate(db, ExtensionIsInstalled, "version", versionInfo.Version)
metadataUpdate(db, ExtensionIsInstalled, "commit_hash", versionInfo.CommitHash)

return nil
}

Expand All @@ -182,24 +196,24 @@ func getDBVersion(db *pgxpool.Pool) (semver.Version, error) {
res, err := db.Query(context.Background(), getVersion)

if err != nil {
return version, err
return version, fmt.Errorf("Error getting DB version: %w", err)
}

for res.Next() {
err = res.Scan(&version)
}

if err != nil {
return version, err
return version, fmt.Errorf("Error getting DB version: %w", err)
}

return version, nil
}

// execMigrationFiles finds all the migration files in a directory, orders them
// (either by ToC or by their numerical prefix) and executes them in a transaction.
func execMigrationFiles(tx pgx.Tx, dirName string) error {
f, err := migrations.MigrationFiles.Open(dirName)
func (t *Migrator) execMigrationFiles(tx pgx.Tx, dirName string) error {
f, err := t.sqlFiles.Open(dirName)
if err != nil {
return fmt.Errorf("unable to get migration scripts: name %s, err %w", dirName, err)
}
Expand All @@ -210,12 +224,12 @@ func execMigrationFiles(tx pgx.Tx, dirName string) error {
file http.File
)

if myToC, ok := toc[dirName]; ok {
if myToC, ok := t.toc[dirName]; ok {
// If exists, use ToC to order the migration files before executing them.
entries = make([]string, 0, len(myToC))
for _, fileName := range myToC {
fullName := filepath.Join(dirName, fileName)
file, err = migrations.MigrationFiles.Open(fullName)
file, err = t.sqlFiles.Open(fullName)
if err != nil {
return fmt.Errorf("unable to get migration script from toc: name %s, err %w", fullName, err)
}
Expand Down Expand Up @@ -244,7 +258,7 @@ func execMigrationFiles(tx pgx.Tx, dirName string) error {

for _, e := range entries {
fileName := filepath.Join(dirName, e)
f, err := migrations.MigrationFiles.Open(fileName)
f, err := t.sqlFiles.Open(fileName)
if err != nil {
return fmt.Errorf("unable to get migration script: name %s, err %w", fileName, err)
}
Expand Down Expand Up @@ -314,11 +328,8 @@ func replaceSchemaNames(r io.ReadCloser) (string, error) {

// upgradeVersion finds all the versions between `from` and `to`, sorts them
// using semantic version ordering and applies them sequentially in the supplied transaction.
func upgradeVersion(tx pgx.Tx, from, to semver.Version) error {
if upgradeUntested {
return fmt.Errorf("unsupported version upgrade reached, aborting migration")
}
f, err := migrations.MigrationFiles.Open(versionScripts)
func (t *Migrator) upgradeVersion(tx pgx.Tx, from, to semver.Version) error {
f, err := t.sqlFiles.Open(versionScripts)
if err != nil {
return fmt.Errorf("unable to open migration scripts: %w", err)
}
Expand Down Expand Up @@ -352,7 +363,7 @@ func upgradeVersion(tx pgx.Tx, from, to semver.Version) error {
}

if from.Compare(version) < 0 && to.Compare(version) >= 0 {
if err = execMigrationFiles(tx, filepath.Join("versions", e.Name())); err != nil {
if err = t.execMigrationFiles(tx, filepath.Join("versions", e.Name())); err != nil {
return err
}
}
Expand Down
41 changes: 2 additions & 39 deletions pkg/pgmodel/migrations/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,49 +11,12 @@ package main
import (
"log"
"net/http"
"os"
"time"

"github.com/shurcooL/vfsgen"
"github.com/timescale/timescale-prometheus/pkg/pgmodel/migrations"
)

// modTimeFS is an http.FileSystem wrapper that modifies
// underlying fs such that all of its file mod times are set to zero.
type modTimeFS struct {
fs http.FileSystem
}

func (fs modTimeFS) Open(name string) (http.File, error) {
f, err := fs.fs.Open(name)
if err != nil {
return nil, err
}
return modTimeFile{f}, nil
}

type modTimeFile struct {
http.File
}

func (f modTimeFile) Stat() (os.FileInfo, error) {
fi, err := f.File.Stat()
if err != nil {
return nil, err
}
return modTimeFileInfo{fi}, nil
}

type modTimeFileInfo struct {
os.FileInfo
}

func (modTimeFileInfo) ModTime() time.Time {
return time.Time{}
}

var Assets http.FileSystem = modTimeFS{
fs: http.Dir("sql"),
}
var Assets http.FileSystem = migrations.NewModTimeFs(http.Dir("sql"))

func main() {
err := vfsgen.Generate(Assets, vfsgen.Options{
Expand Down

0 comments on commit b13cf55

Please sign in to comment.