-
Notifications
You must be signed in to change notification settings - Fork 4
/
testhelpers.go
227 lines (200 loc) · 7.35 KB
/
testhelpers.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
/*******************************************************************************
*
* Copyright 2017-2019 SAP SE
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You should have received a copy of the License along with this
* program. If not, you may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*******************************************************************************/
package easypg
import (
"database/sql"
"fmt"
"os"
"os/exec"
"path/filepath"
"regexp"
"strings"
"testing"
"github.com/sapcc/go-bits/osext"
)
// ClearTables removes all rows from the given tables.
func ClearTables(t *testing.T, db *sql.DB, tableNames ...string) {
t.Helper()
for _, tableName := range tableNames {
_, err := db.Exec("DELETE FROM " + tableName) //nolint:gosec // cannot provide tableName as bind parameter
if err != nil {
t.Fatalf("while clearing table %s: %s", tableName, err.Error())
}
}
}
// ResetPrimaryKeys resets the sequences for the "id" column of the given tables
// to start at 1 again (or if there are entries in the table, to start right
// after the entry with the highest ID).
func ResetPrimaryKeys(t *testing.T, db *sql.DB, tableNames ...string) {
t.Helper()
for _, tableName := range tableNames {
var nextID int64
query := "SELECT 1 + COALESCE(MAX(id), 0) FROM " + tableName //nolint:gosec // cannot provide tableName as bind parameter
err := db.QueryRow(query).Scan(&nextID)
if err != nil {
t.Fatalf("while checking IDs in table %s: %s", tableName, err.Error())
}
query = fmt.Sprintf(`ALTER SEQUENCE %s_id_seq RESTART WITH %d`, tableName, nextID)
_, err = db.Exec(query)
if err != nil {
t.Fatalf("while resetting ID sequence on table %s: %s", tableName, err.Error())
}
}
}
// ExecSQLFile loads a file containing SQL statements and executes them all.
// It implies that every SQL statement is on a single line.
func ExecSQLFile(t *testing.T, db *sql.DB, path string) {
t.Helper()
sqlBytes, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
// split into single statements because db.Exec() will just ignore everything after the first semicolon
for idx, line := range strings.Split(string(sqlBytes), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "--") {
continue
}
_, err = db.Exec(line)
if err != nil {
t.Fatalf("error on SQL line %d: %s", idx, err.Error())
}
}
}
// AssertDBContent makes a dump of the database contents (as a sequence of
// INSERT statements) and runs diff(1) against the given file, producing a test
// error if these two are different from each other.
func AssertDBContent(t *testing.T, db *sql.DB, fixtureFile string) {
t.Helper()
_, a := NewTracker(t, db)
a.AssertEqualToFile(fixtureFile)
}
// Tracker keeps a copy of the database contents and allows for checking the
// database contents (or changes made to them) during tests.
type Tracker struct {
t *testing.T
db *sql.DB
snap dbSnapshot
}
// NewTracker creates a new Tracker.
//
// Since the initial creation involves taking a snapshot, this snapshot is
// returned as a second value. This is an optimization, since it is often
// desired to assert on the full DB contents when creating the tracker. Calling
// Tracker.DBContent() directly after NewTracker() would do a useless second
// snapshot.
func NewTracker(t *testing.T, db *sql.DB) (*Tracker, Assertable) {
t.Helper()
snap := newDBSnapshot(t, db)
return &Tracker{t, db, snap}, Assertable{t, snap.ToSQL(nil)}
}
// DBContent produces a dump of the current database contents, as a sequence of
// INSERT statements on which test assertions can be executed.
func (t *Tracker) DBContent() Assertable {
t.t.Helper()
t.snap = newDBSnapshot(t.t, t.db)
return Assertable{t.t, t.snap.ToSQL(nil)}
}
// DBChanges produces a diff of the current database contents against the state
// at the last Tracker call, as a sequence of INSERT/UPDATE/DELETE statements on
// which test assertions can be executed.
func (t *Tracker) DBChanges() Assertable {
t.t.Helper()
snap := newDBSnapshot(t.t, t.db)
diff := snap.ToSQL(t.snap)
t.snap = snap
return Assertable{t.t, diff}
}
// Assertable contains a set of SQL statements. Instances are produced by
// methods on type Tracker.
type Assertable struct {
t *testing.T
payload string
}
// AssertEqualToFile compares the set of SQL statements to those in the given
// file. A test error is generated in case of differences.
func (a Assertable) AssertEqualToFile(fixtureFile string) {
a.t.Helper()
// write actual content to file to make it easy to copy the computed result over
// to the fixture path when a new test is added or an existing one is modified
fixturePath, err := filepath.Abs(fixtureFile)
failOnErr(a.t, err)
actualPath := fixturePath + ".actual"
failOnErr(a.t, os.WriteFile(actualPath, []byte(a.payload), 0o666))
cmd := exec.Command("diff", "-u", fixturePath, actualPath)
cmd.Stdin = nil
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
failOnErr(a.t, cmd.Run())
}
var whitespaceAtStartOfLineRx = regexp.MustCompile(`(?m)^\s+`)
// AssertEqual compares the set of SQL statements to those in the given string
// literal. A test error is generated in case of differences. This assertion
// is lenient with regards to whitespace to enable callers to format their
// string literals in a way that fits nicely in the surrounding code.
func (a Assertable) AssertEqual(expected string) {
a.t.Helper()
// cleanup indentation and empty lines in `expected`
expected = strings.TrimSpace(expected) + "\n"
expected = whitespaceAtStartOfLineRx.ReplaceAllString(expected, "")
// cleanup empty lines in `actual`
actual := strings.ReplaceAll(a.payload, "\n\n", "\n")
// quick path: if both are equal, we're fine
if expected == actual {
return
}
// slow path: show a diff
tmpDir, err := os.MkdirTemp("", "easypg-diff")
failOnErr(a.t, err)
actualPath := filepath.Join(tmpDir, "/actual")
failOnErr(a.t, os.WriteFile(actualPath, []byte(actual), 0o666))
expectedPath := filepath.Join(tmpDir, "/expected")
failOnErr(a.t, os.WriteFile(expectedPath, []byte(expected), 0o666))
diffCmd := osext.GetenvOrDefault("GOBITS_DIFF_CMD", "diff")
diffCmdSlice := []string{}
if diffCmd == "diff" {
diffCmdSlice = append(diffCmdSlice, "-u", "--color=always")
}
diffCmdSlice = append(diffCmdSlice, expectedPath, actualPath)
cmd := exec.Command(diffCmd, diffCmdSlice...)
cmd.Stdin = nil
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
failOnErr(a.t, cmd.Run())
}
// AssertEqualf is a shorthand for AssertEqual(fmt.Sprintf(...)).
func (a Assertable) AssertEqualf(format string, args ...any) {
a.t.Helper()
a.AssertEqual(fmt.Sprintf(format, args...))
}
// AssertEmpty is a shorthand for AssertEqual("").
func (a Assertable) AssertEmpty() {
a.t.Helper()
a.AssertEqual("")
}
// Ignore is a no-op. It is commonly used like `tr.DBChanges().Ignore()`, to
// clarify that a certain set of DB changes is not asserted on.
func (a Assertable) Ignore() {
}
func failOnErr(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}