Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions assert/assert.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
package assert

import (
"errors"
"fmt"
"reflect"
"regexp"
"testing"

"github.com/sergi/go-diff/diffmatchpatch"
Expand All @@ -23,6 +25,89 @@ func Equal[V comparable](t TestingT, actual, expected V) bool {
return false
}

// ErrEqual checks if the actual error matches the expectation.
//
// - If `expected` is nil, the actual error must be nil.
// - If `expected` is of type error, the actual error must be exactly equal to it, or contain it in the sense of errors.Is().
// - If `expected` is of type string, the actual error message must be exactly equal to it.
// - If `expected` is of type *regexp.Regexp, that regexp must match the actual error message.
func ErrEqual(t TestingT, actual error, expectedErrorOrMessageOrRegexp any) bool {
// NOTE 1: We cannot enumerate the possible types of `expected` as a type argument of the form
// func ErrEqual[T interface{ error | string | *regexp.Regexp }](...)
// because unions of interface types (error) and concrete types (string etc.) are not permitted.
// The risk of accepting an `any` value is acceptable here because the panic from
// using an unexpected type can only occur in tests, and thus will be difficult to overlook.
//
// NOTE 2: The verbose name of the last argument is intended to help users
// who see only the function signature in their IDE autocomplete.
t.Helper()

switch expected := expectedErrorOrMessageOrRegexp.(type) {
case nil:
if actual == nil {
return true
}
t.Errorf("expected success, but got error: %s", actual.Error())
return false

case error:
if actual == nil {
if expected == nil {
// defense in depth: this should have been covered by the previous case branch
return true
}
t.Errorf("expected error stack to contain %q, but got no error", expected.Error())
return false
}
switch {
case expected == nil:
// defense in depth: this should have been covered by the previous case branch
t.Errorf("expected success, but got error: %s", actual.Error())
return false
case errors.Is(actual, expected):
return true
default:
t.Errorf("expected error stack to contain %q, but got error: %s", expected.Error(), actual.Error())
return false
}

case string:
if actual == nil {
if expected == "" {
return true
}
t.Errorf("expected error with message %q, but got no error", expected)
return false
}
msg := actual.Error()
switch expected {
case "":
t.Errorf("expected success, but got error: %s", msg)
return false
case msg:
return true
default:
t.Errorf("expected error with message %q, but got error: %s", expected, msg)
return false
}

case *regexp.Regexp:
if actual == nil {
t.Errorf("expected error with message matching /%s/, but got no error", expected.String())
return false
}
msg := actual.Error()
if expected.MatchString(msg) {
return true
}
t.Errorf("expected error with message matching /%s/, but got error: %s", expected.String(), msg)
return false

default:
panic(fmt.Sprintf("assert.ErrEqual() cannot match against an expectation of type %T", expected))
}
}

// DeepEqual checks if the actual and expected value are equal as
// determined by reflect.DeepEqual(), and t.Error()s otherwise.
func DeepEqual[V any](t *testing.T, variable string, actual, expected V) bool {
Expand Down
74 changes: 74 additions & 0 deletions assert/assert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// SPDX-FileCopyrightText: 2025 SAP SE or an SAP affiliate company
// SPDX-License-Identifier: Apache-2.0

package assert_test

import (
"errors"
"fmt"
"regexp"
"testing"

"github.com/sapcc/go-bits/assert"
"github.com/sapcc/go-bits/internal/testutil"
)

func TestErrEqual(t *testing.T) {
var (
actual error
mock = &testutil.MockT{}
)
checkPasses := func(expected any) {
t.Helper()
ok := assert.ErrEqual(mock, actual, expected)
assert.Equal(t, ok, true)
mock.ExpectNoErrors(t)
}
checkFails := func(expected any, message string) {
t.Helper()
ok := assert.ErrEqual(mock, actual, expected)
assert.Equal(t, ok, false)
mock.ExpectErrors(t, message)
}

// some helper errors for below
errFoo := errors.New("wrong foo supplied")
errBar := fmt.Errorf("could not connect to bar: %w", errFoo)
errQux := errors.New("found no relation from qux to foo/bar")

// check assertions for when the actual error is nil
actual = nil
checkPasses(nil)
checkPasses(error(nil))
checkFails(errFoo, `expected error stack to contain "wrong foo supplied", but got no error`)
checkPasses("")
checkFails("datacenter on fire", `expected error with message "datacenter on fire", but got no error`)
checkFails(regexp.MustCompile(`.*`), `expected error with message matching /.*/, but got no error`)

// check assertions with a simple error
actual = errFoo
checkFails(nil, `expected success, but got error: wrong foo supplied`)
checkFails(error(nil), `expected success, but got error: wrong foo supplied`)
checkPasses(errFoo)
checkFails(errBar, `expected error stack to contain "could not connect to bar: wrong foo supplied", but got error: wrong foo supplied`)
checkFails(errQux, `expected error stack to contain "found no relation from qux to foo/bar", but got error: wrong foo supplied`)
checkFails("", `expected success, but got error: wrong foo supplied`)
checkPasses("wrong foo supplied")
checkFails("datacenter on fire", `expected error with message "datacenter on fire", but got error: wrong foo supplied`)
checkPasses(regexp.MustCompile(`wrong fo* supplied`))
checkFails(regexp.MustCompile(`connect to bar`), `expected error with message matching /connect to bar/, but got error: wrong foo supplied`)

// check assertions with an error stack
actual = errBar
checkFails(nil, `expected success, but got error: could not connect to bar: wrong foo supplied`)
checkFails(error(nil), `expected success, but got error: could not connect to bar: wrong foo supplied`)
checkPasses(errFoo) // both with the contained error...
checkPasses(errBar) // ...as well as with the full error
checkFails(errQux, `expected error stack to contain "found no relation from qux to foo/bar", but got error: could not connect to bar: wrong foo supplied`)
checkFails("", `expected success, but got error: could not connect to bar: wrong foo supplied`)
checkFails("wrong foo supplied", `expected error with message "wrong foo supplied", but got error: could not connect to bar: wrong foo supplied`)
checkPasses("could not connect to bar: wrong foo supplied")
checkFails("datacenter on fire", `expected error with message "datacenter on fire", but got error: could not connect to bar: wrong foo supplied`)
checkPasses(regexp.MustCompile(`wrong fo* supplied`))
checkPasses(regexp.MustCompile(`connect to bar`))
}
37 changes: 8 additions & 29 deletions httptest/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package httptest_test

import (
"bytes"
"fmt"
"io"
"net/http"
"os"
Expand All @@ -17,6 +16,7 @@ import (

"github.com/sapcc/go-bits/assert"
"github.com/sapcc/go-bits/httptest"
"github.com/sapcc/go-bits/internal/testutil"
"github.com/sapcc/go-bits/must"
)

Expand Down Expand Up @@ -45,7 +45,7 @@ var exampleHandler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r

func TestRespondTo(t *testing.T) {
h := httptest.NewHandler(exampleHandler)
ctx := t.Context() // TODO: use t.Context() in Go 1.24+
ctx := t.Context()

// most basic invocation
resp := h.RespondTo(ctx, "POST /reflect")
Expand Down Expand Up @@ -147,22 +147,18 @@ func TestRespondTo(t *testing.T) {
})

// check how ExpectJSON() reports an unexpected status code
mock := &mockTestingT{}
mock := &testutil.MockT{}
h.RespondTo(ctx, "POST /reflect",
httptest.WithBody(strings.NewReader(`{"foo":23,"bar":42}`)),
).ExpectJSON(mock, http.StatusNotFound, jsonmatch.Object{})
assert.DeepEqual(t, "collected errors", mock.Errors, []string{
`expected HTTP status 404, but got 200 (body was "{\"foo\":23,\"bar\":42}")`,
})
mock.ExpectErrors(t, `expected HTTP status 404, but got 200 (body was "{\"foo\":23,\"bar\":42}")`)

// check how ExpectJSON() reports diffs without Pointer
mock.Errors = nil
h.RespondTo(ctx, "POST /reflect",
httptest.WithBody(strings.NewReader(`{"foo":23,"bar":42}`)),
).ExpectJSON(mock, http.StatusOK, jsonmatch.Scalar(true))
assert.DeepEqual(t, "collected errors", mock.Errors, []string{
`type mismatch: expected true, but got {"bar":42,"foo":23}`,
})
mock.ExpectErrors(t, `type mismatch: expected true, but got {"bar":42,"foo":23}`)

// check how ExpectJSON() reports diffs with Pointer
mock.Errors = nil
Expand All @@ -172,9 +168,7 @@ func TestRespondTo(t *testing.T) {
"foo": 23,
"bar": 45,
})
assert.DeepEqual(t, "collected errors", mock.Errors, []string{
`value mismatch at /bar: expected 45, but got 42`,
})
mock.ExpectErrors(t, `value mismatch at /bar: expected 45, but got 42`)

// check ExpectText()
h.RespondTo(ctx, "POST /reflect",
Expand All @@ -186,32 +180,17 @@ func TestRespondTo(t *testing.T) {
h.RespondTo(ctx, "POST /reflect",
httptest.WithBody(strings.NewReader("hello")),
).ExpectText(mock, http.StatusNotFound, "hello")
assert.DeepEqual(t, "collected errors", mock.Errors, []string{
`expected HTTP status 404, but got 200 (body was "hello")`,
})
mock.ExpectErrors(t, `expected HTTP status 404, but got 200 (body was "hello")`)

// check how ExpectText() reports an unexpected response body
mock.Errors = nil
h.RespondTo(ctx, "POST /reflect",
httptest.WithBody(strings.NewReader("hello")),
).ExpectText(mock, http.StatusOK, "world")
assert.DeepEqual(t, "collected errors", mock.Errors, []string{
`expected "world", but got "hello"`,
})
mock.ExpectErrors(t, `expected "world", but got "hello"`)

// check ExpectBodyAsInFixture()
h.RespondTo(ctx, "POST /reflect",
httptest.WithBody(bytes.NewReader(must.Return(os.ReadFile("fixtures/example.txt")))),
).ExpectBodyAsInFixture(t, http.StatusOK, "fixtures/example.txt")
}

// A mock for *testing.T.
type mockTestingT struct {
Errors []string
}

func (t *mockTestingT) Helper() {}

func (t *mockTestingT) Errorf(msg string, args ...any) {
t.Errors = append(t.Errors, fmt.Sprintf(msg, args...))
}
36 changes: 36 additions & 0 deletions internal/testutil/mockt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// SPDX-FileCopyrightText: 2025 SAP SE or an SAP affiliate company
// SPDX-License-Identifier: Apache-2.0

package testutil

import (
"fmt"
"testing"

"github.com/sapcc/go-bits/assert"
)

// A mock for *testing.T that implements the assert.TestingT interface.
type MockT struct {
Errors []string
}

func (mt *MockT) Helper() {}

func (mt *MockT) Errorf(msg string, args ...any) {
mt.Errors = append(mt.Errors, fmt.Sprintf(msg, args...))
}

// ExpectErrors asserts on the errors collected so far,
// and then clears out the list of collected errors for the next subtest.
func (mt *MockT) ExpectErrors(t *testing.T, expected ...string) {
t.Helper()
assert.DeepEqual(t, "collected errors", mt.Errors, expected)
mt.Errors = nil
}

// ExpectErrors asserts that no errors were collected so far.
func (mt *MockT) ExpectNoErrors(t *testing.T) {
t.Helper()
assert.DeepEqual(t, "collected errors", mt.Errors, []string(nil))
}
40 changes: 38 additions & 2 deletions must/must.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
// errors without the need for excessive "if err != nil".
package must

import "github.com/sapcc/go-bits/logg"
import (
"testing"

"github.com/sapcc/go-bits/logg"
)

// Succeed logs a fatal error and terminates the program if the given error is
// non-nil. For example, the following:
Expand All @@ -26,6 +30,14 @@ func Succeed(err error) {
}
}

// SucceedT is a variant of Succeed() for use in unit tests.
// Instead of exiting the program, any non-nil errors are reported with t.Fatal().
func SucceedT(t *testing.T, err error) {
if err != nil {
t.Fatal(err.Error())
}
}

// Return is like Succeed(), except that it propagates a result value on success.
// This can be chained with functions returning a pair of result value and error
// if errors are considered fatal. For example, the following:
Expand All @@ -38,7 +50,31 @@ func Succeed(err error) {
// can be shortened to:
//
// buf := must.Return(os.ReadFile("config.ini"))
func Return[T any](val T, err error) T {
func Return[V any](val V, err error) V {
Succeed(err)
return val
}

// ReturnT is a variant of Return() for use in unit tests.
// Instead of exiting the program, any non-nil errors are reported with t.Fatal().
// For example:
//
// buf := must.ReturnT(os.ReadFile("config.ini"))(t)
func ReturnT[V any](val V, err error) func(*testing.T) V {
// NOTE: This is the only function signature that works. We cannot do something like
//
// myMust := must.WithT(t)
// buf := myMust.Return(os.ReadFile("config.ini"))
//
// because then the type argument V would have to be introduced within a method of typeof(myMust),
// but Go generics do not allow introducing new type arguments in methods. We also cannot do something like
//
// buf := must.ReturnT(t, os.ReadFile("config.ini"))
//
// because filling multiple arguments using a call expression with multiple return values
// is only allowed when there are no other arguments.
return func(t *testing.T) V {
SucceedT(t, err)
return val
}
}
6 changes: 3 additions & 3 deletions osext/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestGetenv(t *testing.T) {

str, err := osext.NeedGetenv(KEY)
assert.Equal(t, str, VAL)
assert.Equal(t, err, nil)
assert.ErrEqual(t, err, nil)

str = osext.GetenvOrDefault(KEY, DEFAULT)
assert.Equal(t, str, VAL)
Expand All @@ -35,7 +35,7 @@ func TestGetenv(t *testing.T) {
t.Setenv(KEY, "")

_, err = osext.NeedGetenv(KEY)
assert.Equal(t, err, error(osext.MissingEnvError{Key: KEY}))
assert.ErrEqual(t, err, osext.MissingEnvError{Key: KEY})

str = osext.GetenvOrDefault(KEY, DEFAULT)
assert.Equal(t, str, DEFAULT)
Expand All @@ -47,7 +47,7 @@ func TestGetenv(t *testing.T) {
os.Unsetenv(KEY)

_, err = osext.NeedGetenv(KEY)
assert.Equal(t, err, error(osext.MissingEnvError{Key: KEY}))
assert.ErrEqual(t, err, osext.MissingEnvError{Key: KEY})

str = osext.GetenvOrDefault(KEY, DEFAULT)
assert.Equal(t, str, DEFAULT)
Expand Down
Loading