Skip to content

Commit

Permalink
Use go-cmp.Equal instead of reflect.DeepEqual
Browse files Browse the repository at this point in the history
  • Loading branch information
posener committed Mar 21, 2018
1 parent 20dae58 commit fc29410
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 94 deletions.
69 changes: 7 additions & 62 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ import (
"time"
"unicode"
"unicode/utf8"

"github.com/davecgh/go-spew/spew"
"github.com/pmezard/go-difflib/difflib"
)

//go:generate go run ../_codegen/main.go -output-package=assert -template=assertion_format.go.tmpl
Expand Down Expand Up @@ -51,7 +48,7 @@ func ObjectsAreEqual(expected, actual interface{}) bool {
}
return bytes.Equal(exp, act)
}
return reflect.DeepEqual(expected, actual)
return compare(expected, actual)

}

Expand All @@ -69,7 +66,7 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool {
expectedValue := reflect.ValueOf(expected)
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {
// Attempt comparison after type conversion
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
return compare(expectedValue.Convert(actualType).Interface(), actual)
}

return false
Expand All @@ -90,7 +87,7 @@ func CallerInfo() []string {
ok := false
name := ""

callers := []string{}
var callers []string
for i := 0; ; i++ {
pc, file, line, ok = runtime.Caller(i)
if !ok {
Expand Down Expand Up @@ -449,7 +446,7 @@ func isEmpty(object interface{}) bool {
// for all other types, compare against the zero value
default:
zero := reflect.Zero(objValue.Type())
return reflect.DeepEqual(object, zero.Interface())
return compare(object, zero.Interface())
}
}

Expand Down Expand Up @@ -1179,7 +1176,7 @@ func matchRegexp(rx interface{}, str interface{}) bool {
r = regexp.MustCompile(fmt.Sprint(rx))
}

return (r.FindStringIndex(fmt.Sprint(str)) != nil)
return r.FindStringIndex(fmt.Sprint(str)) != nil

}

Expand Down Expand Up @@ -1224,7 +1221,7 @@ func Zero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if i != nil && !reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) {
if i != nil && !compare(i, reflect.Zero(reflect.TypeOf(i)).Interface()) {
return Fail(t, fmt.Sprintf("Should be zero, but was %v", i), msgAndArgs...)
}
return true
Expand All @@ -1235,7 +1232,7 @@ func NotZero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) {
if i == nil || compare(i, reflect.Zero(reflect.TypeOf(i)).Interface()) {
return Fail(t, fmt.Sprintf("Should not be zero, but was %v", i), msgAndArgs...)
}
return true
Expand Down Expand Up @@ -1297,51 +1294,6 @@ func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{
return Equal(t, expectedJSONAsInterface, actualJSONAsInterface, msgAndArgs...)
}

func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) {
t := reflect.TypeOf(v)
k := t.Kind()

if k == reflect.Ptr {
t = t.Elem()
k = t.Kind()
}
return t, k
}

// diff returns a diff of both values as long as both are of the same type and
// are a struct, map, slice or array. Otherwise it returns an empty string.
func diff(expected interface{}, actual interface{}) string {
if expected == nil || actual == nil {
return ""
}

et, ek := typeAndKind(expected)
at, _ := typeAndKind(actual)

if et != at {
return ""
}

if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array {
return ""
}

e := spewConfig.Sdump(expected)
a := spewConfig.Sdump(actual)

diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
A: difflib.SplitLines(e),
B: difflib.SplitLines(a),
FromFile: "Expected",
FromDate: "",
ToFile: "Actual",
ToDate: "",
Context: 1,
})

return "\n\nDiff:\n" + diff
}

// validateEqualArgs checks whether provided arguments can be safely used in the
// Equal/NotEqual functions.
func validateEqualArgs(expected, actual interface{}) error {
Expand All @@ -1358,13 +1310,6 @@ func isFunction(arg interface{}) bool {
return reflect.TypeOf(arg).Kind() == reflect.Func
}

var spewConfig = spew.ConfigState{
Indent: " ",
DisablePointerAddresses: true,
DisableCapacities: true,
SortKeys: true,
}

type tHelper interface {
Helper()
}
55 changes: 23 additions & 32 deletions assert/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ func TestEqualFormatting(t *testing.T) {
want string
}{
{equalWant: "want", equalGot: "got", want: "\tassertions.go:[0-9]+: \n\t\t\tError Trace:\t\n\t\t\tError: \tNot equal: \n\t\t\t \texpected: \"want\"\n\t\t\t \tactual : \"got\"\n"},
{equalWant: "want", equalGot: "got", msgAndArgs: []interface{}{"hello, %v!", "world"}, want: "\tassertions.go:[0-9]+: \n\t\t\tError Trace:\t\n\t\t\tError: \tNot equal: \n\t\t\t \texpected: \"want\"\n\t\t\t \tactual : \"got\"\n\t\t\tMessages: \thello, world!\n"},
{equalWant: "want", equalGot: "got", msgAndArgs: []interface{}{"hello, %v!", "world"}, want: `\tassertions.go:\d+: \n\t\t\tError Trace:\t\n\t\t\tError: \tNot equal: \n\t\t\t \texpected: "want"\n\t\t\t \tactual : "got"\n\t\t\tMessages: \thello, world!\n`},
} {
mockT := &bufferT{}
Equal(mockT, currCase.equalWant, currCase.equalGot, currCase.msgAndArgs...)
Expand Down Expand Up @@ -1400,11 +1400,9 @@ func TestDiff(t *testing.T) {
Diff:
--- Expected
+++ Actual
@@ -1,3 +1,3 @@
(struct { foo string }) {
- foo: (string) (len=5) "hello"
+ foo: (string) (len=3) "bar"
}
root.foo:
-: "hello"
+: "bar"
`
actual := diff(
struct{ foo string }{"hello"},
Expand All @@ -1417,14 +1415,9 @@ Diff:
Diff:
--- Expected
+++ Actual
@@ -2,5 +2,5 @@
(int) 1,
- (int) 2,
(int) 3,
- (int) 4
+ (int) 5,
+ (int) 7
}
{[]int}:
-: []int{1, 2, 3, 4}
+: []int{1, 3, 5, 7}
`
actual = diff(
[]int{1, 2, 3, 4},
Expand All @@ -1437,13 +1430,9 @@ Diff:
Diff:
--- Expected
+++ Actual
@@ -2,4 +2,4 @@
(int) 1,
- (int) 2,
- (int) 3
+ (int) 3,
+ (int) 5
}
{[]int}:
-: []int{1, 2, 3}
+: []int{1, 3, 5}
`
actual = diff(
[]int{1, 2, 3, 4}[0:3],
Expand All @@ -1456,16 +1445,18 @@ Diff:
Diff:
--- Expected
+++ Actual
@@ -1,6 +1,6 @@
(map[string]int) (len=4) {
- (string) (len=4) "four": (int) 4,
+ (string) (len=4) "five": (int) 5,
(string) (len=3) "one": (int) 1,
- (string) (len=5) "three": (int) 3,
- (string) (len=3) "two": (int) 2
+ (string) (len=5) "seven": (int) 7,
+ (string) (len=5) "three": (int) 3
}
{map[string]int}["five"]:
-: <non-existent>
+: 5
{map[string]int}["four"]:
-: 4
+: <non-existent>
{map[string]int}["seven"]:
-: <non-existent>
+: 7
{map[string]int}["two"]:
-: 2
+: <non-existent>
`

actual = diff(
Expand Down Expand Up @@ -1555,7 +1546,7 @@ func TestBytesEqual(t *testing.T) {
{nil, make([]byte, 0)},
}
for i, c := range cases {
Equal(t, reflect.DeepEqual(c.a, c.b), ObjectsAreEqual(c.a, c.b), "case %d failed", i+1)
Equal(t, compare(c.a, c.b), ObjectsAreEqual(c.a, c.b), "case %d failed", i+1)
}
}

Expand Down
151 changes: 151 additions & 0 deletions assert/compare.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package assert

import (
"reflect"

"github.com/google/go-cmp/cmp"
)

// compare compares two objects
func compare(expected, actual interface{}) bool {
return cmp.Equal(expected, actual, compareOptions(expected, actual)...)
}

// diff returns a diff of both values as long as both are of the same type and
// are a struct, map, slice or array. Otherwise it returns an empty string.
func diff(expected, actual interface{}) string {
if expected == nil || actual == nil {
return ""
}

et, ek := typeAndKind(expected)
at, _ := typeAndKind(actual)

if et != at {
return ""
}

if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array {
return ""
}

diff := cmp.Diff(expected, actual, compareOptions(expected, actual)...)
if diff != "" {
diff = "\n\nDiff:\n--- Expected\n+++ Actual\n" + diff
}
return diff
}

// compareOptions are cmp.Options used for cmp.Equal and cmp.Diff to compare
// two general objects for testing purposes
func compareOptions(expected, actual interface{}) cmp.Options {
return cmp.Options{
deepAllowUnexported(expected, actual),
compareIdenticalPointers,
}
}

// deepAllowUnexported returns option for cmp.Equal or cmp.Diff in which
// all unexported fields in the two compared types (recursively) are
// allowed.
// Code from https://github.com/google/go-cmp/issues/40 with modification
// to work with cyclic struct
func deepAllowUnexported(vs ...interface{}) cmp.Option {
var (
// allUnexported is a set of types to be added to the unexported list
allUnexported = make(map[reflect.Type]bool)
// visited are list of pointer which are visited during the recursive collection
// of the referenced types.
// It is used to detect cycles and prevent infinite recursion.
visited = make(map[uintptr]bool)
)

// Collect all types from all given objects
for _, v := range vs {
structTypes(reflect.ValueOf(v), allUnexported, visited)
}

// Collect the referenced types
var types []interface{}
for t := range allUnexported {
types = append(types, reflect.New(t).Elem().Interface())
}

// Return cmp option which allows all unexported fields in all the collected types
return cmp.AllowUnexported(types...)
}

// structTypes is a recursive search for all referenced types from a given object.
// It searches recursively in all the given object fields and references, and put the
// collected type in the `m` set.
// It uses the `visited` set to detect cycles and prevent infinite recursion
func structTypes(v reflect.Value, m map[reflect.Type]bool, visited map[uintptr]bool) {
if !v.IsValid() {
return
}

// dive in according to the kind of the given object
switch v.Kind() {
case reflect.Ptr:
if v.IsNil() {
return
}
// prevent infinite recursion
if visited[v.Elem().UnsafeAddr()] {
return
}
// remember jumping to a pointed address
visited[v.Elem().UnsafeAddr()] = true
structTypes(v.Elem(), m, visited)
case reflect.Interface:
if v.IsNil() {
return
}
// search into the object that implement the interface
structTypes(v.Elem(), m, visited)
case reflect.Slice, reflect.Array:
// recursively search in all the slice/array objects
for i := 0; i < v.Len(); i++ {
structTypes(v.Index(i), m, visited)
}
case reflect.Map:
// recursively search in all the map values
for _, k := range v.MapKeys() {
structTypes(v.MapIndex(k), m, visited)
}
case reflect.Struct:
// add the type to the collected types.
m[v.Type()] = true
// recursively search in all the struct fields
for i := 0; i < v.NumField(); i++ {
structTypes(v.Field(i), m, visited)
}
}
}

// compareIdenticalPointers is a cmp option that returns true if the two compared
// objects are pointers and are pointing on the same thing.
var compareIdenticalPointers = cmp.FilterPath(func(p cmp.Path) bool {
// Filter for pointer kinds only.
t := p.Last().Type()
return t != nil && t.Kind() == reflect.Ptr
}, cmp.FilterValues(func(x, y interface{}) bool {
// Filter for pointer values that are identical.
vx := reflect.ValueOf(x)
vy := reflect.ValueOf(y)
return vx.IsValid() && vy.IsValid() && vx.Pointer() == vy.Pointer()
}, cmp.Comparer(func(_, _ interface{}) bool {
// Consider them equal no matter what.
return true
})))

func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) {
t := reflect.TypeOf(v)
k := t.Kind()

if k == reflect.Ptr {
t = t.Elem()
k = t.Kind()
}
return t, k
}

0 comments on commit fc29410

Please sign in to comment.