diff --git a/assert/assert.go b/assert/assert.go index a5dc958..e468f66 100644 --- a/assert/assert.go +++ b/assert/assert.go @@ -4,8 +4,10 @@ package assert import ( + "errors" "fmt" "reflect" + "regexp" "testing" "github.com/sergi/go-diff/diffmatchpatch" @@ -23,6 +25,89 @@ func Equal[V comparable](t TestingT, actual, expected V) bool { return false } +// ErrEqual checks if the actual error matches the expectation. +// +// - If `expected` is nil, the actual error must be nil. +// - If `expected` is of type error, the actual error must be exactly equal to it, or contain it in the sense of errors.Is(). +// - If `expected` is of type string, the actual error message must be exactly equal to it. +// - If `expected` is of type *regexp.Regexp, that regexp must match the actual error message. +func ErrEqual(t TestingT, actual error, expectedErrorOrMessageOrRegexp any) bool { + // NOTE 1: We cannot enumerate the possible types of `expected` as a type argument of the form + // func ErrEqual[T interface{ error | string | *regexp.Regexp }](...) + // because unions of interface types (error) and concrete types (string etc.) are not permitted. + // The risk of accepting an `any` value is acceptable here because the panic from + // using an unexpected type can only occur in tests, and thus will be difficult to overlook. + // + // NOTE 2: The verbose name of the last argument is intended to help users + // who see only the function signature in their IDE autocomplete. + t.Helper() + + switch expected := expectedErrorOrMessageOrRegexp.(type) { + case nil: + if actual == nil { + return true + } + t.Errorf("expected success, but got error: %s", actual.Error()) + return false + + case error: + if actual == nil { + if expected == nil { + // defense in depth: this should have been covered by the previous case branch + return true + } + t.Errorf("expected error stack to contain %q, but got no error", expected.Error()) + return false + } + switch { + case expected == nil: + // defense in depth: this should have been covered by the previous case branch + t.Errorf("expected success, but got error: %s", actual.Error()) + return false + case errors.Is(actual, expected): + return true + default: + t.Errorf("expected error stack to contain %q, but got error: %s", expected.Error(), actual.Error()) + return false + } + + case string: + if actual == nil { + if expected == "" { + return true + } + t.Errorf("expected error with message %q, but got no error", expected) + return false + } + msg := actual.Error() + switch expected { + case "": + t.Errorf("expected success, but got error: %s", msg) + return false + case msg: + return true + default: + t.Errorf("expected error with message %q, but got error: %s", expected, msg) + return false + } + + case *regexp.Regexp: + if actual == nil { + t.Errorf("expected error with message matching /%s/, but got no error", expected.String()) + return false + } + msg := actual.Error() + if expected.MatchString(msg) { + return true + } + t.Errorf("expected error with message matching /%s/, but got error: %s", expected.String(), msg) + return false + + default: + panic(fmt.Sprintf("assert.ErrEqual() cannot match against an expectation of type %T", expected)) + } +} + // DeepEqual checks if the actual and expected value are equal as // determined by reflect.DeepEqual(), and t.Error()s otherwise. func DeepEqual[V any](t *testing.T, variable string, actual, expected V) bool { diff --git a/assert/assert_test.go b/assert/assert_test.go new file mode 100644 index 0000000..bd56639 --- /dev/null +++ b/assert/assert_test.go @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: 2025 SAP SE or an SAP affiliate company +// SPDX-License-Identifier: Apache-2.0 + +package assert_test + +import ( + "errors" + "fmt" + "regexp" + "testing" + + "github.com/sapcc/go-bits/assert" + "github.com/sapcc/go-bits/internal/testutil" +) + +func TestErrEqual(t *testing.T) { + var ( + actual error + mock = &testutil.MockT{} + ) + checkPasses := func(expected any) { + t.Helper() + ok := assert.ErrEqual(mock, actual, expected) + assert.Equal(t, ok, true) + mock.ExpectNoErrors(t) + } + checkFails := func(expected any, message string) { + t.Helper() + ok := assert.ErrEqual(mock, actual, expected) + assert.Equal(t, ok, false) + mock.ExpectErrors(t, message) + } + + // some helper errors for below + errFoo := errors.New("wrong foo supplied") + errBar := fmt.Errorf("could not connect to bar: %w", errFoo) + errQux := errors.New("found no relation from qux to foo/bar") + + // check assertions for when the actual error is nil + actual = nil + checkPasses(nil) + checkPasses(error(nil)) + checkFails(errFoo, `expected error stack to contain "wrong foo supplied", but got no error`) + checkPasses("") + checkFails("datacenter on fire", `expected error with message "datacenter on fire", but got no error`) + checkFails(regexp.MustCompile(`.*`), `expected error with message matching /.*/, but got no error`) + + // check assertions with a simple error + actual = errFoo + checkFails(nil, `expected success, but got error: wrong foo supplied`) + checkFails(error(nil), `expected success, but got error: wrong foo supplied`) + checkPasses(errFoo) + checkFails(errBar, `expected error stack to contain "could not connect to bar: wrong foo supplied", but got error: wrong foo supplied`) + checkFails(errQux, `expected error stack to contain "found no relation from qux to foo/bar", but got error: wrong foo supplied`) + checkFails("", `expected success, but got error: wrong foo supplied`) + checkPasses("wrong foo supplied") + checkFails("datacenter on fire", `expected error with message "datacenter on fire", but got error: wrong foo supplied`) + checkPasses(regexp.MustCompile(`wrong fo* supplied`)) + checkFails(regexp.MustCompile(`connect to bar`), `expected error with message matching /connect to bar/, but got error: wrong foo supplied`) + + // check assertions with an error stack + actual = errBar + checkFails(nil, `expected success, but got error: could not connect to bar: wrong foo supplied`) + checkFails(error(nil), `expected success, but got error: could not connect to bar: wrong foo supplied`) + checkPasses(errFoo) // both with the contained error... + checkPasses(errBar) // ...as well as with the full error + checkFails(errQux, `expected error stack to contain "found no relation from qux to foo/bar", but got error: could not connect to bar: wrong foo supplied`) + checkFails("", `expected success, but got error: could not connect to bar: wrong foo supplied`) + checkFails("wrong foo supplied", `expected error with message "wrong foo supplied", but got error: could not connect to bar: wrong foo supplied`) + checkPasses("could not connect to bar: wrong foo supplied") + checkFails("datacenter on fire", `expected error with message "datacenter on fire", but got error: could not connect to bar: wrong foo supplied`) + checkPasses(regexp.MustCompile(`wrong fo* supplied`)) + checkPasses(regexp.MustCompile(`connect to bar`)) +} diff --git a/httptest/handler_test.go b/httptest/handler_test.go index 4ad9c30..8e2d926 100644 --- a/httptest/handler_test.go +++ b/httptest/handler_test.go @@ -5,7 +5,6 @@ package httptest_test import ( "bytes" - "fmt" "io" "net/http" "os" @@ -17,6 +16,7 @@ import ( "github.com/sapcc/go-bits/assert" "github.com/sapcc/go-bits/httptest" + "github.com/sapcc/go-bits/internal/testutil" "github.com/sapcc/go-bits/must" ) @@ -45,7 +45,7 @@ var exampleHandler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r func TestRespondTo(t *testing.T) { h := httptest.NewHandler(exampleHandler) - ctx := t.Context() // TODO: use t.Context() in Go 1.24+ + ctx := t.Context() // most basic invocation resp := h.RespondTo(ctx, "POST /reflect") @@ -147,22 +147,18 @@ func TestRespondTo(t *testing.T) { }) // check how ExpectJSON() reports an unexpected status code - mock := &mockTestingT{} + mock := &testutil.MockT{} h.RespondTo(ctx, "POST /reflect", httptest.WithBody(strings.NewReader(`{"foo":23,"bar":42}`)), ).ExpectJSON(mock, http.StatusNotFound, jsonmatch.Object{}) - assert.DeepEqual(t, "collected errors", mock.Errors, []string{ - `expected HTTP status 404, but got 200 (body was "{\"foo\":23,\"bar\":42}")`, - }) + mock.ExpectErrors(t, `expected HTTP status 404, but got 200 (body was "{\"foo\":23,\"bar\":42}")`) // check how ExpectJSON() reports diffs without Pointer mock.Errors = nil h.RespondTo(ctx, "POST /reflect", httptest.WithBody(strings.NewReader(`{"foo":23,"bar":42}`)), ).ExpectJSON(mock, http.StatusOK, jsonmatch.Scalar(true)) - assert.DeepEqual(t, "collected errors", mock.Errors, []string{ - `type mismatch: expected true, but got {"bar":42,"foo":23}`, - }) + mock.ExpectErrors(t, `type mismatch: expected true, but got {"bar":42,"foo":23}`) // check how ExpectJSON() reports diffs with Pointer mock.Errors = nil @@ -172,9 +168,7 @@ func TestRespondTo(t *testing.T) { "foo": 23, "bar": 45, }) - assert.DeepEqual(t, "collected errors", mock.Errors, []string{ - `value mismatch at /bar: expected 45, but got 42`, - }) + mock.ExpectErrors(t, `value mismatch at /bar: expected 45, but got 42`) // check ExpectText() h.RespondTo(ctx, "POST /reflect", @@ -186,32 +180,17 @@ func TestRespondTo(t *testing.T) { h.RespondTo(ctx, "POST /reflect", httptest.WithBody(strings.NewReader("hello")), ).ExpectText(mock, http.StatusNotFound, "hello") - assert.DeepEqual(t, "collected errors", mock.Errors, []string{ - `expected HTTP status 404, but got 200 (body was "hello")`, - }) + mock.ExpectErrors(t, `expected HTTP status 404, but got 200 (body was "hello")`) // check how ExpectText() reports an unexpected response body mock.Errors = nil h.RespondTo(ctx, "POST /reflect", httptest.WithBody(strings.NewReader("hello")), ).ExpectText(mock, http.StatusOK, "world") - assert.DeepEqual(t, "collected errors", mock.Errors, []string{ - `expected "world", but got "hello"`, - }) + mock.ExpectErrors(t, `expected "world", but got "hello"`) // check ExpectBodyAsInFixture() h.RespondTo(ctx, "POST /reflect", httptest.WithBody(bytes.NewReader(must.Return(os.ReadFile("fixtures/example.txt")))), ).ExpectBodyAsInFixture(t, http.StatusOK, "fixtures/example.txt") } - -// A mock for *testing.T. -type mockTestingT struct { - Errors []string -} - -func (t *mockTestingT) Helper() {} - -func (t *mockTestingT) Errorf(msg string, args ...any) { - t.Errors = append(t.Errors, fmt.Sprintf(msg, args...)) -} diff --git a/internal/testutil/mockt.go b/internal/testutil/mockt.go new file mode 100644 index 0000000..51366c4 --- /dev/null +++ b/internal/testutil/mockt.go @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: 2025 SAP SE or an SAP affiliate company +// SPDX-License-Identifier: Apache-2.0 + +package testutil + +import ( + "fmt" + "testing" + + "github.com/sapcc/go-bits/assert" +) + +// A mock for *testing.T that implements the assert.TestingT interface. +type MockT struct { + Errors []string +} + +func (mt *MockT) Helper() {} + +func (mt *MockT) Errorf(msg string, args ...any) { + mt.Errors = append(mt.Errors, fmt.Sprintf(msg, args...)) +} + +// ExpectErrors asserts on the errors collected so far, +// and then clears out the list of collected errors for the next subtest. +func (mt *MockT) ExpectErrors(t *testing.T, expected ...string) { + t.Helper() + assert.DeepEqual(t, "collected errors", mt.Errors, expected) + mt.Errors = nil +} + +// ExpectErrors asserts that no errors were collected so far. +func (mt *MockT) ExpectNoErrors(t *testing.T) { + t.Helper() + assert.DeepEqual(t, "collected errors", mt.Errors, []string(nil)) +} diff --git a/must/must.go b/must/must.go index 109e75c..b84eda9 100644 --- a/must/must.go +++ b/must/must.go @@ -5,7 +5,11 @@ // errors without the need for excessive "if err != nil". package must -import "github.com/sapcc/go-bits/logg" +import ( + "testing" + + "github.com/sapcc/go-bits/logg" +) // Succeed logs a fatal error and terminates the program if the given error is // non-nil. For example, the following: @@ -26,6 +30,14 @@ func Succeed(err error) { } } +// SucceedT is a variant of Succeed() for use in unit tests. +// Instead of exiting the program, any non-nil errors are reported with t.Fatal(). +func SucceedT(t *testing.T, err error) { + if err != nil { + t.Fatal(err.Error()) + } +} + // Return is like Succeed(), except that it propagates a result value on success. // This can be chained with functions returning a pair of result value and error // if errors are considered fatal. For example, the following: @@ -38,7 +50,31 @@ func Succeed(err error) { // can be shortened to: // // buf := must.Return(os.ReadFile("config.ini")) -func Return[T any](val T, err error) T { +func Return[V any](val V, err error) V { Succeed(err) return val } + +// ReturnT is a variant of Return() for use in unit tests. +// Instead of exiting the program, any non-nil errors are reported with t.Fatal(). +// For example: +// +// buf := must.ReturnT(os.ReadFile("config.ini"))(t) +func ReturnT[V any](val V, err error) func(*testing.T) V { + // NOTE: This is the only function signature that works. We cannot do something like + // + // myMust := must.WithT(t) + // buf := myMust.Return(os.ReadFile("config.ini")) + // + // because then the type argument V would have to be introduced within a method of typeof(myMust), + // but Go generics do not allow introducing new type arguments in methods. We also cannot do something like + // + // buf := must.ReturnT(t, os.ReadFile("config.ini")) + // + // because filling multiple arguments using a call expression with multiple return values + // is only allowed when there are no other arguments. + return func(t *testing.T) V { + SucceedT(t, err) + return val + } +} diff --git a/osext/env_test.go b/osext/env_test.go index ecaaef8..dd35e3b 100644 --- a/osext/env_test.go +++ b/osext/env_test.go @@ -23,7 +23,7 @@ func TestGetenv(t *testing.T) { str, err := osext.NeedGetenv(KEY) assert.Equal(t, str, VAL) - assert.Equal(t, err, nil) + assert.ErrEqual(t, err, nil) str = osext.GetenvOrDefault(KEY, DEFAULT) assert.Equal(t, str, VAL) @@ -35,7 +35,7 @@ func TestGetenv(t *testing.T) { t.Setenv(KEY, "") _, err = osext.NeedGetenv(KEY) - assert.Equal(t, err, error(osext.MissingEnvError{Key: KEY})) + assert.ErrEqual(t, err, osext.MissingEnvError{Key: KEY}) str = osext.GetenvOrDefault(KEY, DEFAULT) assert.Equal(t, str, DEFAULT) @@ -47,7 +47,7 @@ func TestGetenv(t *testing.T) { os.Unsetenv(KEY) _, err = osext.NeedGetenv(KEY) - assert.Equal(t, err, error(osext.MissingEnvError{Key: KEY})) + assert.ErrEqual(t, err, osext.MissingEnvError{Key: KEY}) str = osext.GetenvOrDefault(KEY, DEFAULT) assert.Equal(t, str, DEFAULT)