diff --git a/is.go b/is.go index 2c3c28a..6bedcb2 100644 --- a/is.go +++ b/is.go @@ -68,8 +68,8 @@ func (is *Is) Strict() *Is { // not equal. // // Equal does not respect type differences. If the types are different and -// comparable (eg int32 and int64), but the values are the same, the objects -// are considered equal. +// comparable (eg int32 and int64), they will be compared as though they are +// the same type. func (is *Is) Equal(a interface{}, b interface{}) { result := isEqual(a, b) if !result { @@ -83,8 +83,8 @@ func (is *Is) Equal(a interface{}, b interface{}) { // equal. // // NotEqual does not respect type differences. If the types are different and -// comparable (eg int32 and int64), but the values are different, the objects -// are considered not equal. +// comparable (eg int32 and int64), they will be compared as though they are +// the same type. func (is *Is) NotEqual(a interface{}, b interface{}) { result := isEqual(a, b) if result { @@ -94,6 +94,50 @@ func (is *Is) NotEqual(a interface{}, b interface{}) { } } +// OneOf performs a deep compare of the provided object and an array of +// comparison objects. It fails if the first object is not equal to one of the +// comparison objects. +// +// OneOf does not respect type differences. If the types are different and +// comparable (eg int32 and int64), they will be compared as though they are +// the same type. +func (is *Is) OneOf(a interface{}, b ...interface{}) { + result := false + for _, o := range b { + result = isEqual(a, o) + if result { + break + } + } + if !result { + fail(is, "expected object '%s' to be equal to one of '%s', but got: %v and %v", + objectTypeName(a), + objectTypeNames(b), a, b) + } +} + +// NotOneOf performs a deep compare of the provided object and an array of +// comparison objects. It fails if the first object is equal to one of the +// comparison objects. +// +// NotOneOf does not respect type differences. If the types are different and +// comparable (eg int32 and int64), they will be compared as though they are +// the same type. +func (is *Is) NotOneOf(a interface{}, b ...interface{}) { + result := false + for _, o := range b { + result = isEqual(a, o) + if result { + break + } + } + if result { + fail(is, "expected object '%s' not to be equal to one of '%s', but got: %v and %v", + objectTypeName(a), + objectTypeNames(b), a, b) + } +} + // Err checks the provided error object to determine if an error is present. func (is *Is) Err(e error) { result := isNil(e) diff --git a/is_test.go b/is_test.go index ef88c1f..6c3d0a4 100644 --- a/is_test.go +++ b/is_test.go @@ -146,6 +146,8 @@ func TestIs(t *testing.T) { is.False(false) is.Zero(nil) is.Nil((*testStruct)(nil)) + is.OneOf(1, 2, 3, 1) + is.NotOneOf(1, 2, 3) fail = func(is *Is, format string, args ...interface{}) {} is.Equal((*testStruct)(nil), &testStruct{}) diff --git a/workers.go b/workers.go index 0c68e5d..b213a13 100644 --- a/workers.go +++ b/workers.go @@ -1,6 +1,7 @@ package is import ( + "bytes" "fmt" "reflect" ) @@ -9,6 +10,22 @@ func objectTypeName(o interface{}) string { return fmt.Sprintf("%T", o) } +func objectTypeNames(o []interface{}) string { + if o == nil { + return objectTypeName(o) + } + if len(o) == 1 { + return objectTypeName(o[0]) + } + var b bytes.Buffer + b.WriteString(objectTypeName(o[0])) + for _, e := range o[1:] { + b.WriteString(",") + b.WriteString(objectTypeName(e)) + } + return b.String() +} + func isNil(o interface{}) bool { if o == nil { return true