/
upgrade.go
296 lines (283 loc) · 8.93 KB
/
upgrade.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
package sql_store_upgrade
import (
"database/sql"
"errors"
"fmt"
"strings"
)
type upgradeFunc func(*sql.Tx, string) error
var ErrUnknownDialect = errors.New("unknown dialect")
var Upgrades = [...]upgradeFunc{
func(tx *sql.Tx, _ string) error {
for _, query := range []string{
`CREATE TABLE IF NOT EXISTS crypto_account (
device_id VARCHAR(255) PRIMARY KEY,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account bytea NOT NULL
)`,
`CREATE TABLE IF NOT EXISTS crypto_message_index (
sender_key CHAR(43),
session_id CHAR(43),
"index" INTEGER,
event_id VARCHAR(255) NOT NULL,
timestamp BIGINT NOT NULL,
PRIMARY KEY (sender_key, session_id, "index")
)`,
`CREATE TABLE IF NOT EXISTS crypto_tracked_user (
user_id VARCHAR(255) PRIMARY KEY
)`,
`CREATE TABLE IF NOT EXISTS crypto_device (
user_id VARCHAR(255),
device_id VARCHAR(255),
identity_key CHAR(43) NOT NULL,
signing_key CHAR(43) NOT NULL,
trust SMALLINT NOT NULL,
deleted BOOLEAN NOT NULL,
name VARCHAR(255) NOT NULL,
PRIMARY KEY (user_id, device_id)
)`,
`CREATE TABLE IF NOT EXISTS crypto_olm_session (
session_id CHAR(43) PRIMARY KEY,
sender_key CHAR(43) NOT NULL,
session bytea NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL
)`,
`CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
session_id CHAR(43) PRIMARY KEY,
sender_key CHAR(43) NOT NULL,
signing_key CHAR(43) NOT NULL,
room_id VARCHAR(255) NOT NULL,
session bytea NOT NULL,
forwarding_chains bytea NOT NULL
)`,
`CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session (
room_id VARCHAR(255) PRIMARY KEY,
session_id CHAR(43) NOT NULL UNIQUE,
session bytea NOT NULL,
shared BOOLEAN NOT NULL,
max_messages INTEGER NOT NULL,
message_count INTEGER NOT NULL,
max_age BIGINT NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL
)`,
} {
if _, err := tx.Exec(query); err != nil {
return err
}
}
return nil
},
func(tx *sql.Tx, dialect string) error {
if dialect == "postgres" {
tablesToPkeys := map[string][]string{
"crypto_account": {},
"crypto_olm_session": {"session_id"},
"crypto_megolm_inbound_session": {"session_id"},
"crypto_megolm_outbound_session": {"room_id"},
}
for tableName, pkeys := range tablesToPkeys {
// add account_id to primary key
pkeyStr := strings.Join(append(pkeys, "account_id"), ", ")
for _, query := range []string{
fmt.Sprintf("ALTER TABLE %s ADD COLUMN account_id VARCHAR(255)", tableName),
fmt.Sprintf("UPDATE %s SET account_id=''", tableName),
fmt.Sprintf("ALTER TABLE %s ALTER COLUMN account_id SET NOT NULL", tableName),
fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s_pkey", tableName, tableName),
fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s_pkey PRIMARY KEY (%s)", tableName, tableName, pkeyStr),
} {
if _, err := tx.Exec(query); err != nil {
return err
}
}
}
} else if dialect == "sqlite3" {
tableCols := map[string]string{
"crypto_account": `
account_id VARCHAR(255) NOT NULL,
device_id VARCHAR(255) NOT NULL,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account BLOB NOT NULL,
PRIMARY KEY (account_id)
`,
"crypto_olm_session": `
account_id VARCHAR(255) NOT NULL,
session_id CHAR(43) NOT NULL,
sender_key CHAR(43) NOT NULL,
session BLOB NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL,
PRIMARY KEY (account_id, session_id)
`,
"crypto_megolm_inbound_session": `
account_id VARCHAR(255) NOT NULL,
session_id CHAR(43) NOT NULL,
sender_key CHAR(43) NOT NULL,
signing_key CHAR(43) NOT NULL,
room_id VARCHAR(255) NOT NULL,
session BLOB NOT NULL,
forwarding_chains BLOB NOT NULL,
PRIMARY KEY (account_id, session_id)
`,
"crypto_megolm_outbound_session": `
account_id VARCHAR(255) NOT NULL,
room_id VARCHAR(255) NOT NULL,
session_id CHAR(43) NOT NULL UNIQUE,
session BLOB NOT NULL,
shared BOOLEAN NOT NULL,
max_messages INTEGER NOT NULL,
message_count INTEGER NOT NULL,
max_age BIGINT NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL,
PRIMARY KEY (account_id, room_id)
`,
}
for tableName, cols := range tableCols {
// re-create tables with account_id column and new pkey and re-insert rows
for _, query := range []string{
fmt.Sprintf("ALTER TABLE %s RENAME TO old_%s", tableName, tableName),
fmt.Sprintf("CREATE TABLE %s (%s)", tableName, cols),
fmt.Sprintf("INSERT INTO %s SELECT '', * FROM old_%s", tableName, tableName),
fmt.Sprintf("DROP TABLE old_%s", tableName),
} {
if _, err := tx.Exec(query); err != nil {
return err
}
}
}
} else {
return fmt.Errorf("%w (%s)", ErrUnknownDialect, dialect)
}
return nil
},
func(tx *sql.Tx, dialect string) error {
if dialect == "postgres" {
alters := [...]string{
"ADD COLUMN withheld_code VARCHAR(255)",
"ADD COLUMN withheld_reason TEXT",
"ALTER COLUMN signing_key DROP NOT NULL",
"ALTER COLUMN session DROP NOT NULL",
"ALTER COLUMN forwarding_chains DROP NOT NULL",
}
for _, alter := range alters {
_, err := tx.Exec(fmt.Sprintf("ALTER TABLE crypto_megolm_inbound_session %s", alter))
if err != nil {
return err
}
}
} else if dialect == "sqlite3" {
_, err := tx.Exec("ALTER TABLE crypto_megolm_inbound_session RENAME TO old_crypto_megolm_inbound_session")
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_megolm_inbound_session (
account_id VARCHAR(255) NOT NULL,
session_id CHAR(43) NOT NULL,
sender_key CHAR(43) NOT NULL,
signing_key CHAR(43),
room_id VARCHAR(255) NOT NULL,
session BLOB,
forwarding_chains BLOB,
withheld_code VARCHAR(255),
withheld_reason TEXT,
PRIMARY KEY (account_id, session_id)
)`)
if err != nil {
return err
}
_, err = tx.Exec(`INSERT INTO crypto_megolm_inbound_session
(session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id)
SELECT * FROM old_crypto_megolm_inbound_session`)
if err != nil {
return err
}
_, err = tx.Exec("DROP TABLE old_crypto_megolm_inbound_session")
if err != nil {
return err
}
} else {
return fmt.Errorf("%w (%s)", ErrUnknownDialect, dialect)
}
return nil
},
func(tx *sql.Tx, dialect string) error {
if _, err := tx.Exec(
`CREATE TABLE IF NOT EXISTS crypto_cross_signing_keys (
user_id VARCHAR(255) NOT NULL,
usage VARCHAR(20) NOT NULL,
key CHAR(43) NOT NULL,
PRIMARY KEY (user_id, usage)
)`,
); err != nil {
return err
}
if _, err := tx.Exec(
`CREATE TABLE IF NOT EXISTS crypto_cross_signing_signatures (
signed_user_id VARCHAR(255) NOT NULL,
signed_key VARCHAR(255) NOT NULL,
signer_user_id VARCHAR(255) NOT NULL,
signer_key VARCHAR(255) NOT NULL,
signature CHAR(88) NOT NULL,
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
)`,
); err != nil {
return err
}
return nil
},
}
// GetVersion returns the current version of the DB schema.
func GetVersion(db *sql.DB) (int, error) {
_, err := db.Exec("CREATE TABLE IF NOT EXISTS crypto_version (version INTEGER)")
if err != nil {
return -1, err
}
version := 0
row := db.QueryRow("SELECT version FROM crypto_version LIMIT 1")
if row != nil {
_ = row.Scan(&version)
}
return version, nil
}
// SetVersion sets the schema version in a running DB transaction.
func SetVersion(tx *sql.Tx, version int) error {
_, err := tx.Exec("DELETE FROM crypto_version")
if err != nil {
return err
}
_, err = tx.Exec("INSERT INTO crypto_version (version) VALUES ($1)", version)
return err
}
// Upgrade upgrades the database from the current to the latest version available.
func Upgrade(db *sql.DB, dialect string) error {
version, err := GetVersion(db)
if err != nil {
return err
}
// perform migrations starting with #version
for ; version < len(Upgrades); version++ {
tx, err := db.Begin()
if err != nil {
return err
}
// run each migrate func
migrateFunc := Upgrades[version]
err = migrateFunc(tx, dialect)
if err != nil {
_ = tx.Rollback()
return err
}
// also update the version in this tx
if err = SetVersion(tx, version+1); err != nil {
return err
}
if err = tx.Commit(); err != nil {
return err
}
}
return nil
}