-
Notifications
You must be signed in to change notification settings - Fork 458
/
assertions.go
170 lines (146 loc) · 4.67 KB
/
assertions.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
package spiretest
import (
"fmt"
"reflect"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
)
var (
protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
)
func RequireErrorContains(tb testing.TB, err error, contains string) {
tb.Helper()
if !AssertErrorContains(tb, err, contains) {
tb.FailNow()
}
}
func AssertErrorContains(tb testing.TB, err error, contains string) bool {
tb.Helper()
if !assert.Error(tb, err) {
return false
}
if !assert.Contains(tb, err.Error(), contains) {
return false
}
return true
}
func RequireGRPCStatus(tb testing.TB, err error, code codes.Code, message string) {
tb.Helper()
if !AssertGRPCStatus(tb, err, code, message) {
tb.FailNow()
}
}
func AssertGRPCStatus(tb testing.TB, err error, code codes.Code, message string) bool {
tb.Helper()
st := status.Convert(err)
if code != st.Code() || message != st.Message() {
return assert.Fail(tb, fmt.Sprintf("Status code=%q msg=%q does not match code=%q msg=%q", st.Code(), st.Message(), code, message))
}
return true
}
func RequireGRPCStatusContains(tb testing.TB, err error, code codes.Code, contains string, msgAndArgs ...any) {
tb.Helper()
if !AssertGRPCStatusContains(tb, err, code, contains, msgAndArgs...) {
tb.FailNow()
}
}
func AssertGRPCStatusContains(tb testing.TB, err error, code codes.Code, contains string, msgAndArgs ...any) bool {
tb.Helper()
if code == codes.OK {
if contains != "" {
return assert.Fail(tb, "cannot assert that an OK status has message %q", contains)
}
return AssertGRPCStatus(tb, err, code, "")
}
st := status.Convert(err)
if code != st.Code() || !strings.Contains(st.Message(), contains) {
return assert.Fail(tb, fmt.Sprintf("Status code=%q msg=%q does not match code=%q with message containing %q", st.Code(), st.Message(), code, contains), msgAndArgs...)
}
return true
}
func RequireGRPCStatusHasPrefix(tb testing.TB, err error, code codes.Code, prefix string) {
tb.Helper()
if !AssertGRPCStatusHasPrefix(tb, err, code, prefix) {
tb.FailNow()
}
}
func AssertGRPCStatusHasPrefix(tb testing.TB, err error, code codes.Code, prefix string) bool {
tb.Helper()
st := status.Convert(err)
if code != st.Code() || !strings.HasPrefix(st.Message(), prefix) {
return assert.Fail(tb, fmt.Sprintf("Status code=%q msg=%q does not match code=%q with message prefix %q", st.Code(), st.Message(), code, prefix))
}
return true
}
func RequireProtoListEqual(tb testing.TB, expected, actual any) {
tb.Helper()
if !AssertProtoListEqual(tb, expected, actual) {
tb.FailNow()
}
}
func AssertProtoListEqual(tb testing.TB, expected, actual any) bool {
tb.Helper()
ev := reflect.ValueOf(expected)
et := ev.Type()
av := reflect.ValueOf(actual)
at := av.Type()
if et.Kind() != reflect.Slice {
return assert.Fail(tb, "expected value is not a slice")
}
if !et.Elem().Implements(protoMessageType) {
return assert.Fail(tb, "expected value is not a slice of elements that implement proto.Message")
}
if at.Kind() != reflect.Slice {
return assert.Fail(tb, "actual value is not a slice")
}
if !at.Elem().Implements(protoMessageType) {
return assert.Fail(tb, "actual value is not a slice of elements that implement proto.Message")
}
if !assert.Equal(tb, ev.Len(), av.Len(), "expected %d elements in list; got %d", ev.Len(), av.Len()) {
return false
}
for i := 0; i < ev.Len(); i++ {
e := ev.Index(i).Interface().(proto.Message)
a := av.Index(i).Interface().(proto.Message)
if !AssertProtoEqual(tb, e, a, "proto %d in list is not equal", i) {
return false
}
}
return true
}
func RequireProtoEqual(tb testing.TB, expected, actual proto.Message, msgAndArgs ...any) {
tb.Helper()
if !AssertProtoEqual(tb, expected, actual, msgAndArgs...) {
tb.FailNow()
}
}
func AssertProtoEqual(tb testing.TB, expected, actual proto.Message, msgAndArgs ...any) bool {
tb.Helper()
return assert.Empty(tb, cmp.Diff(expected, actual, protocmp.Transform()), msgAndArgs...)
}
func RequireErrorPrefix(tb testing.TB, err error, prefix string) {
tb.Helper()
if !AssertErrorPrefix(tb, err, prefix) {
tb.FailNow()
}
}
func AssertErrorPrefix(tb testing.TB, err error, prefix string) bool {
tb.Helper()
if err == nil || !strings.HasPrefix(err.Error(), prefix) {
return assert.Fail(tb, fmt.Sprintf("error %v does not have prefix %q", err, prefix))
}
return true
}
func AssertHasPrefix(tb testing.TB, msg string, prefix string) bool {
tb.Helper()
if !strings.HasPrefix(msg, prefix) {
return assert.Fail(tb, fmt.Sprintf("string %q does not have prefix %q", msg, prefix))
}
return true
}