Skip to content

Commit

Permalink
fix: handle concurrent transactional errors in the refresh token gran…
Browse files Browse the repository at this point in the history
…t handler (#402)

This commit provides the functionality required to address ory/hydra#1719 & ory/hydra#1735 by adding error checking to the RefreshTokenGrantHandler's PopulateTokenEndpointResponse method so it can deal with errors due to concurrent access. This will allow the authorization server to render a better error to the user-agent.

No longer returns fosite.ErrServerError in the event the storage. Instead a wrapped fosite.ErrNotFound is returned when fetching the refresh token fails due to it no longer being present. This scenario is caused when the user sends two or more request to refresh using the same token and one request gets into the handler just after the prior request finished and successfully committed its transaction.

Adds unit test coverage for transaction error handling logic added to the RefreshTokenGrantHandler's PopulateTokenEndpointResponse method
  • Loading branch information
aaslamin committed Mar 25, 2020
1 parent f99bb80 commit b17190b
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 25 deletions.
5 changes: 4 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ var (
// ErrInvalidatedAuthorizeCode is an error indicating that an authorization code has been
// used previously.
ErrInvalidatedAuthorizeCode = errors.New("Authorization code has ben invalidated")
ErrUnknownRequest = &RFC6749Error{
// ErrSerializationFailure is an error indicating that the transactional capable storage could not guarantee
// consistency of Update & Delete operations on the same rows between multiple sessions.
ErrSerializationFailure = errors.New("The request could not be completed due to concurrent access")
ErrUnknownRequest = &RFC6749Error{
Name: errUnknownErrorName,
Description: "The handler is not responsible for this request",
Code: http.StatusBadRequest,
Expand Down
56 changes: 32 additions & 24 deletions handler/oauth2/flow_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@ import (
"strings"
"time"

"github.com/ory/fosite"
"github.com/ory/fosite/storage"

"github.com/pkg/errors"

"github.com/ory/fosite"
)

type RefreshTokenGrantHandler struct {
Expand Down Expand Up @@ -139,34 +137,22 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con

ts, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil)
if err != nil {
if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.TokenRevocationStorage); rollBackTxnErr != nil {
err = rollBackTxnErr
}
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return handleRefreshTokenEndpointResponseStorageError(ctx, c.TokenRevocationStorage, err)
} else if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, ts.GetID()); err != nil {
if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.TokenRevocationStorage); rollBackTxnErr != nil {
err = rollBackTxnErr
}
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return handleRefreshTokenEndpointResponseStorageError(ctx, c.TokenRevocationStorage, err)
} else if err := c.TokenRevocationStorage.RevokeRefreshToken(ctx, ts.GetID()); err != nil {
if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.TokenRevocationStorage); rollBackTxnErr != nil {
err = rollBackTxnErr
}
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return handleRefreshTokenEndpointResponseStorageError(ctx, c.TokenRevocationStorage, err)
}

storeReq := requester.Sanitize([]string{})
storeReq.SetID(ts.GetID())

if err := c.TokenRevocationStorage.CreateAccessTokenSession(ctx, accessSignature, storeReq); err != nil {
if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.TokenRevocationStorage); rollBackTxnErr != nil {
err = rollBackTxnErr
}
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
} else if err := c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, storeReq); err != nil {
if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.TokenRevocationStorage); rollBackTxnErr != nil {
err = rollBackTxnErr
}
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return handleRefreshTokenEndpointResponseStorageError(ctx, c.TokenRevocationStorage, err)
}

if err := c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, storeReq); err != nil {
return handleRefreshTokenEndpointResponseStorageError(ctx, c.TokenRevocationStorage, err)
}

responder.SetAccessToken(accessToken)
Expand All @@ -181,3 +167,25 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con

return nil
}

func handleRefreshTokenEndpointResponseStorageError(ctx context.Context, store TokenRevocationStorage, storageErr error) (err error) {
defer func() {
if rbErr := storage.MaybeRollbackTx(ctx, store); rbErr != nil {
err = errors.WithStack(fosite.ErrServerError.WithDebug(rbErr.Error()))
}
}()

if errors.Cause(storageErr) == fosite.ErrSerializationFailure {
return errors.WithStack(fosite.ErrInvalidRequest.
WithDebugf(storageErr.Error()).
WithHint("Failed to refresh token because of multiple concurrent requests using the same token which is not allowed."))
}

if errors.Cause(storageErr) == fosite.ErrNotFound {
return errors.WithStack(fosite.ErrInvalidRequest.
WithDebugf(storageErr.Error()).
WithHint("Failed to refresh token because of multiple concurrent requests using the same token which is not allowed."))
}

return errors.WithStack(fosite.ErrServerError.WithDebug(storageErr.Error()))
}
169 changes: 169 additions & 0 deletions handler/oauth2/flow_refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,29 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) {
},
{
description: "transaction should be rolled back if call to `GetRefreshTokenSession` results in an error",
setup: func() {
request.GrantTypes = fosite.Arguments{"refresh_token"}
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(propagatedContext, nil).
Times(1)
mockRevocationStore.
EXPECT().
GetRefreshTokenSession(propagatedContext, gomock.Any(), nil).
Return(nil, errors.New("Whoops, a nasty database error occurred!")).
Times(1)
mockTransactional.
EXPECT().
Rollback(propagatedContext).
Return(nil).
Times(1)
},
expectError: fosite.ErrServerError,
},
{
description: "should result in a fosite.ErrInvalidRequest if `GetRefreshTokenSession` results in a " +
"fosite.ErrNotFound error",
setup: func() {
request.GrantTypes = fosite.Arguments{"refresh_token"}
mockTransactional.
Expand All @@ -424,6 +447,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) {
Return(nil).
Times(1)
},
expectError: fosite.ErrInvalidRequest,
},
{
description: "transaction should be rolled back if call to `RevokeAccessToken` results in an error",
Expand All @@ -450,6 +474,35 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) {
Return(nil).
Times(1)
},
expectError: fosite.ErrServerError,
},
{
description: "should result in a fosite.ErrInvalidRequest if call to `RevokeAccessToken` results in a " +
"fosite.ErrSerializationFailure error",
setup: func() {
request.GrantTypes = fosite.Arguments{"refresh_token"}
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(propagatedContext, nil).
Times(1)
mockRevocationStore.
EXPECT().
GetRefreshTokenSession(propagatedContext, gomock.Any(), nil).
Return(request, nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeAccessToken(propagatedContext, gomock.Any()).
Return(fosite.ErrSerializationFailure).
Times(1)
mockTransactional.
EXPECT().
Rollback(propagatedContext).
Return(nil).
Times(1)
},
expectError: fosite.ErrInvalidRequest,
},
{
description: "transaction should be rolled back if call to `RevokeRefreshToken` results in an error",
Expand Down Expand Up @@ -481,6 +534,77 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) {
Return(nil).
Times(1)
},
expectError: fosite.ErrServerError,
},
{
description: "should result in a fosite.ErrInvalidRequest if call to `RevokeRefreshToken` results in a " +
"fosite.ErrSerializationFailure error",
setup: func() {
request.GrantTypes = fosite.Arguments{"refresh_token"}
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(propagatedContext, nil).
Times(1)
mockRevocationStore.
EXPECT().
GetRefreshTokenSession(propagatedContext, gomock.Any(), nil).
Return(request, nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeAccessToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeRefreshToken(propagatedContext, gomock.Any()).
Return(fosite.ErrSerializationFailure).
Times(1)
mockTransactional.
EXPECT().
Rollback(propagatedContext).
Return(nil).
Times(1)
},
expectError: fosite.ErrInvalidRequest,
},
{
description: "should result in a fosite.ErrInvalidRequest if call to `CreateAccessTokenSession` results in " +
"a fosite.ErrSerializationFailure error",
setup: func() {
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(propagatedContext, nil).
Times(1)
mockRevocationStore.
EXPECT().
GetRefreshTokenSession(propagatedContext, gomock.Any(), nil).
Return(request, nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeAccessToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeRefreshToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
CreateAccessTokenSession(propagatedContext, gomock.Any(), gomock.Any()).
Return(fosite.ErrSerializationFailure).
Times(1)
mockTransactional.
EXPECT().
Rollback(propagatedContext).
Return(nil).
Times(1)
},
expectError: fosite.ErrInvalidRequest,
},
{
description: "transaction should be rolled back if call to `CreateAccessTokenSession` results in an error",
Expand Down Expand Up @@ -516,6 +640,7 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) {
Return(nil).
Times(1)
},
expectError: fosite.ErrServerError,
},
{
description: "transaction should be rolled back if call to `CreateRefreshTokenSession` results in an error",
Expand Down Expand Up @@ -557,6 +682,50 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) {
Return(nil).
Times(1)
},
expectError: fosite.ErrServerError,
},
{
description: "should result in a fosite.ErrInvalidRequest if call to `CreateRefreshTokenSession` results in " +
"a fosite.ErrSerializationFailure error",
setup: func() {
request.GrantTypes = fosite.Arguments{"refresh_token"}
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(propagatedContext, nil).
Times(1)
mockRevocationStore.
EXPECT().
GetRefreshTokenSession(propagatedContext, gomock.Any(), nil).
Return(request, nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeAccessToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeRefreshToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
CreateAccessTokenSession(propagatedContext, gomock.Any(), gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()).
Return(fosite.ErrSerializationFailure).
Times(1)
mockTransactional.
EXPECT().
Rollback(propagatedContext).
Return(nil).
Times(1)
},
expectError: fosite.ErrInvalidRequest,
},
{
description: "should result in a server error if transaction cannot be created",
Expand Down

0 comments on commit b17190b

Please sign in to comment.