Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move object equality functions to another package for clarity #1279

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 73 additions & 51 deletions assert/assertion_format.go

Large diffs are not rendered by default.

248 changes: 146 additions & 102 deletions assert/assertion_forward.go

Large diffs are not rendered by default.

231 changes: 83 additions & 148 deletions assert/assertions.go

Large diffs are not rendered by default.

97 changes: 3 additions & 94 deletions assert/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"strings"
"testing"
"time"

"github.com/stretchr/testify/internal/equal"
"unsafe"
)

Expand Down Expand Up @@ -101,99 +103,6 @@ func (a *AssertionTesterConformingObject) TestMethod() {
type AssertionTesterNonConformingObject struct {
}

func TestObjectsAreEqual(t *testing.T) {
cases := []struct {
expected interface{}
actual interface{}
result bool
}{
// cases that are expected to be equal
{"Hello World", "Hello World", true},
{123, 123, true},
{123.5, 123.5, true},
{[]byte("Hello World"), []byte("Hello World"), true},
{nil, nil, true},

// cases that are expected not to be equal
{map[int]int{5: 10}, map[int]int{10: 20}, false},
{'x', "x", false},
{"x", 'x', false},
{0, 0.1, false},
{0.1, 0, false},
{time.Now, time.Now, false},
{func() {}, func() {}, false},
{uint32(10), int32(10), false},
}

for _, c := range cases {
t.Run(fmt.Sprintf("ObjectsAreEqual(%#v, %#v)", c.expected, c.actual), func(t *testing.T) {
res := ObjectsAreEqual(c.expected, c.actual)

if res != c.result {
t.Errorf("ObjectsAreEqual(%#v, %#v) should return %#v", c.expected, c.actual, c.result)
}

})
}

// Cases where type differ but values are equal
if !ObjectsAreEqualValues(uint32(10), int32(10)) {
t.Error("ObjectsAreEqualValues should return true")
}
if ObjectsAreEqualValues(0, nil) {
t.Fail()
}
if ObjectsAreEqualValues(nil, 0) {
t.Fail()
}

}

func TestObjectsExportedFieldsAreEqual(t *testing.T) {
type Nested struct {
Exported interface{}
notExported interface{}
}

type S struct {
Exported1 interface{}
Exported2 Nested
notExported1 interface{}
notExported2 Nested
}

type S2 struct {
foo interface{}
}

cases := []struct {
expected interface{}
actual interface{}
result bool
}{
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S{1, Nested{2, 3}, 4, Nested{5, 6}}, true},
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S{1, Nested{2, 3}, "a", Nested{5, 6}}, true},
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S{1, Nested{2, 3}, 4, Nested{5, "a"}}, true},
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S{1, Nested{2, 3}, 4, Nested{"a", "a"}}, true},
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S{1, Nested{2, "a"}, 4, Nested{5, 6}}, true},
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S{"a", Nested{2, 3}, 4, Nested{5, 6}}, false},
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S{1, Nested{"a", 3}, 4, Nested{5, 6}}, false},
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S2{1}, false},
{1, S{1, Nested{2, 3}, 4, Nested{5, 6}}, false},
}

for _, c := range cases {
t.Run(fmt.Sprintf("ObjectsExportedFieldsAreEqual(%#v, %#v)", c.expected, c.actual), func(t *testing.T) {
res := ObjectsExportedFieldsAreEqual(c.expected, c.actual)

if res != c.result {
t.Errorf("ObjectsExportedFieldsAreEqual(%#v, %#v) should return %#v", c.expected, c.actual, c.result)
}

})
}
}

func TestImplements(t *testing.T) {

mockT := new(testing.T)
Expand Down Expand Up @@ -2243,7 +2152,7 @@ func TestBytesEqual(t *testing.T) {
{nil, make([]byte, 0)},
}
for i, c := range cases {
Equal(t, reflect.DeepEqual(c.a, c.b), ObjectsAreEqual(c.a, c.b), "case %d failed", i+1)
Equal(t, reflect.DeepEqual(c.a, c.b), equal.ObjectsAreEqual(c.a, c.b), "case %d failed", i+1)
}
}

Expand Down
91 changes: 91 additions & 0 deletions internal/equal/equal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package equal

import (
"bytes"
"reflect"
)

// ObjectsAreEqual determines if two objects are considered equal.
//
// This function does no assertion of any kind.
func ObjectsAreEqual(expected, actual interface{}) bool {
if expected == nil || actual == nil {
return expected == actual
}

exp, ok := expected.([]byte)
if !ok {
return reflect.DeepEqual(expected, actual)
}

act, ok := actual.([]byte)
if !ok {
return false
}
if exp == nil || act == nil {
return exp == nil && act == nil
}
return bytes.Equal(exp, act)
}

// ObjectsExportedFieldsAreEqual determines if the exported (public) fields of two structs are considered equal.
// If the two objects are not of the same type, or if either of them are not a struct, they are not considered equal.
//
// This function does no assertion of any kind.
func ObjectsExportedFieldsAreEqual(expected, actual interface{}) bool {
if expected == nil || actual == nil {
return expected == actual
}

expectedType := reflect.TypeOf(expected)
actualType := reflect.TypeOf(actual)

if expectedType != actualType {
return false
}

if expectedType.Kind() != reflect.Struct || actualType.Kind() != reflect.Struct {
return false
}

expectedValue := reflect.ValueOf(expected)
actualValue := reflect.ValueOf(actual)

for i := 0; i < expectedType.NumField(); i++ {
field := expectedType.Field(i)
isExported := field.PkgPath == "" // should use field.IsExported() but it's not available in Go 1.16.5
if isExported {
var equal bool
if field.Type.Kind() == reflect.Struct {
equal = ObjectsExportedFieldsAreEqual(expectedValue.Field(i).Interface(), actualValue.Field(i).Interface())
} else {
equal = ObjectsAreEqualValues(expectedValue.Field(i).Interface(), actualValue.Field(i).Interface())
}

if !equal {
return false
}
}
}
return true
}

// ObjectsAreEqualValues gets whether two objects are equal, or if their
// values are equal.
func ObjectsAreEqualValues(expected, actual interface{}) bool {
if ObjectsAreEqual(expected, actual) {
return true
}

actualType := reflect.TypeOf(actual)
if actualType == nil {
return false
}
expectedValue := reflect.ValueOf(expected)
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {
// Attempt comparison after type conversion
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
}

return false
}
100 changes: 100 additions & 0 deletions internal/equal/equal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package equal

import (
"fmt"
"testing"
"time"
)

func TestObjectsAreEqual(t *testing.T) {
cases := []struct {
expected interface{}
actual interface{}
result bool
}{
// cases that are expected to be equal
{"Hello World", "Hello World", true},
{123, 123, true},
{123.5, 123.5, true},
{[]byte("Hello World"), []byte("Hello World"), true},
{nil, nil, true},

// cases that are expected not to be equal
{map[int]int{5: 10}, map[int]int{10: 20}, false},
{'x', "x", false},
{"x", 'x', false},
{0, 0.1, false},
{0.1, 0, false},
{time.Now, time.Now, false},
{func() {}, func() {}, false},
{uint32(10), int32(10), false},
}

for _, c := range cases {
t.Run(fmt.Sprintf("ObjectsAreEqual(%#v, %#v)", c.expected, c.actual), func(t *testing.T) {
res := ObjectsAreEqual(c.expected, c.actual)

if res != c.result {
t.Errorf("ObjectsAreEqual(%#v, %#v) should return %#v", c.expected, c.actual, c.result)
}

})
}

// Cases where type differ but values are equal
if !ObjectsAreEqualValues(uint32(10), int32(10)) {
t.Error("ObjectsAreEqualValues should return true")
}
if ObjectsAreEqualValues(0, nil) {
t.Fail()
}
if ObjectsAreEqualValues(nil, 0) {
t.Fail()
}

}

func TestObjectsExportedFieldsAreEqual(t *testing.T) {
type Nested struct {
Exported interface{}
notExported interface{}
}

type S struct {
Exported1 interface{}
Exported2 Nested
notExported1 interface{}
notExported2 Nested
}

type S2 struct {
foo interface{}
}

cases := []struct {
expected interface{}
actual interface{}
result bool
}{
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S{1, Nested{2, 3}, 4, Nested{5, 6}}, true},
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S{1, Nested{2, 3}, "a", Nested{5, 6}}, true},
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S{1, Nested{2, 3}, 4, Nested{5, "a"}}, true},
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S{1, Nested{2, 3}, 4, Nested{"a", "a"}}, true},
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S{1, Nested{2, "a"}, 4, Nested{5, 6}}, true},
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S{"a", Nested{2, 3}, 4, Nested{5, 6}}, false},
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S{1, Nested{"a", 3}, 4, Nested{5, 6}}, false},
{S{1, Nested{2, 3}, 4, Nested{5, 6}}, S2{1}, false},
{1, S{1, Nested{2, 3}, 4, Nested{5, 6}}, false},
}

for _, c := range cases {
t.Run(fmt.Sprintf("ObjectsExportedFieldsAreEqual(%#v, %#v)", c.expected, c.actual), func(t *testing.T) {
res := ObjectsExportedFieldsAreEqual(c.expected, c.actual)

if res != c.result {
t.Errorf("ObjectsExportedFieldsAreEqual(%#v, %#v) should return %#v", c.expected, c.actual, c.result)
}

})
}
}
3 changes: 2 additions & 1 deletion mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/pmezard/go-difflib/difflib"
"github.com/stretchr/objx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/internal/equal"
)

// TestingT is an interface wrapper around *testing.T
Expand Down Expand Up @@ -929,7 +930,7 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
} else {
// normal checking

if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
if equal.ObjectsAreEqual(expected, Anything) || equal.ObjectsAreEqual(actual, Anything) || equal.ObjectsAreEqual(actual, expected) {
// match
output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt)
} else {
Expand Down
Loading