/
migrate.go
156 lines (131 loc) · 3.88 KB
/
migrate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
package duckdb
import (
"context"
"embed"
"fmt"
"path"
"strconv"
"strings"
"github.com/jmoiron/sqlx"
)
// Embed migrations directory in the binary
//
//go:embed migrations/*.sql
var migrationsFS embed.FS
// Name of the table that tracks migrations.
var migrationVersionTable = "rill.migration_version"
// Migrate implements drivers.Connection.
// Migrate for DuckDB may not be safe for concurrent use.
func (c *connection) Migrate(ctx context.Context) (err error) {
conn, release, err := c.acquireMetaConn(ctx)
if err != nil {
return err
}
defer func() { _ = release() }()
// Create rill schema if it doesn't exist
_, err = conn.ExecContext(ctx, "create schema if not exists rill")
if err != nil {
return c.checkErr(err)
}
// Create migrationVersionTable if it doesn't exist
_, err = conn.ExecContext(ctx, fmt.Sprintf("create table if not exists %s(version integer not null)", migrationVersionTable))
if err != nil {
return c.checkErr(err)
}
// Set the version to 0 if table is empty
_, err = conn.ExecContext(ctx, fmt.Sprintf("insert into %s(version) select 0 where 0=(select count(*) from %s)", migrationVersionTable, migrationVersionTable))
if err != nil {
return c.checkErr(err)
}
// Get version of latest migration
var currentVersion int
err = conn.QueryRowContext(ctx, fmt.Sprintf("select version from %s", migrationVersionTable)).Scan(¤tVersion)
if err != nil {
return c.checkErr(err)
}
// Iterate over migrations (sorted by filename)
files, err := migrationsFS.ReadDir("migrations")
if err != nil {
return err
}
for _, file := range files {
// Extract version number from filename
version, err := migrationFilenameToVersion(file.Name())
if err != nil {
return fmt.Errorf("unexpected migration filename: %s", file.Name())
}
// Skip migrations below current version
if version <= currentVersion {
continue
}
// Read SQL
sql, err := migrationsFS.ReadFile(path.Join("migrations", file.Name()))
if err != nil {
return err
}
err = c.migrateSingle(ctx, conn, file.Name(), sql, version)
if err != nil {
return err
}
}
return nil
}
func (c *connection) migrateSingle(ctx context.Context, conn *sqlx.Conn, name string, sql []byte, version int) (err error) {
// Start a transaction
tx, err := conn.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() { err := tx.Rollback(); _ = c.checkErr(err) }()
// Run migration
_, err = tx.ExecContext(ctx, string(sql))
if err != nil {
return fmt.Errorf("failed to run migration '%s': %w", name, c.checkErr(err))
}
// Update migration version
_, err = tx.ExecContext(ctx, fmt.Sprintf("UPDATE %s SET version=?", migrationVersionTable), version)
if err != nil {
return c.checkErr(err)
}
// Commit migration
err = tx.Commit()
if err != nil {
return c.checkErr(err)
}
// Force DuckDB to merge WAL into .db file
_, err = conn.ExecContext(ctx, "CHECKPOINT;")
if err != nil {
return c.checkErr(err)
}
return nil
}
// MigrationStatus implements drivers.Connection.
func (c *connection) MigrationStatus(ctx context.Context) (current, desired int, err error) {
conn, release, err := c.acquireMetaConn(ctx)
if err != nil {
return 0, 0, err
}
defer func() { _ = release() }()
// Get current version
err = conn.QueryRowxContext(ctx, fmt.Sprintf("select version from %s", migrationVersionTable)).Scan(¤t)
if err != nil {
return 0, 0, c.checkErr(err)
}
// Set desired to version number of last migration file
files, err := migrationsFS.ReadDir("migrations")
if err != nil {
return 0, 0, err
}
if len(files) > 0 {
file := files[len(files)-1]
version, err := migrationFilenameToVersion(file.Name())
if err != nil {
return 0, 0, fmt.Errorf("unexpected migration filename: %s", file.Name())
}
desired = version
}
return current, desired, nil
}
func migrationFilenameToVersion(name string) (int, error) {
return strconv.Atoi(strings.TrimSuffix(name, ".sql"))
}