/
assertions.go
287 lines (258 loc) · 9.23 KB
/
assertions.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
package assertions
import (
"errors"
"fmt"
"path/filepath"
"reflect"
"runtime"
"strings"
"github.com/d4l3k/messagediff"
"github.com/prysmaticlabs/prysm/v3/encoding/ssz/equality"
"github.com/sirupsen/logrus/hooks/test"
"google.golang.org/protobuf/proto"
)
// AssertionTestingTB exposes enough testing.TB methods for assertions.
type AssertionTestingTB interface {
Errorf(format string, args ...interface{})
Fatalf(format string, args ...interface{})
}
type assertionLoggerFn func(string, ...interface{})
// Equal compares values using comparison operator.
func Equal(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if expected != actual {
errMsg := parseMsg("Values are not equal", msg...)
_, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s, want: %[4]v (%[4]T), got: %[5]v (%[5]T)", filepath.Base(file), line, errMsg, expected, actual)
}
}
// NotEqual compares values using comparison operator.
func NotEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if expected == actual {
errMsg := parseMsg("Values are equal", msg...)
_, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s, both values are equal: %[4]v (%[4]T)", filepath.Base(file), line, errMsg, expected)
}
}
// DeepEqual compares values using DeepEqual.
func DeepEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if !isDeepEqual(expected, actual) {
errMsg := parseMsg("Values are not equal", msg...)
_, file, line, _ := runtime.Caller(2)
diff, _ := messagediff.PrettyDiff(expected, actual)
loggerFn("%s:%d %s, want: %#v, got: %#v, diff: %s", filepath.Base(file), line, errMsg, expected, actual, diff)
}
}
// DeepNotEqual compares values using DeepEqual.
func DeepNotEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if isDeepEqual(expected, actual) {
errMsg := parseMsg("Values are equal", msg...)
_, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s, want: %#v, got: %#v", filepath.Base(file), line, errMsg, expected, actual)
}
}
// DeepSSZEqual compares values using ssz.DeepEqual.
func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if !equality.DeepEqual(expected, actual) {
errMsg := parseMsg("Values are not equal", msg...)
_, file, line, _ := runtime.Caller(2)
diff, _ := messagediff.PrettyDiff(expected, actual)
loggerFn("%s:%d %s, want: %#v, got: %#v, diff: %s", filepath.Base(file), line, errMsg, expected, actual, diff)
}
}
// DeepNotSSZEqual compares values using ssz.DeepEqual.
func DeepNotSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
if equality.DeepEqual(expected, actual) {
errMsg := parseMsg("Values are equal", msg...)
_, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s, want: %#v, got: %#v", filepath.Base(file), line, errMsg, expected, actual)
}
}
// StringContains checks whether a string contains specified substring. If flag is false, inverse is checked.
func StringContains(loggerFn assertionLoggerFn, expected, actual string, flag bool, msg ...interface{}) {
if flag {
if !strings.Contains(actual, expected) {
errMsg := parseMsg("Expected substring is not found", msg...)
_, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s, got: %v, want: %s", filepath.Base(file), line, errMsg, actual, expected)
}
} else {
if strings.Contains(actual, expected) {
errMsg := parseMsg("Unexpected substring is found", msg...)
_, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s, got: %v, not want: %s", filepath.Base(file), line, errMsg, actual, expected)
}
}
}
// NoError asserts that error is nil.
func NoError(loggerFn assertionLoggerFn, err error, msg ...interface{}) {
if err != nil {
errMsg := parseMsg("Unexpected error", msg...)
_, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s: %v", filepath.Base(file), line, errMsg, err)
}
}
// ErrorIs uses Errors.Is to recursively unwrap err looking for target in the chain.
// If any error in the chain matches target, the assertion will pass.
func ErrorIs(loggerFn assertionLoggerFn, err, target error, msg ...interface{}) {
if !errors.Is(err, target) {
errMsg := parseMsg(fmt.Sprintf("error %s", target), msg...)
_, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s: %v", filepath.Base(file), line, errMsg, err)
}
}
// ErrorContains asserts that actual error contains wanted message.
func ErrorContains(loggerFn assertionLoggerFn, want string, err error, msg ...interface{}) {
if want == "" {
loggerFn("Want string can't be empty")
}
if err == nil || !strings.Contains(err.Error(), want) {
errMsg := parseMsg("Expected error not returned", msg...)
_, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s, got: %v, want: %s", filepath.Base(file), line, errMsg, err, want)
}
}
// NotNil asserts that passed value is not nil.
func NotNil(loggerFn assertionLoggerFn, obj interface{}, msg ...interface{}) {
if isNil(obj) {
errMsg := parseMsg("Unexpected nil value", msg...)
_, file, line, _ := runtime.Caller(2)
loggerFn("%s:%d %s", filepath.Base(file), line, errMsg)
}
}
// isNil checks that underlying value of obj is nil.
func isNil(obj interface{}) bool {
if obj == nil {
return true
}
value := reflect.ValueOf(obj)
switch value.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
return value.IsNil()
}
return false
}
// LogsContain checks whether a given substring is a part of logs. If flag=false, inverse is checked.
func LogsContain(loggerFn assertionLoggerFn, hook *test.Hook, want string, flag bool, msg ...interface{}) {
_, file, line, _ := runtime.Caller(2)
entries := hook.AllEntries()
logs := make([]string, 0, len(entries))
match := false
for _, e := range entries {
msg, err := e.String()
if err != nil {
loggerFn("%s:%d Failed to format log entry to string: %v", filepath.Base(file), line, err)
return
}
if strings.Contains(msg, want) {
match = true
}
for _, field := range e.Data {
fieldStr, ok := field.(string)
if !ok {
continue
}
if strings.Contains(fieldStr, want) {
match = true
}
}
logs = append(logs, msg)
}
var errMsg string
if flag && !match {
errMsg = parseMsg("Expected log not found", msg...)
} else if !flag && match {
errMsg = parseMsg("Unexpected log found", msg...)
}
if errMsg != "" {
loggerFn("%s:%d %s: %v\nSearched logs:\n%v", filepath.Base(file), line, errMsg, want, logs)
}
}
func parseMsg(defaultMsg string, msg ...interface{}) string {
if len(msg) >= 1 {
msgFormat, ok := msg[0].(string)
if !ok {
return defaultMsg
}
return fmt.Sprintf(msgFormat, msg[1:]...)
}
return defaultMsg
}
func isDeepEqual(expected, actual interface{}) bool {
_, isProto := expected.(proto.Message)
if isProto {
return proto.Equal(expected.(proto.Message), actual.(proto.Message))
}
return reflect.DeepEqual(expected, actual)
}
// NotEmpty asserts that an object's fields are not empty. This function recursively checks each
// pointer / struct field.
func NotEmpty(loggerFn assertionLoggerFn, obj interface{}, msg ...interface{}) {
_, ignoreFieldsWithoutTags := obj.(proto.Message)
notEmpty(loggerFn, obj, ignoreFieldsWithoutTags, []string{} /*fields*/, 0 /*stackSize*/, msg...)
}
// notEmpty checks all fields are not zero, including pointer field references to other structs.
// This method has the option to ignore fields without struct tags, which is helpful for checking
// protobuf messages that have internal fields.
func notEmpty(loggerFn assertionLoggerFn, obj interface{}, ignoreFieldsWithoutTags bool, fields []string, stackSize int, msg ...interface{}) {
var v reflect.Value
if vo, ok := obj.(reflect.Value); ok {
v = reflect.Indirect(vo)
} else {
v = reflect.Indirect(reflect.ValueOf(obj))
}
if len(fields) == 0 {
fields = []string{v.Type().Name()}
}
fail := func(fields []string) {
m := parseMsg("", msg...)
errMsg := fmt.Sprintf("empty/zero field: %s", strings.Join(fields, "."))
if len(m) > 0 {
m = strings.Join([]string{m, errMsg}, ": ")
} else {
m = errMsg
}
_, file, line, _ := runtime.Caller(4 + stackSize)
loggerFn("%s:%d %s", filepath.Base(file), line, m)
}
if v.Kind() != reflect.Struct {
if v.IsZero() {
fail(fields)
}
return
}
for i := 0; i < v.NumField(); i++ {
if ignoreFieldsWithoutTags && len(v.Type().Field(i).Tag) == 0 {
continue
}
fields := append(fields, v.Type().Field(i).Name)
switch k := v.Field(i).Kind(); k {
case reflect.Ptr:
notEmpty(loggerFn, v.Field(i), ignoreFieldsWithoutTags, fields, stackSize+1, msg...)
case reflect.Slice:
f := v.Field(i)
if f.Len() == 0 {
fail(fields)
}
for i := 0; i < f.Len(); i++ {
notEmpty(loggerFn, f.Index(i), ignoreFieldsWithoutTags, fields, stackSize+1, msg...)
}
default:
if v.Field(i).IsZero() {
fail(fields)
}
}
}
}
// TBMock exposes enough testing.TB methods for assertions.
type TBMock struct {
ErrorfMsg string
FatalfMsg string
}
// Errorf writes testing logs to ErrorfMsg.
func (tb *TBMock) Errorf(format string, args ...interface{}) {
tb.ErrorfMsg = fmt.Sprintf(format, args...)
}
// Fatalf writes testing logs to FatalfMsg.
func (tb *TBMock) Fatalf(format string, args ...interface{}) {
tb.FatalfMsg = fmt.Sprintf(format, args...)
}