Skip to content

Commit

Permalink
cmdutil: Support errors.Join-based multi-errors
Browse files Browse the repository at this point in the history
cmdutil has some special handling for hashicorp/go-multierror
so that multi-errors are printed cleanly in the form:

    %d errors occurred:
        1) foo
        2) bar
        ...

In Go 1.20, the errors package got a native `errors.Join` function.
This adds support for errors.Join-based multi-errors to this logic.

These errors implement an `Unwrap() []error` method
which can be used to access the full list of errors.
We use that and then implement the same logic for formatting as before.
  • Loading branch information
abhinav committed Aug 29, 2023
1 parent 25901d9 commit fea8fc5
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 8 deletions.
@@ -0,0 +1,4 @@
changes:
- type: chore
scope: sdk/go
description: Support multi-errors built from errors.Join for RunFunc, Exit, and friends.
37 changes: 29 additions & 8 deletions sdk/go/common/util/cmdutil/exit.go
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/spf13/cobra"

"github.com/pulumi/pulumi/sdk/v3/go/common/diag"
"github.com/pulumi/pulumi/sdk/v3/go/common/util/contract"
"github.com/pulumi/pulumi/sdk/v3/go/common/util/logging"
"github.com/pulumi/pulumi/sdk/v3/go/common/util/result"
)
Expand Down Expand Up @@ -160,18 +161,38 @@ func exitErrorCodef(code int, format string, args ...interface{}) {
os.Exit(code)
}

type stdMultiError interface {
Unwrap() []error
}

// errorMessage returns a message, possibly cleaning up the text if appropriate.
func errorMessage(err error) string {
if multi, ok := err.(*multierror.Error); ok {
wr := multi.WrappedErrors()
if len(wr) == 1 {
return errorMessage(wr[0])
}
msg := fmt.Sprintf("%d errors occurred:", len(wr))
for i, werr := range wr {
contract.Requiref(err != nil, "err", "must not be nil")

var underlying []error
switch multi := err.(type) {
case *multierror.Error:
underlying = multi.WrappedErrors()
case stdMultiError:
underlying = multi.Unwrap()
default:
return err.Error()
}

switch len(underlying) {
case 0:
// This should never happen, but just in case.
// Return the original error message.
return err.Error()

case 1:
return errorMessage(underlying[0])

default:
msg := fmt.Sprintf("%d errors occurred:", len(underlying))
for i, werr := range underlying {
msg += fmt.Sprintf("\n %d) %s", i+1, errorMessage(werr))
}
return msg
}
return err.Error()
}
98 changes: 98 additions & 0 deletions sdk/go/common/util/cmdutil/exit_test.go
Expand Up @@ -2,11 +2,13 @@ package cmdutil

import (
"bytes"
"errors"
"io"
"os"
"os/exec"
"testing"

"github.com/hashicorp/go-multierror"
"github.com/pulumi/pulumi/sdk/v3/go/common/testing/iotest"
"github.com/pulumi/pulumi/sdk/v3/go/common/util/result"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -68,3 +70,99 @@ func TestFakeCommand(t *testing.T) {
// Unreachable: RunFunc should have called os.Exit.
assert.Fail(t, "unreachable", "RunFunc should have called os.Exit: %v", err)
}

func TestErrorMessage(t *testing.T) {
t.Parallel()

tests := []struct {
desc string
give error
want string
}{
{
desc: "simple error",
give: errors.New("great sadness"),
want: "great sadness",
},
{
desc: "hashi multi error",
give: multierror.Append(
errors.New("foo"),
errors.New("bar"),
errors.New("baz"),
),
want: "3 errors occurred:" +
"\n 1) foo" +
"\n 2) bar" +
"\n 3) baz",
},
{
desc: "std errors.Join",
give: errors.Join(
errors.New("foo"),
errors.New("bar"),
errors.New("baz"),
),
want: "3 errors occurred:" +
"\n 1) foo" +
"\n 2) bar" +
"\n 3) baz",
},
{
desc: "empty multi error",
// This is technically invalid,
// but we guard against it,
// so let's test it too.
give: &invalidEmptyMultiError{},
want: "invalid empty multi error",
},
{
desc: "single wrapped error",
give: &multierror.Error{
Errors: []error{
errors.New("great sadness"),
},
},
want: "great sadness",
},
{
desc: "multi error inside single wrapped error",
give: &multierror.Error{
Errors: []error{
errors.Join(
errors.New("foo"),
errors.New("bar"),
errors.New("baz"),
),
},
},
want: "3 errors occurred:" +
"\n 1) foo" +
"\n 2) bar" +
"\n 3) baz",
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.desc, func(t *testing.T) {
t.Parallel()

got := errorMessage(tt.give)
assert.Equal(t, tt.want, got)
})
}
}

// invalidEmptyMultiError is an invalid error type
// that implements Unwrap() []error, but returns an empty slice.
// This is invalid per the contract for that method.
type invalidEmptyMultiError struct{}

func (*invalidEmptyMultiError) Error() string {
return "invalid empty multi error"
}

func (*invalidEmptyMultiError) Unwrap() []error {
return []error{} // invalid
}

0 comments on commit fea8fc5

Please sign in to comment.