-
Notifications
You must be signed in to change notification settings - Fork 111
/
migrate.go
149 lines (127 loc) · 3.99 KB
/
migrate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package postgres
import (
"context"
"embed"
"fmt"
"path"
"strconv"
"strings"
)
// Embed migrations directory in the binary
//
//go:embed migrations/*.sql
var migrationsFS embed.FS
// Fixed advisory lock number to prevent concurrent migrations
var migrationLockNumber = int64(8294273491672920) // random number
// Name of the table that tracks migrations
var migrationVersionTable = "runtime_migration_version"
// Migrate implements drivers.Connection.
// Migrate for Postgres is safe for concurrent invocations.
// Adapted from: https://github.com/jackc/tern
func (c *connection) Migrate(ctx context.Context) (err error) {
// Acquire advisory lock
_, err = c.db.ExecContext(ctx, "select pg_advisory_lock($1)", migrationLockNumber)
if err != nil {
return err
}
defer func() {
// Release advisory lock when this function returns
_, unlockErr := c.db.ExecContext(ctx, "select pg_advisory_unlock($1)", migrationLockNumber)
if err == nil && unlockErr != nil {
err = unlockErr
}
}()
// Check if migrationVersionTable exists
// (Not doing "create table if not exists" to prevent unnecessary privileges)
var exists int
err = c.db.QueryRowContext(ctx, "select count(*) from pg_catalog.pg_class where relname=$1 and relkind='r' and pg_table_is_visible(oid)", migrationVersionTable).Scan(&exists)
if err != nil {
return err
}
// Create migrationVersionTable if it doesn't exist
if exists == 0 {
_, err = c.db.ExecContext(ctx, fmt.Sprintf("create table if not exists %s(version int4 not null)", migrationVersionTable))
if err != nil {
return err
}
// Set the version to 0 if table is empty (note: defensive coding, table should always be empty)
_, err = c.db.ExecContext(ctx, fmt.Sprintf("insert into %s(version) select 0 where 0=(select count(*) from %s)", migrationVersionTable, migrationVersionTable))
if err != nil {
return err
}
}
// Get version of latest migration
var currentVersion int
err = c.db.QueryRowContext(ctx, fmt.Sprintf("select version from %s", migrationVersionTable)).Scan(¤tVersion)
if err != nil {
return err
}
// Iterate over migrations (sorted by filename)
files, err := migrationsFS.ReadDir("migrations")
if err != nil {
return err
}
for _, file := range files {
// Extract version number from filename
version, err := migrationFilenameToVersion(file.Name())
if err != nil {
return fmt.Errorf("unexpected migration filename: %s", file.Name())
}
// Skip migrations below current version
if version <= currentVersion {
continue
}
// Read SQL
sql, err := migrationsFS.ReadFile(path.Join("migrations", file.Name()))
if err != nil {
return err
}
// Start a transaction
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
// Run migration
_, err = tx.ExecContext(ctx, string(sql))
if err != nil {
return fmt.Errorf("failed to run migration '%s': %s", file.Name(), err.Error())
}
// Update migration version
_, err = tx.ExecContext(ctx, fmt.Sprintf("UPDATE %s SET version=$1", migrationVersionTable), version)
if err != nil {
return err
}
// Commit migration
err = tx.Commit()
if err != nil {
return err
}
}
return nil
}
// MigrationStatus implements drivers.Connection
func (c *connection) MigrationStatus(ctx context.Context) (current int, desired int, err error) {
// Get current version
err = c.db.QueryRowxContext(ctx, fmt.Sprintf("select version from %s", migrationVersionTable)).Scan(¤t)
if err != nil {
return 0, 0, err
}
// Set desired to version number of last migration file
files, err := migrationsFS.ReadDir("migrations")
if err != nil {
return 0, 0, err
}
if len(files) > 0 {
file := files[len(files)-1]
version, err := migrationFilenameToVersion(file.Name())
if err != nil {
return 0, 0, fmt.Errorf("unexpected migration filename: %s", file.Name())
}
desired = version
}
return current, desired, nil
}
func migrationFilenameToVersion(name string) (int, error) {
return strconv.Atoi(strings.TrimSuffix(name, ".sql"))
}