Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Errors now wrap underlying errors #479

Merged
merged 1 commit into from
Sep 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion access_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (f *Fosite) NewAccessRequest(ctx context.Context, r *http.Request, session
if r.Method != "POST" {
return accessRequest, errors.WithStack(ErrInvalidRequest.WithHintf("HTTP method is \"%s\", expected \"POST\".", r.Method))
} else if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart {
return accessRequest, errors.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithDebug(err.Error()))
return accessRequest, errors.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithCause(err).WithDebug(err.Error()))
} else if len(r.PostForm) == 0 {
return accessRequest, errors.WithStack(ErrInvalidRequest.WithHint("The POST body can not be empty."))
}
Expand Down
3 changes: 1 addition & 2 deletions access_response_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"testing"

"github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -96,7 +95,7 @@ func TestNewAccessResponse(t *testing.T) {
ar, err := f.NewAccessResponse(nil, nil)

if c.expectErr != nil {
assert.EqualError(t, errors.Cause(err), c.expectErr.Error())
assert.EqualError(t, err, c.expectErr.Error())
} else {
require.NoError(t, err)
assert.Equal(t, ar, c.expect)
Expand Down
4 changes: 2 additions & 2 deletions audience_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ func DefaultAudienceMatchingStrategy(haystack []string, needle []string) error {
for _, n := range needle {
nu, err := url.Parse(n)
if err != nil {
return errors.WithStack(ErrInvalidRequest.WithHintf(`Unable to parse requested audience "%s".`, n).WithDebug(err.Error()))
return errors.WithStack(ErrInvalidRequest.WithHintf(`Unable to parse requested audience "%s".`, n).WithCause(err).WithDebug(err.Error()))
}

var found bool
for _, h := range haystack {
hu, err := url.Parse(h)
if err != nil {
return errors.WithStack(ErrInvalidRequest.WithHintf(`Unable to parse whitelisted audience "%s".`, h).WithDebug(err.Error()))
return errors.WithStack(ErrInvalidRequest.WithHintf(`Unable to parse whitelisted audience "%s".`, h).WithCause(err).WithDebug(err.Error()))
}

allowedPath := strings.TrimRight(hu.Path, "/")
Expand Down
2 changes: 1 addition & 1 deletion authorize_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (f *Fosite) WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequest
query.Add("state", ar.GetState())

var redirectURIString string
if !(len(ar.GetResponseTypes()) == 0 || ar.GetResponseTypes().ExactOne("code")) && errors.Cause(err) != ErrUnsupportedResponseType {
if !(len(ar.GetResponseTypes()) == 0 || ar.GetResponseTypes().ExactOne("code")) && !errors.Is(err, ErrUnsupportedResponseType) {
redirectURIString = redirectURI.String() + "#" + query.Encode()
} else {
for key, values := range redirectURI.Query() {
Expand Down
2 changes: 1 addition & 1 deletion authorize_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func GetRedirectURIFromRequestValues(values url.Values) (string, error) {
// The endpoint URI MAY include an "application/x-www-form-urlencoded" formatted (per Appendix B) query component
redirectURI, err := url.QueryUnescape(values.Get("redirect_uri"))
if err != nil {
return "", errors.WithStack(ErrInvalidRequest.WithHint(`The "redirect_uri" parameter is malformed or missing.`).WithDebug(err.Error()))
return "", errors.WithStack(ErrInvalidRequest.WithHint(`The "redirect_uri" parameter is malformed or missing.`).WithCause(err).WithDebug(err.Error()))
}
return redirectURI, nil
}
Expand Down
11 changes: 6 additions & 5 deletions authorize_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,16 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequest(request *Aut
})
if err != nil {
// Do not re-process already enhanced errors
if e, ok := errors.Cause(err).(*jwt.ValidationError); ok {
var e *jwt.ValidationError
if errors.As(err, &e) {
if e.Inner != nil {
return e.Inner
}
return errors.WithStack(ErrInvalidRequestObject.WithHintf("Unable to verify the request object's signature.").WithDebug(err.Error()))
return errors.WithStack(ErrInvalidRequestObject.WithHintf("Unable to verify the request object's signature.").WithCause(err).WithDebug(err.Error()))
}
return err
} else if err := token.Claims.Valid(); err != nil {
return errors.WithStack(ErrInvalidRequestObject.WithHint("Unable to verify the request object because its claims could not be validated, check if the expiry time is set correctly.").WithDebug(err.Error()))
return errors.WithStack(ErrInvalidRequestObject.WithHint("Unable to verify the request object because its claims could not be validated, check if the expiry time is set correctly.").WithCause(err).WithDebug(err.Error()))
}

claims, ok := token.Claims.(*jwt.MapClaims)
Expand Down Expand Up @@ -222,7 +223,7 @@ func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (Auth
}

if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart {
return request, errors.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithDebug(err.Error()))
return request, errors.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithCause(err).WithDebug(err.Error()))
}

request.Form = r.Form
Expand All @@ -233,7 +234,7 @@ func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (Auth

client, err := f.Store.GetClient(ctx, request.GetRequestForm().Get("client_id"))
if err != nil {
return request, errors.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 Client does not exist.").WithDebug(err.Error()))
return request, errors.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 Client does not exist.").WithCause(err).WithDebug(err.Error()))
}
request.Client = client

Expand Down
2 changes: 1 addition & 1 deletion authorize_request_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func TestNewAuthorizeRequest(t *testing.T) {

ar, err := c.conf.NewAuthorizeRequest(context.Background(), c.r)
if c.expectedError != nil {
assert.EqualError(t, errors.Cause(err), c.expectedError.Error())
assert.EqualError(t, err, c.expectedError.Error())
// https://github.com/ory/hydra/issues/1642
AssertObjectKeysEqual(t, &AuthorizeRequest{State: c.query.Get("state")}, ar, "State")
} else {
Expand Down
17 changes: 9 additions & 8 deletions client_authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (f *Fosite) AuthenticateClient(ctx context.Context, r *http.Request, form u

client, err = f.Store.GetClient(ctx, clientID)
if err != nil {
return nil, errors.WithStack(ErrInvalidClient.WithDebug(err.Error()))
return nil, errors.WithStack(ErrInvalidClient.WithCause(err).WithDebug(err.Error()))
}

oidcClient, ok := client.(OpenIDConnectClient)
Expand Down Expand Up @@ -134,15 +134,16 @@ func (f *Fosite) AuthenticateClient(ctx context.Context, r *http.Request, form u
})
if err != nil {
// Do not re-process already enhanced errors
if e, ok := errors.Cause(err).(*jwt.ValidationError); ok {
var e *jwt.ValidationError
if errors.As(err, &e) {
if e.Inner != nil {
return nil, e.Inner
}
return nil, errors.WithStack(ErrInvalidClient.WithHint("Unable to verify the integrity of the \"client_assertion\" value.").WithDebug(err.Error()))
return nil, errors.WithStack(ErrInvalidClient.WithHint("Unable to verify the integrity of the \"client_assertion\" value.").WithCause(err).WithDebug(err.Error()))
}
return nil, err
} else if err := token.Claims.Valid(); err != nil {
return nil, errors.WithStack(ErrInvalidClient.WithHint("Unable to verify the request object because its claims could not be validated, check if the expiry time is set correctly.").WithDebug(err.Error()))
return nil, errors.WithStack(ErrInvalidClient.WithHint("Unable to verify the request object because its claims could not be validated, check if the expiry time is set correctly.").WithCause(err).WithDebug(err.Error()))
}

claims, ok := token.Claims.(*jwt.MapClaims)
Expand Down Expand Up @@ -212,7 +213,7 @@ func (f *Fosite) AuthenticateClient(ctx context.Context, r *http.Request, form u

client, err := f.Store.GetClient(ctx, clientID)
if err != nil {
return nil, errors.WithStack(ErrInvalidClient.WithDebug(err.Error()))
return nil, errors.WithStack(ErrInvalidClient.WithCause(err).WithDebug(err.Error()))
}

if oidcClient, ok := client.(OpenIDConnectClient); !ok {
Expand All @@ -231,7 +232,7 @@ func (f *Fosite) AuthenticateClient(ctx context.Context, r *http.Request, form u

// Enforce client authentication
if err := f.Hasher.Compare(ctx, client.GetHashedSecret(), []byte(clientSecret)); err != nil {
return nil, errors.WithStack(ErrInvalidClient.WithDebug(err.Error()))
return nil, errors.WithStack(ErrInvalidClient.WithCause(err).WithDebug(err.Error()))
}

return client, nil
Expand Down Expand Up @@ -274,9 +275,9 @@ func clientCredentialsFromRequest(r *http.Request, form url.Values) (clientID, c
if id, secret, ok := r.BasicAuth(); !ok {
return clientCredentialsFromRequestBody(form, true)
} else if clientID, err = url.QueryUnescape(id); err != nil {
return "", "", errors.WithStack(ErrInvalidRequest.WithHint(`The client id in the HTTP authorization header could not be decoded from "application/x-www-form-urlencoded".`).WithDebug(err.Error()))
return "", "", errors.WithStack(ErrInvalidRequest.WithHint(`The client id in the HTTP authorization header could not be decoded from "application/x-www-form-urlencoded".`).WithCause(err).WithDebug(err.Error()))
} else if clientSecret, err = url.QueryUnescape(secret); err != nil {
return "", "", errors.WithStack(ErrInvalidRequest.WithHint(`The client secret in the HTTP authorization header could not be decoded from "application/x-www-form-urlencoded".`).WithDebug(err.Error()))
return "", "", errors.WithStack(ErrInvalidRequest.WithHint(`The client secret in the HTTP authorization header could not be decoded from "application/x-www-form-urlencoded".`).WithCause(err).WithDebug(err.Error()))
}

return clientID, clientSecret, nil
Expand Down
11 changes: 6 additions & 5 deletions client_authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,11 +497,12 @@ func TestAuthenticateClient(t *testing.T) {
}

if err != nil {
switch e := errors.Cause(err).(type) {
case *jwt.ValidationError:
t.Logf("Error is: %s", e.Inner)
case *RFC6749Error:
t.Logf("Debug is: %s", e.Debug)
var validationError *jwt.ValidationError
var rfcError *RFC6749Error
if errors.As(err, &validationError) {
t.Logf("Error is: %s", validationError.Inner)
} else if errors.As(err, &rfcError) {
t.Logf("Debug is: %s", rfcError.Debug)
}
}
require.NoError(t, err)
Expand Down
21 changes: 18 additions & 3 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,16 +253,16 @@ const (
)

func ErrorToRFC6749Error(err error) *RFC6749Error {
if e, ok := err.(*RFC6749Error); ok {
return e
} else if e, ok := errors.Cause(err).(*RFC6749Error); ok {
var e *RFC6749Error
if errors.As(err, &e) {
return e
}
return &RFC6749Error{
Name: errUnknownErrorName,
Description: "The error is unrecognizable.",
Debug: err.Error(),
Code: http.StatusInternalServerError,
cause: err,
}
}

Expand All @@ -272,6 +272,7 @@ type RFC6749Error struct {
Hint string
Code int
Debug string
cause error
}

func (e *RFC6749Error) Status() string {
Expand All @@ -294,6 +295,14 @@ func (e *RFC6749Error) StatusCode() int {
return e.Code
}

func (e *RFC6749Error) Cause() error {
return e.cause
}

func (e *RFC6749Error) Unwrap() error {
return e.cause
}

func (e *RFC6749Error) WithHintf(hint string, args ...interface{}) *RFC6749Error {
return e.WithHint(fmt.Sprintf(hint, args...))
}
Expand All @@ -320,6 +329,12 @@ func (e *RFC6749Error) WithDescription(description string) *RFC6749Error {
return &err
}

func (e *RFC6749Error) WithCause(cause error) *RFC6749Error {
err := *e
err.cause = cause
return &err
}

func (e *RFC6749Error) Sanitize() *RFC6749Error {
err := *e
err.Debug = ""
Expand Down
4 changes: 2 additions & 2 deletions handler/oauth2/flow_authorize_code_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ func (c *AuthorizeExplicitGrantHandler) HandleAuthorizeEndpointRequest(ctx conte
func (c *AuthorizeExplicitGrantHandler) IssueAuthorizeCode(ctx context.Context, ar fosite.AuthorizeRequester, resp fosite.AuthorizeResponder) error {
code, signature, err := c.AuthorizeCodeStrategy.GenerateAuthorizeCode(ctx, ar)
if err != nil {
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error()))
}

ar.GetSession().SetExpiresAt(fosite.AuthorizeCode, time.Now().UTC().Add(c.AuthCodeLifespan))
if err := c.CoreStorage.CreateAuthorizeCodeSession(ctx, signature, ar.Sanitize(c.GetSanitationWhiteList())); err != nil {
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error()))
}

resp.AddQuery("code", code)
Expand Down
3 changes: 1 addition & 2 deletions handler/oauth2/flow_authorize_code_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"testing"
"time"

"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -135,7 +134,7 @@ func TestAuthorizeCode_HandleAuthorizeEndpointRequest(t *testing.T) {
aresp := fosite.NewAuthorizeResponse()
err := h.HandleAuthorizeEndpointRequest(nil, c.areq, aresp)
if c.expectErr != nil {
require.EqualError(t, errors.Cause(err), c.expectErr.Error())
require.EqualError(t, err, c.expectErr.Error())
} else {
require.NoError(t, err)
}
Expand Down
28 changes: 14 additions & 14 deletions handler/oauth2/flow_authorize_code_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (c *AuthorizeExplicitGrantHandler) HandleTokenEndpointRequest(ctx context.C
code := request.GetRequestForm().Get("code")
signature := c.AuthorizeCodeStrategy.AuthorizeCodeSignature(code)
authorizeRequest, err := c.CoreStorage.GetAuthorizeCodeSession(ctx, signature, request.GetSession())
if errors.Cause(err) == fosite.ErrInvalidatedAuthorizeCode {
if errors.Is(err, fosite.ErrInvalidatedAuthorizeCode) {
if authorizeRequest == nil {
return fosite.ErrServerError.
WithHint("Misconfigured code lead to an error that prohibited the OAuth 2.0 Framework from processing this request.").
Expand All @@ -68,16 +68,16 @@ func (c *AuthorizeExplicitGrantHandler) HandleTokenEndpointRequest(ctx context.C
debug += "Revocation of refresh_token lead to error " + revErr.Error() + "."
}
return errors.WithStack(fosite.ErrInvalidGrant.WithHint(hint).WithDebug(debug))
} else if err != nil && errors.Cause(err).Error() == fosite.ErrNotFound.Error() {
return errors.WithStack(fosite.ErrInvalidGrant.WithDebug(err.Error()))
} else if err != nil && errors.Is(err, fosite.ErrNotFound) {
return errors.WithStack(fosite.ErrInvalidGrant.WithCause(err).WithDebug(err.Error()))
} else if err != nil {
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error()))
}

// The authorization server MUST verify that the authorization code is valid
// This needs to happen after store retrieval for the session to be hydrated properly
if err := c.AuthorizeCodeStrategy.ValidateAuthorizeCode(ctx, request, code); err != nil {
return errors.WithStack(fosite.ErrInvalidGrant.WithDebug(err.Error()))
return errors.WithStack(fosite.ErrInvalidGrant.WithCause(err).WithDebug(err.Error()))
}

// Override scopes
Expand Down Expand Up @@ -141,10 +141,10 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex
signature := c.AuthorizeCodeStrategy.AuthorizeCodeSignature(code)
authorizeRequest, err := c.CoreStorage.GetAuthorizeCodeSession(ctx, signature, requester.GetSession())
if err != nil {
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error()))
} else if err := c.AuthorizeCodeStrategy.ValidateAuthorizeCode(ctx, requester, code); err != nil {
// This needs to happen after store retrieval for the session to be hydrated properly
return errors.WithStack(fosite.ErrInvalidRequest.WithDebug(err.Error()))
return errors.WithStack(fosite.ErrInvalidRequest.WithCause(err).WithDebug(err.Error()))
}

for _, scope := range authorizeRequest.GetGrantedScopes() {
Expand All @@ -157,38 +157,38 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex

access, accessSignature, err := c.AccessTokenStrategy.GenerateAccessToken(ctx, requester)
if err != nil {
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error()))
}

var refresh, refreshSignature string
if canIssueRefreshToken(c, authorizeRequest) {
refresh, refreshSignature, err = c.RefreshTokenStrategy.GenerateRefreshToken(ctx, requester)
if err != nil {
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error()))
}
}

ctx, err = storage.MaybeBeginTx(ctx, c.CoreStorage)
if err != nil {
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error()))
}

if err := c.CoreStorage.InvalidateAuthorizeCodeSession(ctx, signature); err != nil {
if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil {
err = rollBackTxnErr
}
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error()))
} else if err := c.CoreStorage.CreateAccessTokenSession(ctx, accessSignature, requester.Sanitize([]string{})); err != nil {
if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil {
err = rollBackTxnErr
}
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error()))
} else if refreshSignature != "" {
if err := c.CoreStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester.Sanitize([]string{})); err != nil {
if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil {
err = rollBackTxnErr
}
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error()))
}
}

Expand All @@ -201,7 +201,7 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex
}

if err := storage.MaybeCommitTx(ctx, c.CoreStorage); err != nil {
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error()))
}

return nil
Expand Down
6 changes: 3 additions & 3 deletions handler/oauth2/flow_authorize_code_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func TestAuthorizeCode_PopulateTokenEndpointResponse(t *testing.T) {
err := h.PopulateTokenEndpointResponse(nil, c.areq, aresp)

if c.expectErr != nil {
require.EqualError(t, errors.Cause(err), c.expectErr.Error(), "%+v", err)
require.EqualError(t, err, c.expectErr.Error(), "%+v", err)
} else {
require.NoError(t, err, "%+v", err)
}
Expand Down Expand Up @@ -443,7 +443,7 @@ func TestAuthorizeCode_HandleTokenEndpointRequest(t *testing.T) {

err := h.HandleTokenEndpointRequest(context.Background(), c.areq)
if c.expectErr != nil {
require.EqualError(t, errors.Cause(err), c.expectErr.Error(), "%+v", err)
require.EqualError(t, err, c.expectErr.Error(), "%+v", err)
} else {
require.NoError(t, err, "%+v", err)
if c.check != nil {
Expand Down Expand Up @@ -675,7 +675,7 @@ func TestAuthorizeCodeTransactional_HandleTokenEndpointRequest(t *testing.T) {
}

if err := handler.PopulateTokenEndpointResponse(propagatedContext, request, response); testCase.expectError != nil {
assert.EqualError(t, errors.Cause(err), testCase.expectError.Error())
assert.EqualError(t, err, testCase.expectError.Error())
}
})
}
Expand Down
Loading