diff --git a/assert/assertions.go b/assert/assertions.go index 0b7570f21..b65ba1f58 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -48,6 +48,29 @@ type ErrorAssertionFunc func(TestingT, error, ...interface{}) bool // Comparison is a custom function that returns true on success and false on failure type Comparison func() (success bool) +// ErrorIsFor returns an [ErrorAssertionFunc] which tests if the error wraps target. +func ErrorIsFor(target error) ErrorAssertionFunc { + return func(t TestingT, err error, msgsAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + return ErrorIs(t, err, target, msgsAndArgs...) + } +} + +// ErrorAsFor returns an [ErrorAssertionFunc] which tests if the any error in err's tree matches target and if so, assigns it to target. +// The returned function panics if target is not a non-nil pointer to either a type that implements error, or to any interface type. +func ErrorAsFor(target interface{}) ErrorAssertionFunc { + return func(t TestingT, err error, msgsAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + return ErrorAs(t, err, target, msgsAndArgs...) + } +} + /* Helper functions */ diff --git a/assert/assertions_test.go b/assert/assertions_test.go index 2a6e47234..577f3615a 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -2750,7 +2750,13 @@ func ExampleErrorAssertionFunc() { t := &testing.T{} // provided by test dumbParseNum := func(input string, v interface{}) error { - return json.Unmarshal([]byte(input), v) + + err := json.Unmarshal([]byte(input), v) + if err != nil { + return testingError{"could not Unmarshal " + input} + } + + return nil } tests := []struct { @@ -2760,8 +2766,9 @@ func ExampleErrorAssertionFunc() { }{ {"1.2 is number", "1.2", NoError}, {"1.2.3 not number", "1.2.3", Error}, - {"true is not number", "true", Error}, + {"true is not number", "true", ErrorAsFor(&testingError{})}, {"3 is number", "3", NoError}, + {"3% is not a valid number", "3%", ErrorIsFor(testingError{"could not Unmarshal 3%"})}, } for _, tt := range tests { @@ -2772,7 +2779,17 @@ func ExampleErrorAssertionFunc() { } } +type testingError struct { + extraInfo string +} + +func (t testingError) Error() string { + return t.extraInfo +} + + func TestErrorAssertionFunc(t *testing.T) { + var testError = errors.New("test error") tests := []struct { name string err error @@ -2780,6 +2797,8 @@ func TestErrorAssertionFunc(t *testing.T) { }{ {"noError", nil, NoError}, {"error", errors.New("whoops"), Error}, + {"errorIs", testError, ErrorIsFor(testError)}, + {"errorAs", testingError{extraInfo: "something"}, ErrorAsFor(&testingError{})}, } for _, tt := range tests { diff --git a/require/requirements.go b/require/requirements.go index 91772dfeb..9fb5e6670 100644 --- a/require/requirements.go +++ b/require/requirements.go @@ -26,4 +26,27 @@ type BoolAssertionFunc func(TestingT, bool, ...interface{}) // for table driven tests. type ErrorAssertionFunc func(TestingT, error, ...interface{}) +// ErrorIsFunc returns an [ErrorAssertionFunc] which tests if the error wraps target. +func ErrorIsFor(expectedError error) ErrorAssertionFunc { + return func(t TestingT, err error, msgsAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + ErrorIs(t, err, expectedError, msgsAndArgs...) + } +} + +// ErrorAsFunc returns an [ErrorAssertionFunc] which tests if the any error in err's tree matches target and if so, assigns it to target. +// The returned function panics if target is not a non-nil pointer to either a type that implements error, or to any interface type. +func ErrorAsFor(expectedInterface interface{}) ErrorAssertionFunc { + return func(t TestingT, err error, msgsAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + ErrorAs(t, err, expectedInterface, msgsAndArgs...) + } +} + //go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=require -template=require.go.tmpl -include-format-funcs" diff --git a/require/requirements_test.go b/require/requirements_test.go index febf0c187..ba3e3615e 100644 --- a/require/requirements_test.go +++ b/require/requirements_test.go @@ -665,7 +665,17 @@ func ExampleErrorAssertionFunc() { } } +type testingError struct { + extraInfo string +} + +func (t testingError) Error() string { + return t.extraInfo +} + + func TestErrorAssertionFunc(t *testing.T) { + var testError = errors.New("test error") tests := []struct { name string err error @@ -673,6 +683,8 @@ func TestErrorAssertionFunc(t *testing.T) { }{ {"noError", nil, NoError}, {"error", errors.New("whoops"), Error}, + {"errorIs", testError, ErrorIsFor(testError)}, + {"errorAs", testingError{extraInfo: "something"}, ErrorAsFor(&testingError{})}, } for _, tt := range tests {