Skip to content

Commit

Permalink
Use go-cmp.Equal instead of reflect.DeepEqual
Browse files Browse the repository at this point in the history
  • Loading branch information
posener committed Mar 14, 2020
1 parent f6cbfc0 commit 64d5d85
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 119 deletions.
83 changes: 9 additions & 74 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"unicode/utf8"

"github.com/davecgh/go-spew/spew"
"github.com/pmezard/go-difflib/difflib"
yaml "gopkg.in/yaml.v2"
)

Expand Down Expand Up @@ -59,20 +58,7 @@ func ObjectsAreEqual(expected, actual interface{}) bool {
if expected == nil || actual == nil {
return expected == actual
}

exp, ok := expected.([]byte)
if !ok {
return reflect.DeepEqual(expected, actual)
}

act, ok := actual.([]byte)
if !ok {
return false
}
if exp == nil || act == nil {
return exp == nil && act == nil
}
return bytes.Equal(exp, act)
return cmpEqual(expected, actual)
}

// ObjectsAreEqualValues gets whether two objects are equal, or if their
Expand All @@ -89,7 +75,7 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool {
expectedValue := reflect.ValueOf(expected)
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {
// Attempt comparison after type conversion
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
return cmpEqual(expectedValue.Convert(actualType).Interface(), actual)
}

return false
Expand All @@ -110,7 +96,7 @@ func CallerInfo() []string {
var line int
var name string

callers := []string{}
var callers []string
for i := 0; ; i++ {
pc, file, line, ok = runtime.Caller(i)
if !ok {
Expand Down Expand Up @@ -341,7 +327,7 @@ func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{})
}

if !ObjectsAreEqual(expected, actual) {
diff := diff(expected, actual)
diff := cmpDiff(expected, actual)
expected, actual = formatUnequalValues(expected, actual)
return Fail(t, fmt.Sprintf("Not equal: \n"+
"expected: %s\n"+
Expand Down Expand Up @@ -462,7 +448,7 @@ func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interfa
}

if !ObjectsAreEqualValues(expected, actual) {
diff := diff(expected, actual)
diff := cmpDiff(expected, actual)
expected, actual = formatUnequalValues(expected, actual)
return Fail(t, fmt.Sprintf("Not equal: \n"+
"expected: %s\n"+
Expand Down Expand Up @@ -575,7 +561,7 @@ func isEmpty(object interface{}) bool {
// for all other types, compare against the zero value
default:
zero := reflect.Zero(objValue.Type())
return reflect.DeepEqual(object, zero.Interface())
return cmpEqual(object, zero.Interface())
}
}

Expand Down Expand Up @@ -1369,7 +1355,7 @@ func matchRegexp(rx interface{}, str interface{}) bool {
r = regexp.MustCompile(fmt.Sprint(rx))
}

return (r.FindStringIndex(fmt.Sprint(str)) != nil)
return r.FindStringIndex(fmt.Sprint(str)) != nil

}

Expand Down Expand Up @@ -1414,7 +1400,7 @@ func Zero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if i != nil && !reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) {
if i != nil && !cmpEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) {
return Fail(t, fmt.Sprintf("Should be zero, but was %v", i), msgAndArgs...)
}
return true
Expand All @@ -1425,7 +1411,7 @@ func NotZero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) {
if i == nil || cmpEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) {
return Fail(t, fmt.Sprintf("Should not be zero, but was %v", i), msgAndArgs...)
}
return true
Expand Down Expand Up @@ -1542,57 +1528,6 @@ func YAMLEq(t TestingT, expected string, actual string, msgAndArgs ...interface{
return Equal(t, expectedYAMLAsInterface, actualYAMLAsInterface, msgAndArgs...)
}

func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) {
t := reflect.TypeOf(v)
k := t.Kind()

if k == reflect.Ptr {
t = t.Elem()
k = t.Kind()
}
return t, k
}

// diff returns a diff of both values as long as both are of the same type and
// are a struct, map, slice, array or string. Otherwise it returns an empty string.
func diff(expected interface{}, actual interface{}) string {
if expected == nil || actual == nil {
return ""
}

et, ek := typeAndKind(expected)
at, _ := typeAndKind(actual)

if et != at {
return ""
}

if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array && ek != reflect.String {
return ""
}

var e, a string
if et != reflect.TypeOf("") {
e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual)
} else {
e = reflect.ValueOf(expected).String()
a = reflect.ValueOf(actual).String()
}

diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
A: difflib.SplitLines(e),
B: difflib.SplitLines(a),
FromFile: "Expected",
FromDate: "",
ToFile: "Actual",
ToDate: "",
Context: 1,
})

return "\n\nDiff:\n" + diff
}

func isFunction(arg interface{}) bool {
if arg == nil {
return false
Expand Down
81 changes: 36 additions & 45 deletions assert/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1419,7 +1419,7 @@ func TestInDeltaMapValues(t *testing.T) {
f: False,
},
} {
tc.f(t, InDeltaMapValues(mockT, tc.expect, tc.actual, tc.delta), tc.title+"\n"+diff(tc.expect, tc.actual))
tc.f(t, InDeltaMapValues(mockT, tc.expect, tc.actual, tc.delta), tc.title+"\n"+cmpDiff(tc.expect, tc.actual))
}
}

Expand Down Expand Up @@ -1855,13 +1855,11 @@ func TestDiff(t *testing.T) {
Diff:
--- Expected
+++ Actual
@@ -1,3 +1,3 @@
(struct { foo string }) {
- foo: (string) (len=5) "hello"
+ foo: (string) (len=3) "bar"
}
root.foo:
-: "hello"
+: "bar"
`
actual := diff(
actual := cmpDiff(
struct{ foo string }{"hello"},
struct{ foo string }{"bar"},
)
Expand All @@ -1872,16 +1870,11 @@ Diff:
Diff:
--- Expected
+++ Actual
@@ -2,5 +2,5 @@
(int) 1,
- (int) 2,
(int) 3,
- (int) 4
+ (int) 5,
+ (int) 7
}
{[]int}:
-: []int{1, 2, 3, 4}
+: []int{1, 3, 5, 7}
`
actual = diff(
actual = cmpDiff(
[]int{1, 2, 3, 4},
[]int{1, 3, 5, 7},
)
Expand All @@ -1892,15 +1885,11 @@ Diff:
Diff:
--- Expected
+++ Actual
@@ -2,4 +2,4 @@
(int) 1,
- (int) 2,
- (int) 3
+ (int) 3,
+ (int) 5
}
{[]int}:
-: []int{1, 2, 3}
+: []int{1, 3, 5}
`
actual = diff(
actual = cmpDiff(
[]int{1, 2, 3, 4}[0:3],
[]int{1, 3, 5, 7}[0:3],
)
Expand All @@ -1911,19 +1900,21 @@ Diff:
Diff:
--- Expected
+++ Actual
@@ -1,6 +1,6 @@
(map[string]int) (len=4) {
- (string) (len=4) "four": (int) 4,
+ (string) (len=4) "five": (int) 5,
(string) (len=3) "one": (int) 1,
- (string) (len=5) "three": (int) 3,
- (string) (len=3) "two": (int) 2
+ (string) (len=5) "seven": (int) 7,
+ (string) (len=5) "three": (int) 3
}
{map[string]int}["five"]:
-: <non-existent>
+: 5
{map[string]int}["four"]:
-: 4
+: <non-existent>
{map[string]int}["seven"]:
-: <non-existent>
+: 7
{map[string]int}["two"]:
-: 2
+: <non-existent>
`

actual = diff(
actual = cmpDiff(
map[string]int{"one": 1, "two": 2, "three": 3, "four": 4},
map[string]int{"one": 1, "three": 3, "five": 5, "seven": 7},
)
Expand All @@ -1941,7 +1932,7 @@ Diff:
})
`

actual = diff(
actual = cmpDiff(
errors.New("some expected error"),
errors.New("actual error"),
)
Expand All @@ -1959,7 +1950,7 @@ Diff:
}
`

actual = diff(
actual = cmpDiff(
diffTestingStruct{A: "some string", B: 10},
diffTestingStruct{A: "some string", B: 15},
)
Expand All @@ -1976,12 +1967,12 @@ func TestTimeEqualityErrorFormatting(t *testing.T) {
}

func TestDiffEmptyCases(t *testing.T) {
Equal(t, "", diff(nil, nil))
Equal(t, "", diff(struct{ foo string }{}, nil))
Equal(t, "", diff(nil, struct{ foo string }{}))
Equal(t, "", diff(1, 2))
Equal(t, "", diff(1, 2))
Equal(t, "", diff([]int{1}, []bool{true}))
Equal(t, "", cmpDiff(nil, nil))
Equal(t, "", cmpDiff(struct{ foo string }{}, nil))
Equal(t, "", cmpDiff(nil, struct{ foo string }{}))
Equal(t, "", cmpDiff(1, 2))
Equal(t, "", cmpDiff(1, 2))
Equal(t, "", cmpDiff([]int{1}, []bool{true}))
}

// Ensure there are no data races
Expand All @@ -2007,7 +1998,7 @@ func TestDiffRace(t *testing.T) {
rChans[idx] = make(chan string)
go func(ch chan string) {
defer close(ch)
ch <- diff(expected, actual)
ch <- cmpDiff(expected, actual)
}(rChans[idx])
}

Expand Down Expand Up @@ -2064,7 +2055,7 @@ func TestBytesEqual(t *testing.T) {
{nil, make([]byte, 0)},
}
for i, c := range cases {
Equal(t, reflect.DeepEqual(c.a, c.b), ObjectsAreEqual(c.a, c.b), "case %d failed", i+1)
Equal(t, cmpEqual(c.a, c.b), ObjectsAreEqual(c.a, c.b), "case %d failed", i+1)
}
}

Expand Down
Loading

0 comments on commit 64d5d85

Please sign in to comment.