Skip to content

Commit

Permalink
Move object equality functions to another package for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
mminklet committed Oct 10, 2022
1 parent 181cea6 commit ae64ce6
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 103 deletions.
81 changes: 28 additions & 53 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/davecgh/go-spew/spew"
"github.com/pmezard/go-difflib/difflib"
"github.com/stretchr/testify/assert/equal"
yaml "gopkg.in/yaml.v3"
)

Expand Down Expand Up @@ -53,49 +54,6 @@ type Comparison func() (success bool)
Helper functions
*/

// ObjectsAreEqual determines if two objects are considered equal.
//
// This function does no assertion of any kind.
func ObjectsAreEqual(expected, actual interface{}) bool {
if expected == nil || actual == nil {
return expected == actual
}

exp, ok := expected.([]byte)
if !ok {
return reflect.DeepEqual(expected, actual)
}

act, ok := actual.([]byte)
if !ok {
return false
}
if exp == nil || act == nil {
return exp == nil && act == nil
}
return bytes.Equal(exp, act)
}

// ObjectsAreEqualValues gets whether two objects are equal, or if their
// values are equal.
func ObjectsAreEqualValues(expected, actual interface{}) bool {
if ObjectsAreEqual(expected, actual) {
return true
}

actualType := reflect.TypeOf(actual)
if actualType == nil {
return false
}
expectedValue := reflect.ValueOf(expected)
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {
// Attempt comparison after type conversion
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
}

return false
}

/* CallerInfo is necessary because the assert functions use the testing object
internally, causing it to print the file:line of the assert method, rather than where
the problem actually occurred in calling code.*/
Expand Down Expand Up @@ -319,13 +277,30 @@ func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs
h.Helper()
}

if !ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) {
if !equal.ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) {
return Fail(t, fmt.Sprintf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object)), msgAndArgs...)
}

return true
}

// ObjectsAreEqual determines if two objects are considered equal.
//
// Deprecated: Use equal.ObjectsAreEqual, which does the exact same thing,
// but is clearer that it does no assertions of any kind.
func ObjectsAreEqual(expected, actual interface{}) bool {
return equal.ObjectsAreEqual(expected, actual)
}

// ObjectsAreEqualValues gets whether two objects are equal, or if their
// values are equal.

// Deprecated: Use equal.ObjectsAreEqualValues, which does the exact same thing,
// but is clearer that it does no assertions of any kind.
func ObjectsAreEqualValues(expected, actual interface{}) bool {
return equal.ObjectsAreEqualValues(expected, actual)
}

// Equal asserts that two objects are equal.
//
// assert.Equal(t, 123, 123)
Expand All @@ -342,7 +317,7 @@ func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{})
expected, actual, err), msgAndArgs...)
}

if !ObjectsAreEqual(expected, actual) {
if !equal.ObjectsAreEqual(expected, actual) {
diff := diff(expected, actual)
expected, actual = formatUnequalValues(expected, actual)
return Fail(t, fmt.Sprintf("Not equal: \n"+
Expand Down Expand Up @@ -463,7 +438,7 @@ func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interfa
h.Helper()
}

if !ObjectsAreEqualValues(expected, actual) {
if !equal.ObjectsAreEqualValues(expected, actual) {
diff := diff(expected, actual)
expected, actual = formatUnequalValues(expected, actual)
return Fail(t, fmt.Sprintf("Not equal: \n"+
Expand Down Expand Up @@ -694,7 +669,7 @@ func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{
expected, actual, err), msgAndArgs...)
}

if ObjectsAreEqual(expected, actual) {
if equal.ObjectsAreEqual(expected, actual) {
return Fail(t, fmt.Sprintf("Should not be: %#v\n", actual), msgAndArgs...)
}

Expand All @@ -710,7 +685,7 @@ func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...inte
h.Helper()
}

if ObjectsAreEqualValues(expected, actual) {
if equal.ObjectsAreEqualValues(expected, actual) {
return Fail(t, fmt.Sprintf("Should not be: %#v\n", actual), msgAndArgs...)
}

Expand Down Expand Up @@ -744,15 +719,15 @@ func containsElement(list interface{}, element interface{}) (ok, found bool) {
if listKind == reflect.Map {
mapKeys := listValue.MapKeys()
for i := 0; i < len(mapKeys); i++ {
if ObjectsAreEqual(mapKeys[i].Interface(), element) {
if equal.ObjectsAreEqual(mapKeys[i].Interface(), element) {
return true, true
}
}
return true, false
}

for i := 0; i < listValue.Len(); i++ {
if ObjectsAreEqual(listValue.Index(i).Interface(), element) {
if equal.ObjectsAreEqual(listValue.Index(i).Interface(), element) {
return true, true
}
}
Expand Down Expand Up @@ -845,7 +820,7 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
subsetElement := subsetValue.MapIndex(subsetKey).Interface()
listElement := listValue.MapIndex(subsetKey).Interface()

if !ObjectsAreEqual(subsetElement, listElement) {
if !equal.ObjectsAreEqual(subsetElement, listElement) {
return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", list, subsetElement), msgAndArgs...)
}
}
Expand Down Expand Up @@ -906,7 +881,7 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
subsetElement := subsetValue.MapIndex(subsetKey).Interface()
listElement := listValue.MapIndex(subsetKey).Interface()

if !ObjectsAreEqual(subsetElement, listElement) {
if !equal.ObjectsAreEqual(subsetElement, listElement) {
return true
}
}
Expand Down Expand Up @@ -983,7 +958,7 @@ func diffLists(listA, listB interface{}) (extraA, extraB []interface{}) {
if visited[j] {
continue
}
if ObjectsAreEqual(bValue.Index(j).Interface(), element) {
if equal.ObjectsAreEqual(bValue.Index(j).Interface(), element) {
visited[j] = true
found = true
break
Expand Down
52 changes: 3 additions & 49 deletions assert/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert/equal"
)

var (
Expand Down Expand Up @@ -100,54 +102,6 @@ func (a *AssertionTesterConformingObject) TestMethod() {
type AssertionTesterNonConformingObject struct {
}

func TestObjectsAreEqual(t *testing.T) {
cases := []struct {
expected interface{}
actual interface{}
result bool
}{
// cases that are expected to be equal
{"Hello World", "Hello World", true},
{123, 123, true},
{123.5, 123.5, true},
{[]byte("Hello World"), []byte("Hello World"), true},
{nil, nil, true},

// cases that are expected not to be equal
{map[int]int{5: 10}, map[int]int{10: 20}, false},
{'x', "x", false},
{"x", 'x', false},
{0, 0.1, false},
{0.1, 0, false},
{time.Now, time.Now, false},
{func() {}, func() {}, false},
{uint32(10), int32(10), false},
}

for _, c := range cases {
t.Run(fmt.Sprintf("ObjectsAreEqual(%#v, %#v)", c.expected, c.actual), func(t *testing.T) {
res := ObjectsAreEqual(c.expected, c.actual)

if res != c.result {
t.Errorf("ObjectsAreEqual(%#v, %#v) should return %#v", c.expected, c.actual, c.result)
}

})
}

// Cases where type differ but values are equal
if !ObjectsAreEqualValues(uint32(10), int32(10)) {
t.Error("ObjectsAreEqualValues should return true")
}
if ObjectsAreEqualValues(0, nil) {
t.Fail()
}
if ObjectsAreEqualValues(nil, 0) {
t.Fail()
}

}

func TestImplements(t *testing.T) {

mockT := new(testing.T)
Expand Down Expand Up @@ -2183,7 +2137,7 @@ func TestBytesEqual(t *testing.T) {
{nil, make([]byte, 0)},
}
for i, c := range cases {
Equal(t, reflect.DeepEqual(c.a, c.b), ObjectsAreEqual(c.a, c.b), "case %d failed", i+1)
Equal(t, reflect.DeepEqual(c.a, c.b), equal.ObjectsAreEqual(c.a, c.b), "case %d failed", i+1)
}
}

Expand Down
51 changes: 51 additions & 0 deletions assert/equal/equal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package equal

import (
"bytes"
"reflect"
)

// ObjectsAreEqual determines if two objects are considered equal.
//
// This function does no assertion of any kind.
func ObjectsAreEqual(expected, actual interface{}) bool {
if expected == nil || actual == nil {
return expected == actual
}

exp, ok := expected.([]byte)
if !ok {
return reflect.DeepEqual(expected, actual)
}

act, ok := actual.([]byte)
if !ok {
return false
}
if exp == nil || act == nil {
return exp == nil && act == nil
}
return bytes.Equal(exp, act)
}

// ObjectsAreEqualValues gets whether two objects are equal, or if their
// values are equal.
//
// This function does no assertion of any kind.
func ObjectsAreEqualValues(expected, actual interface{}) bool {
if ObjectsAreEqual(expected, actual) {
return true
}

actualType := reflect.TypeOf(actual)
if actualType == nil {
return false
}
expectedValue := reflect.ValueOf(expected)
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {
// Attempt comparison after type conversion
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
}

return false
}
55 changes: 55 additions & 0 deletions assert/equal/equal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package equal

import (
"fmt"
"testing"
"time"
)

func TestObjectsAreEqual(t *testing.T) {
cases := []struct {
expected interface{}
actual interface{}
result bool
}{
// cases that are expected to be equal
{"Hello World", "Hello World", true},
{123, 123, true},
{123.5, 123.5, true},
{[]byte("Hello World"), []byte("Hello World"), true},
{nil, nil, true},

// cases that are expected not to be equal
{map[int]int{5: 10}, map[int]int{10: 20}, false},
{'x', "x", false},
{"x", 'x', false},
{0, 0.1, false},
{0.1, 0, false},
{time.Now, time.Now, false},
{func() {}, func() {}, false},
{uint32(10), int32(10), false},
}

for _, c := range cases {
t.Run(fmt.Sprintf("ObjectsAreEqual(%#v, %#v)", c.expected, c.actual), func(t *testing.T) {
res := ObjectsAreEqual(c.expected, c.actual)

if res != c.result {
t.Errorf("ObjectsAreEqual(%#v, %#v) should return %#v", c.expected, c.actual, c.result)
}

})
}

// Cases where type differ but values are equal
if !ObjectsAreEqualValues(uint32(10), int32(10)) {
t.Error("ObjectsAreEqualValues should return true")
}
if ObjectsAreEqualValues(0, nil) {
t.Fail()
}
if ObjectsAreEqualValues(nil, 0) {
t.Fail()
}

}
3 changes: 2 additions & 1 deletion mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/pmezard/go-difflib/difflib"
"github.com/stretchr/objx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/assert/equal"
)

// TestingT is an interface wrapper around *testing.T
Expand Down Expand Up @@ -923,7 +924,7 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
} else {
// normal checking

if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
if equal.ObjectsAreEqual(expected, Anything) || equal.ObjectsAreEqual(actual, Anything) || equal.ObjectsAreEqual(actual, expected) {
// match
output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt)
} else {
Expand Down

0 comments on commit ae64ce6

Please sign in to comment.