Skip to content

Commit

Permalink
Add workround to fix time.Time equal
Browse files Browse the repository at this point in the history
  • Loading branch information
leoleoasd committed Jul 28, 2020
1 parent 1c8633f commit 90f9249
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 9 deletions.
187 changes: 179 additions & 8 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"time"
"unicode"
"unicode/utf8"
"unsafe"

"github.com/davecgh/go-spew/spew"
"github.com/pmezard/go-difflib/difflib"
Expand Down Expand Up @@ -60,19 +61,189 @@ func ObjectsAreEqual(expected, actual interface{}) bool {
return expected == actual
}

exp, ok := expected.([]byte)
if !ok {
return reflect.DeepEqual(expected, actual)
switch expected.(type) {
// if more non-struct types needs special care, add them here.
case []byte:
if act, ok := actual.([]byte); ok {
exp := expected.([]byte)
if exp == nil || act == nil {
return exp == nil && act == nil
}
return bytes.Equal(exp, act)
}
return false
}
v1 := reflect.ValueOf(expected)
v2 := reflect.ValueOf(actual)
if v1.Kind() == reflect.Struct {
if !v1.IsValid() || !v2.IsValid() {
return v1.IsValid() == v2.IsValid()
}
if v1.Type() != v2.Type() {
return false
}
if v1.CanAddr() && v2.CanAddr() {
return deepEqual(v1, v2, make(map[deepEqualVisit]bool))
} else {
// Copy them to heap to make them addressable.
copyV1 := reflect.New(v1.Type()).Elem()
copyV1.Set(v1)
copyV2 := reflect.New(v1.Type()).Elem()
copyV2.Set(v2)
return deepEqual(copyV1, copyV2, make(map[deepEqualVisit]bool))
}
}
return reflect.DeepEqual(expected, actual)
}

act, ok := actual.([]byte)
if !ok {
type deepEqualVisit struct {
a1 unsafe.Pointer
a2 unsafe.Pointer
typ reflect.Type
}

func deepEqual(v1, v2 reflect.Value, visited map[deepEqualVisit]bool) bool {
if !v1.IsValid() || !v2.IsValid() {
return v1.IsValid() == v2.IsValid()
}
if v1.Type() != v2.Type() {
return false
}
if exp == nil || act == nil {
return exp == nil && act == nil

// if more struct-type needs special care, add them here.
if tv1, ok := v1.Interface().(time.Time); ok {
tv2 := v2.Interface().(time.Time)
return tv1.Equal(tv2)
}

// From now on, we are basically copying from reflect.DeepCopy

// We want to avoid putting more in the visited map than we need to.
// For any possible reference cycle that might be encountered,
// hard(v1, v2) needs to return true for at least one of the types in the cycle,
// and it's safe and valid to get Value's internal pointer.
hard := func(v1, v2 reflect.Value) bool {
switch v1.Kind() {
case reflect.Map, reflect.Slice, reflect.Ptr, reflect.Interface:
// Nil pointers cannot be cyclic. Avoid putting them in the visited map.
return !v1.IsNil() && !v2.IsNil()
}
return false
}

if hard(v1, v2) {
// Should be addressable.
addr1 := unsafe.Pointer(v1.UnsafeAddr())
addr2 := unsafe.Pointer(v2.UnsafeAddr())
if uintptr(addr1) > uintptr(addr2) {
// Canonicalize order to reduce number of entries in visited.
// Assumes non-moving garbage collector.
addr1, addr2 = addr2, addr1
}

// Short circuit if references are already seen.
typ := v1.Type()
v := deepEqualVisit{addr1, addr2, typ}
if visited[v] {
return true
}

// Remember for later.
visited[v] = true
}

switch v1.Kind() {
case reflect.Array:
for i := 0; i < v1.Len(); i += 1 {
if !deepEqual(v1.Index(i), v2.Index(i), visited) {
return false
}
}
return true
case reflect.Slice:
if v1.IsNil() != v2.IsNil() {
return false
}
if v1.Len() != v2.Len() {
return false
}
if v1.Pointer() == v2.Pointer() {
return true
}
for i := 0; i < v1.Len(); i += 1 {
if !deepEqual(v1.Index(i), v2.Index(i), visited) {
return false
}
}
return true
case reflect.Interface:
if v1.IsNil() || v2.IsNil() {
return v1.IsNil() == v2.IsNil()
}
return deepEqual(v1.Elem(), v2.Elem(), visited)
case reflect.Ptr:
if v1.Pointer() == v2.Pointer() {
return true
}
return deepEqual(v1.Elem(), v2.Elem(), visited)
case reflect.Struct:
// We need to look into the unexported value of the struct.
// I'm using this trick from
// https://stackoverflow.com/questions/42664837/how-to-access-unexported-struct-fields-in-golang.

// they are already on heap because the caller of deepEqual(ObjectsAreEqual) does the copy work.

//if v1.CanAddr() && v2.CanAddr() {
for i, n := 0, v1.NumField(); i < n; i++ {
v1f := reflect.NewAt(v1.Field(i).Type(), unsafe.Pointer(v1.Field(i).UnsafeAddr())).Elem()
v2f := reflect.NewAt(v2.Field(i).Type(), unsafe.Pointer(v2.Field(i).UnsafeAddr())).Elem()
// v1f and v2f are writable and readable even it's unexported.
// to call time.Time.Equal, we need to read these values (call Interface)
// so them must be readable.
if !deepEqual(v1f, v2f, visited) {
return false
}
}
//} else {
// // copy them out to heap
// // this may be slow, but won't effect production performance.
// copyV1 := reflect.New(v1.Type()).Elem()
// copyV1.Set(v1)
// copyV2 := reflect.New(v1.Type()).Elem()
// copyV2.Set(v2)
// for i, n := 0, copyV1.NumField(); i < n; i++ {
// v1f := reflect.NewAt(copyV1.Field(i).Type(), unsafe.Pointer(copyV1.Field(i).UnsafeAddr())).Elem()
// v2f := reflect.NewAt(copyV2.Field(i).Type(), unsafe.Pointer(copyV2.Field(i).UnsafeAddr())).Elem()
// // v1f and v2f are writable and readable even it's unexported.
// // to call time.Time.Equal, we need to read these values (call Interface)
// // so them must be readable.
// if !deepEqual(v1f, v2f, visited) {
// return false
// }
// }
//}
return true
case reflect.Map:
if v1.IsNil() != v2.IsNil() {
return false
}
if v1.Len() != v2.Len() {
return false
}
if v1.Pointer() == v2.Pointer() {
return true
}
for _, k := range v1.MapKeys() {
val1 := v1.MapIndex(k)
val2 := v2.MapIndex(k)
if !val1.IsValid() || !val2.IsValid() || !deepEqual(val1, val2, visited) {
return false
}
}
return true
}
return bytes.Equal(exp, act)
// types we can't handle
return reflect.DeepEqual(v1.Interface(), v2.Interface())
}

// ObjectsAreEqualValues gets whether two objects are equal, or if their
Expand Down
2 changes: 1 addition & 1 deletion assert/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func TestObjectsAreEqual_StructContainingTime(t *testing.T) {
}
}{
struct{ time.Time }{
a.Inner.UTC(),
a.Inner.UTC().Round(0),
},
}
if !ObjectsAreEqual(a, b) {
Expand Down

0 comments on commit 90f9249

Please sign in to comment.