-
Notifications
You must be signed in to change notification settings - Fork 104
/
migrate.go
139 lines (119 loc) · 3.64 KB
/
migrate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package postgres
import (
"context"
"embed"
"fmt"
"path"
"strconv"
"strings"
)
// Embed migrations directory in the binary
//
//go:embed migrations/*.sql
var migrationsFS embed.FS
// Fixed advisory lock number to prevent concurrent migrations.
var migrationLockNumber = int64(5103805673824918) // random number
// Name of the table that tracks migrations.
var migrationVersionTable = "admin_migration_version"
// FindMigrationVersion returns the current migration version
func (c *connection) FindMigrationVersion(ctx context.Context) (int, error) {
var version int
err := c.db.QueryRowxContext(ctx, fmt.Sprintf("select version from %s", migrationVersionTable)).Scan(&version)
if err != nil {
if strings.Contains(err.Error(), "does not exist") {
return 0, nil
}
return 0, err
}
return version, nil
}
// Migrate runs migrations. It's safe for concurrent invocations.
// Adapted from: https://github.com/jackc/tern
func (c *connection) Migrate(ctx context.Context) (err error) {
// Acquire advisory lock
_, err = c.db.ExecContext(ctx, "select pg_advisory_lock($1)", migrationLockNumber)
if err != nil {
return err
}
defer func() {
// Release advisory lock when this function returns
_, unlockErr := c.db.ExecContext(ctx, "select pg_advisory_unlock($1)", migrationLockNumber)
if err == nil && unlockErr != nil {
err = unlockErr
}
}()
// Check if migrationVersionTable exists
var exists int
err = c.db.QueryRowContext(ctx, "select count(*) from pg_catalog.pg_class where relname=$1 and relkind='r' and pg_table_is_visible(oid)", migrationVersionTable).Scan(&exists)
if err != nil {
return err
}
// Create migrationVersionTable if it doesn't exist
if exists == 0 {
_, err = c.db.ExecContext(ctx, fmt.Sprintf("create table if not exists %s(version int4 not null)", migrationVersionTable))
if err != nil {
return err
}
// Set the version to 0 if table is empty (note: defensive coding, table should always be empty)
_, err = c.db.ExecContext(ctx, fmt.Sprintf("insert into %s(version) select 0 where 0=(select count(*) from %s)", migrationVersionTable, migrationVersionTable))
if err != nil {
return err
}
}
// Get version of latest migration
var currentVersion int
err = c.db.QueryRowContext(ctx, fmt.Sprintf("select version from %s", migrationVersionTable)).Scan(¤tVersion)
if err != nil {
return err
}
// Iterate over migrations (sorted by filename)
files, err := migrationsFS.ReadDir("migrations")
if err != nil {
return err
}
for _, file := range files {
// Extract version number from filename
version, err := strconv.Atoi(strings.TrimSuffix(file.Name(), ".sql"))
if err != nil {
return fmt.Errorf("unexpected migration filename: %s", file.Name())
}
// Skip migrations below current version
if version <= currentVersion {
continue
}
// Read SQL
sql, err := migrationsFS.ReadFile(path.Join("migrations", file.Name()))
if err != nil {
return err
}
err = c.migrateSingle(ctx, file.Name(), sql, version)
if err != nil {
return err
}
}
return nil
}
func (c *connection) migrateSingle(ctx context.Context, name string, sql []byte, version int) (err error) {
// Start a transaction
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
// Run migration
_, err = tx.ExecContext(ctx, string(sql))
if err != nil {
return fmt.Errorf("failed to run migration '%s': %w", name, err)
}
// Update migration version
_, err = tx.ExecContext(ctx, fmt.Sprintf("UPDATE %s SET version=$1", migrationVersionTable), version)
if err != nil {
return err
}
// Commit migration
err = tx.Commit()
if err != nil {
return err
}
return nil
}