diff --git a/errkit/errors_test.go b/errkit/errors_test.go index e63eb63..f83db4a 100644 --- a/errkit/errors_test.go +++ b/errkit/errors_test.go @@ -5,6 +5,7 @@ import ( "github.com/pkg/errors" errorutil "github.com/projectdiscovery/utils/errors" + "github.com/stretchr/testify/require" "go.uber.org/multierr" stderrors "errors" @@ -77,3 +78,30 @@ func TestErrorUtil(t *testing.T) { t.Fatal("expected 3 errors") } } + +func TestErrKindCheck(t *testing.T) { + x := New("port closed or filtered").SetKind(ErrKindNetworkPermanent) + t.Run("Errkind With Normal Error", func(t *testing.T) { + wrapped := Wrap(x, "this is a wrapped error") + if !IsKind(wrapped, ErrKindNetworkPermanent) { + t.Fatal("expected to be able to find the original error") + } + }) + + // mix of multiple kinds + tmp := New("i/o timeout").SetKind(ErrKindNetworkTemporary) + t.Run("Errkind With Multiple Kinds", func(t *testing.T) { + wrapped := Append(x, tmp) + errx := FromError(wrapped) + val, ok := errx.kind.(*multiKind) + require.True(t, ok, "expected to be able to find the original error") + require.Equal(t, 2, len(val.kinds)) + }) + + // duplicate kinds + t.Run("Errkind With Duplicate Kinds", func(t *testing.T) { + wrapped := Append(x, x) + errx := FromError(wrapped) + require.True(t, errx.kind.Is(ErrKindNetworkPermanent), "expected to be able to find the original error") + }) +} diff --git a/errkit/kind.go b/errkit/kind.go index 3c586ca..9e3cd6c 100644 --- a/errkit/kind.go +++ b/errkit/kind.go @@ -28,6 +28,17 @@ var ( ErrKindUnknown = NewPrimitiveErrKind("unknown-error", "unknown error", nil) ) +var ( + // DefaultErrorKinds is the default error kinds used in classification + // if one intends to add more default error kinds it must be done in init() function + // of that package to avoid race conditions + DefaultErrorKinds = []ErrKind{ + ErrKindNetworkTemporary, + ErrKindNetworkPermanent, + ErrKindDeadline, + } +) + // ErrKind is an interface that represents a kind of error type ErrKind interface { // Is checks if current error kind is same as given error kind @@ -110,6 +121,11 @@ func isNetworkPermanentErr(err *ErrorX) bool { return true case strings.Contains(v, "could not resolve host"): return true + case strings.Contains(v, "port closed or filtered"): + // pd standard error for port closed or filtered + return true + case strings.Contains(v, "connect: connection refused"): + return true } return false } @@ -192,7 +208,7 @@ func CombineErrKinds(kind ...ErrKind) ErrKind { f := &multiKind{} uniq := map[ErrKind]struct{}{} for _, k := range kind { - if k == nil { + if k == nil || k.String() == "" { continue } if val, ok := k.(*multiKind); ok { @@ -206,15 +222,27 @@ func CombineErrKinds(kind ...ErrKind) ErrKind { all := maps.Keys(uniq) for _, k := range all { for u := range uniq { - if k.IsParent(u) || k.Is(u) { + if k.IsParent(u) { + delete(uniq, u) + } + } + } + if len(uniq) > 1 { + // check and remove unknown error kind + for k := range uniq { + if k.Is(ErrKindUnknown) { delete(uniq, k) } } } + f.kinds = maps.Keys(uniq) if len(f.kinds) > MaxErrorDepth { f.kinds = f.kinds[:MaxErrorDepth] } + if len(f.kinds) == 1 { + return f.kinds[0] + } return f } @@ -237,6 +265,12 @@ func GetErrorKind(err error, defs ...ErrKind) ErrKind { return def } } + // check in default error kinds + for _, def := range DefaultErrorKinds { + if def.Represents(x) { + return def + } + } return ErrKindUnknown }