/
migration.go
330 lines (283 loc) · 10.6 KB
/
migration.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
package tidal
import (
"bytes"
"database/sql"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
"text/template"
"time"
)
// Used to parse a migration filename's components
var fnamere = regexp.MustCompile(`^(\d+)[_-]([\w\d_-]+)\.sql$`)
// Open a migration SQL file and parse it into a Migration object.
func Open(path string) (m Migration, err error) {
filename := filepath.Base(path)
if !fnamere.MatchString(filename) {
return m, fmt.Errorf("could not parse %q as a migration filename", filename)
}
if m.Name, m.Revision, err = parseFilename(filename); err != nil {
return m, err
}
// Now read the file and compress the contents into a descriptor
var f *os.File
if f, err = os.Open(path); err != nil {
return m, err
}
defer f.Close()
if m.descriptor, err = NewDescriptor(f, filename); err != nil {
return m, err
}
return m, nil
}
// Migration defines how changes to the database are applied (up) or rolled back (down).
// Each migration is defined by two distinct pieces of SQL code, one for up and one for
// down, which are are parsed from a single SQL file, delimited by tidal-parseable
// comments. Migrations are generally compiled into a compressed descriptor format that
// can be included with application source code so that the migrations are compiled with
// the binary, rather than sourced from external files.
//
// Each migration can also include status information from the database (collected at
// runtime). This information defines the database's knowledge of the migration, e.g.
// has the migration been applied or not. Status information is stored in the database
// inside of a migrations table that is applied with Revision 0 (an application's first
// revision is Revision 1). The table is updated with migrate and sync commands.
//
// Migrations are identified by a unique revision number that specifies the sequence
// which migrations must be applied. For now that means that migrations can only be
// applied linearly (and not as a directed acyclic graph with multiple dependencies).
// Future work is required to create a migration DAG structure.
type Migration struct {
Revision int // the unique id of the migration, prefix from the migration file
Name string // the human readable name of the migration, suffix of the migration file
Active bool // if the migration has been applied and is part of the active schema
Applied time.Time // the timestamp the migration was applied
Created time.Time // the timestamp the migration was added to the database
descriptor Descriptor // contains the gzip compressed data to minimize compile time size
dbsync bool // if the migration has been synchronized to the database
}
// Up applies the migration to the database. The migration creates a transaction that
// executes the SQL UP code as well as an update to the migrations table reflecting the
// change in state. Both of these SQL commands must be executed together without error
// otherwise the entire transaction is rolled back.
func (m *Migration) Up(conn *sql.DB) (err error) {
var tx *sql.Tx
if tx, err = conn.Begin(); err != nil {
return fmt.Errorf("could not begin transaction to apply revision %d: %s", m.Revision, err)
}
defer func() {
// Recover from panic, rolling back transaction, then re-throw panic
if p := recover(); p != nil {
tx.Rollback()
panic(p)
} else if err != nil {
// Rollback the transaction, but don't get the rollback error since the
// error is already non nil, and that's what we want to return
tx.Rollback()
} else {
// Success, commit! Store any commit errors to return if necessary
err = tx.Commit()
}
}()
// Execute up transaction
err = m.upTx(tx)
return err
}
func (m *Migration) upTx(tx *sql.Tx) (err error) {
var sql string
if sql, err = m.UpSQL(); err != nil {
return fmt.Errorf("could not parse revision %d up sql: %s", m.Revision, err)
}
if _, err = tx.Exec(sql); err != nil {
return fmt.Errorf("could not exec revision %d up: %s", m.Revision, err)
}
// If this is an application migration, update the migrations status table
if m.Revision > 0 {
sql := "UPDATE migrations SET active=$1, applied=$2 WHERE revision=$3"
if _, err = tx.Exec(sql, true, time.Now().UTC(), m.Revision); err != nil {
return fmt.Errorf("could not update migration status of revision %d: %s", m.Revision, err)
}
}
return nil
}
// UpSQL returns the sql statement defined for applying the migration to the specific
// revision. This requires parsing the underlying descriptor correctly.
func (m *Migration) UpSQL() (string, error) {
return m.descriptor.Up()
}
// Down rolls back the migration from the database. The migration creates a transaction
// that executes the SQL DOWN code as well as an update to the migrations table reflecting
// the change in state. Both of these SQL commands must be executed together without
// error, otherwise the entire transaction is rolled back.
func (m *Migration) Down(conn *sql.DB) (err error) {
var tx *sql.Tx
if tx, err = conn.Begin(); err != nil {
return fmt.Errorf("could not begin transaction to rollback revision %d: %s", m.Revision, err)
}
defer func() {
// Recover from panic, rolling back transaction, then re-throw panic
if p := recover(); p != nil {
tx.Rollback()
panic(p)
} else if err != nil {
// Rollback the transaction, but don't get the rollback error since the
// error is already non nil, and that's what we want to return
tx.Rollback()
} else {
// Success, commit! Store any commit errors to return if necessary
err = tx.Commit()
}
}()
// Execute down transaction
err = m.downTx(tx)
return err
}
func (m *Migration) downTx(tx *sql.Tx) (err error) {
var sql string
if sql, err = m.DownSQL(); err != nil {
return fmt.Errorf("could not parse revision %d down sql: %s", m.Revision, err)
}
if _, err = tx.Exec(sql); err != nil {
return fmt.Errorf("could not exec revision %d down: %s", m.Revision, err)
}
// If this is an application migration, update the migrations status table
if m.Revision > 0 {
sql := "UPDATE migrations SET active=$1, applied=NULL WHERE revision=$3"
if _, err = tx.Exec(sql, false, m.Revision); err != nil {
return fmt.Errorf("could not update migration status of revision %d: %s", m.Revision, err)
}
}
return nil
}
// DownSQL returns the sql statement defined for rolling back the migration to a state
// before this specific revision. This requires parsing the underlying descriptor correctly.
func (m *Migration) DownSQL() (string, error) {
return m.descriptor.Down()
}
// Package returns the parsed package directive from the descriptor if it has one.
func (m *Migration) Package() (string, error) {
return m.descriptor.Package()
}
// Synchronized returns true if the migration state has been synchronized with the database.
func (m *Migration) Synchronized() bool {
return m.dbsync
}
// Predecessors returns the number of migrations before this migration.
func (m *Migration) Predecessors() (n int, err error) {
if len(migrations) == 0 {
return 0, fmt.Errorf("revision %d was not registered", m.Revision)
}
for _, o := range migrations {
if m.Revision == o.Revision {
break
}
if o.Revision > m.Revision {
return 0, fmt.Errorf("revision %d was not registered", m.Revision)
}
n++
}
if n == len(migrations) && migrations[n-1].Revision != m.Revision {
return 0, fmt.Errorf("revision %d was not registered", m.Revision)
}
return n, nil
}
// Successors returns the number of migrations after this migration.
func (m *Migration) Successors() (n int, err error) {
i := sort.Search(len(migrations), func(i int) bool {
return m.Revision <= migrations[i].Revision
})
if i < len(migrations) && migrations[i].Revision == m.Revision {
if i+1 == len(migrations) {
return 0, nil
}
return len(migrations[i+1:]), nil
}
return 0, fmt.Errorf("revision %d was not registered", m.Revision)
}
const sqldata = `-- Revision {{ .Revision }} generated on {{ .Timestamp }}{{ if .PackageName }}
-- package: {{ .PackageName }}{{ end }}
-- migrate: up
-- insert up migration sql here
-- migrate: down
-- insert down migration sql here
-- migrate: end
`
var sqldataTemplate = template.Must(template.New("").Parse(sqldata))
// sqldataContext is used to populate the sqldata template
type sqldataContext struct {
Revision int
Timestamp string
PackageName string
}
// Create a new SQL migration file for code generation. The migration file is an ANSI
// SQL file that uses TSQL comments to deliniate the up and down migrations. Using the
// code generation tool, these files are directly embedded into your application code
// base using compressed Descriptors, which are registered as migrations at runtime.
// This helper utility adds the next migration sql file revision (based on the latest
// registered revision and the maximum revision number from sibling files) and writes
// out an empty template to the migrations directory.
func Create(migrationsDirectory, name, packageName string) (err error) {
var latestRevision int
if len(migrations) > 0 {
latestRevision = migrations[len(migrations)-1].Revision
}
var listing []os.FileInfo
if listing, err = ioutil.ReadDir(migrationsDirectory); err != nil {
return err
}
for _, finfo := range listing {
filename := finfo.Name()
if !strings.HasSuffix(filename, ".sql") {
continue
}
_, revision, err := parseFilename(filename)
if err != nil {
return err
}
if revision > latestRevision {
latestRevision = revision
}
}
// Create the template context
now := time.Now().Local()
ctx := &sqldataContext{
Revision: latestRevision + 1,
Timestamp: now.Format("2006-01-02 15:04:05 -0700"),
PackageName: packageName,
}
// Execute the template
builder := &bytes.Buffer{}
if err = sqldataTemplate.Execute(builder, ctx); err != nil {
return err
}
// Determine the write path
if name == "" {
name = fmt.Sprintf("auto_%s", now.Format("200601021504"))
}
name = strings.Replace(name, " ", "_", -1)
outpath := filepath.Join(migrationsDirectory, fmt.Sprintf("%04d_%s.sql", latestRevision+1, name))
// Create the generated migration template file
var f *os.File
if f, err = os.Create(outpath); err != nil {
return err
}
defer f.Close()
if _, err = f.Write(builder.Bytes()); err != nil {
return err
}
return nil
}
// helper function parse a filename or path into Migration metadata
func parseFilename(filename string) (name string, revision int, err error) {
groups := fnamere.FindStringSubmatch(filename)
name = strings.Replace(groups[2], "_", " ", -1)
if revision, err = strconv.Atoi(groups[1]); err != nil {
return "", 0, fmt.Errorf("could not parse %q to revision number: %s", groups[1], err)
}
return name, revision, nil
}