-
Notifications
You must be signed in to change notification settings - Fork 90
/
testingutil.go
199 lines (167 loc) · 4.18 KB
/
testingutil.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
package testingutil
import (
"database/sql"
"flag"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/mattn/go-sqlite3"
)
func init() {
// Register a test driver for persisting the WAL after DB.Close()
sql.Register("sqlite3-persist-wal", &sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
if err := conn.SetFileControlInt("main", sqlite3.SQLITE_FCNTL_PERSIST_WAL, 1); err != nil {
return fmt.Errorf("cannot set file control: %w", err)
}
return nil
},
})
}
var (
journalMode = flag.String("journal-mode", "delete", "")
pageSize = flag.Int("page-size", 0, "")
noCompress = flag.Bool("no-compress", false, "disable ltx compression")
)
// IsWALMode returns the true if -journal-mode is set to "wal".
func IsWALMode() bool {
return JournalMode() == "wal"
}
// JournalMode returns the value of -journal-mode.
func JournalMode() string {
return strings.ToLower(*journalMode)
}
// PageSize returns the value of -page-size flag
func PageSize() int {
if *pageSize == 0 {
return 4096
}
return *pageSize
}
// Compress returns true if LTX compression is enabled.
func Compress() bool {
return !*noCompress
}
// OpenSQLDB opens a connection to a SQLite database.
func OpenSQLDB(tb testing.TB, dsn string) *sql.DB {
tb.Helper()
db, err := sql.Open("sqlite3", dsn)
if err != nil {
tb.Fatal(err)
}
if *pageSize != 0 {
if _, err := db.Exec(fmt.Sprintf(`PRAGMA page_size = %d`, *pageSize)); err != nil {
tb.Fatal(err)
}
}
if _, err := db.Exec(`PRAGMA busy_timeout = 5000`); err != nil {
tb.Fatal(err)
}
if _, err := db.Exec(`PRAGMA journal_mode = ` + *journalMode); err != nil {
tb.Fatal(err)
}
tb.Cleanup(func() {
if err := db.Close(); err != nil {
tb.Fatal(err)
}
})
return db
}
// ReopenSQLDB closes the existing database connection and reopens it with the DSN.
func ReopenSQLDB(tb testing.TB, db **sql.DB, dsn string) {
tb.Helper()
if err := (*db).Close(); err != nil {
tb.Fatal(err)
}
*db = OpenSQLDB(tb, dsn)
}
// WithTx executes fn in the context of a database transaction.
// Transaction is committed automatically.
func WithTx(tb testing.TB, driverName, dsn string, fn func(tx *sql.Tx)) {
tb.Helper()
db, err := sql.Open(driverName, dsn)
if err != nil {
tb.Fatal(err)
}
defer func() { _ = db.Close() }()
if _, err := db.Exec(`PRAGMA busy_timeout = 5000`); err != nil {
tb.Fatal(err)
} else if _, err := db.Exec(`PRAGMA journal_mode = ` + *journalMode); err != nil {
tb.Fatal(err)
}
tx, err := db.Begin()
if err != nil {
tb.Fatal(err)
}
defer func() { _ = tx.Rollback() }()
fn(tx)
if err := tx.Commit(); err != nil {
tb.Fatal(err)
}
}
// RetryUntil calls fn every interval until it returns nil or timeout elapses.
func RetryUntil(tb testing.TB, interval, timeout time.Duration, fn func() error) {
tb.Helper()
ticker := time.NewTicker(interval)
defer ticker.Stop()
timer := time.NewTimer(timeout)
defer timer.Stop()
var err error
for {
select {
case <-ticker.C:
if err = fn(); err == nil {
return
}
case <-timer.C:
tb.Fatalf("timeout: %s", err)
}
}
}
// MustCopyDir recursively copies files from src directory to dst directory.
func MustCopyDir(tb testing.TB, src, dst string) {
if err := os.MkdirAll(dst, 0755); err != nil {
tb.Fatal(err)
}
ents, err := os.ReadDir(src)
if err != nil {
tb.Fatal(err)
}
for _, ent := range ents {
fi, err := os.Stat(filepath.Join(src, ent.Name()))
if err != nil {
tb.Fatal(err)
}
// If it's a directory, copy recursively.
if fi.IsDir() {
MustCopyDir(tb, filepath.Join(src, ent.Name()), filepath.Join(dst, ent.Name()))
continue
}
// If it's a file, open the source file.
r, err := os.Open(filepath.Join(src, ent.Name()))
if err != nil {
tb.Fatal(err)
}
defer func() { _ = r.Close() }()
// Create destination file.
w, err := os.Create(filepath.Join(dst, ent.Name()))
if err != nil {
tb.Fatal(err)
}
defer func() { _ = w.Close() }()
// Copy contents of file to destination.
if _, err := io.Copy(w, r); err != nil {
tb.Fatal(err)
}
// Release file handles.
if err := r.Close(); err != nil {
tb.Fatal(err)
} else if err := w.Close(); err != nil {
tb.Fatal(err)
}
}
}