Skip to content

Commit

Permalink
0.3.1 (#98)
Browse files Browse the repository at this point in the history
* all: better error handling - closes #100 
* oauth2/implicit: bad HTML encoding of the scope parameter - closes #95
* oauth2: state parameter is missing when response_type=id_token - closes #96
* oauth2: id token hashes are not base64 url encoded - closes #97
* openid: hybrid flow using `token+code+id_token` returns multiple tokens of the same type - closes #99
  • Loading branch information
arekkas committed Sep 22, 2016
1 parent cb328ca commit b16e3fc
Show file tree
Hide file tree
Showing 12 changed files with 150 additions and 28 deletions.
89 changes: 86 additions & 3 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@ var (
ErrInsufficientEntropy = errors.Errorf("The request used a security parameter (e.g., anti-replay, anti-csrf) with insufficient entropy (minimum of %d characters)", MinParameterEntropy)
ErrMisconfiguration = errors.New("The request failed because of an internal error that is probably caused by misconfiguration")
ErrNotFound = errors.New("Could not find the requested resource(s)")
ErrInvalidTokenFormat = errors.New("Invalid token format")
ErrTokenSignatureMismatch = errors.New("Token signature mismatch")
ErrTokenExpired = errors.New("Token expired")
ErrScopeNotGranted = errors.New("The token was not granted the requested scope")
ErrTokenClaim = errors.New("The token failed validation due to a claim mismatch")
)

const (
errRequestUnauthorized = "request_unauthorized"
errRequestForbidden = "request_forbidden"
errInvalidRequestName = "invalid_request"
errUnauthorizedClientName = "unauthorized_client"
errAccessDeniedName = "acccess_denied"
Expand All @@ -36,10 +43,16 @@ const (
errUnsupportedGrantTypeName = "unsupported_grant_type"
errInvalidGrantName = "invalid_grant"
errInvalidClientName = "invalid_client"
errInvalidError = "invalid_error"
UnknownErrorName = "unknown_error"
errNotFound = "not_found"
errInvalidState = "invalid_state"
errMisconfiguration = "misconfiguration"
errInsufficientEntropy = "insufficient_entropy"
errInvalidTokenFormat = "invalid_token"
errTokenSignatureMismatch = "token_signature_mismatch"
errTokenExpired = "token_expired"
errScopeNotGranted = "scope_not_granted"
errTokenClaim = "token_claim"
)

type RFC6749Error struct {
Expand All @@ -52,6 +65,69 @@ type RFC6749Error struct {

func ErrorToRFC6749Error(err error) *RFC6749Error {
switch errors.Cause(err) {
case ErrTokenClaim: {
return &RFC6749Error{
Name: errTokenClaim,
Description: ErrTokenClaim.Error(),
Debug: err.Error(),
Hint: "One or more token claims failed validation.",
StatusCode: http.StatusUnauthorized,
}
}
case ErrScopeNotGranted: {
return &RFC6749Error{
Name: errScopeNotGranted,
Description: ErrScopeNotGranted.Error(),
Debug: err.Error(),
Hint: "The resource owner did not grant the requested scope.",
StatusCode: http.StatusForbidden,
}
}
case ErrTokenExpired: {
return &RFC6749Error{
Name: errTokenExpired,
Description: ErrTokenExpired.Error(),
Debug: err.Error(),
Hint: "The token expired.",
StatusCode: http.StatusUnauthorized,
}
}
case ErrInvalidTokenFormat: {
return &RFC6749Error{
Name: errInvalidTokenFormat,
Description: ErrInvalidTokenFormat.Error(),
Debug: err.Error(),
Hint: "Check that you provided a valid token in the right format.",
StatusCode: http.StatusBadRequest,
}
}
case ErrTokenSignatureMismatch: {
return &RFC6749Error{
Name: errTokenSignatureMismatch,
Description: ErrTokenSignatureMismatch.Error(),
Debug: err.Error(),
Hint: "Check that you provided a valid token in the right format.",
StatusCode: http.StatusBadRequest,
}
}
case ErrRequestUnauthorized: {
return &RFC6749Error{
Name: errRequestUnauthorized,
Description: ErrRequestUnauthorized.Error(),
Debug: err.Error(),
Hint: "Check that you provided valid credentials in the right format.",
StatusCode: http.StatusUnauthorized,
}
}
case ErrRequestForbidden: {
return &RFC6749Error{
Name: errRequestForbidden,
Description: ErrRequestForbidden.Error(),
Debug: err.Error(),
Hint: "You are not allowed to perform this action.",
StatusCode: http.StatusForbidden,
}
}
case ErrInvalidRequest:
return &RFC6749Error{
Name: errInvalidRequestName,
Expand Down Expand Up @@ -123,7 +199,7 @@ func ErrorToRFC6749Error(err error) *RFC6749Error {
Name: errInvalidClientName,
Description: ErrInvalidClient.Error(),
Debug: err.Error(),
StatusCode: http.StatusBadRequest,
StatusCode: http.StatusUnauthorized,
}
case ErrInvalidState:
return &RFC6749Error{
Expand All @@ -146,9 +222,16 @@ func ErrorToRFC6749Error(err error) *RFC6749Error {
Debug: err.Error(),
StatusCode: http.StatusInternalServerError,
}
case ErrNotFound:
return &RFC6749Error{
Name: errNotFound,
Description: ErrNotFound.Error(),
Debug: err.Error(),
StatusCode: http.StatusNotFound,
}
default:
return &RFC6749Error{
Name: errInvalidError,
Name: UnknownErrorName,
Description: "The error is unrecognizable.",
Debug: err.Error(),
StatusCode: http.StatusInternalServerError,
Expand Down
4 changes: 2 additions & 2 deletions errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
)

func TestErrorToRFC6749(t *testing.T) {
assert.Equal(t, errInvalidError, ErrorToRFC6749Error(errors.New("")).Name)
assert.Equal(t, errInvalidError, ErrorToRFC6749Error(native.New("")).Name)
assert.Equal(t, UnknownErrorName, ErrorToRFC6749Error(errors.New("")).Name)
assert.Equal(t, UnknownErrorName, ErrorToRFC6749Error(native.New("")).Name)

assert.Equal(t, errInvalidRequestName, ErrorToRFC6749Error(errors.Wrap(ErrInvalidRequest, "")).Name)
assert.Equal(t, errUnauthorizedClientName, ErrorToRFC6749Error(errors.Wrap(ErrUnauthorizedClient, "")).Name)
Expand Down
2 changes: 1 addition & 1 deletion handler/oauth2/flow_authorize_implicit.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (c *AuthorizeImplicitGrantTypeHandler) IssueImplicitAccessToken(ctx context
resp.AddFragment("expires_in", strconv.Itoa(int(c.AccessTokenLifespan/time.Second)))
resp.AddFragment("token_type", "bearer")
resp.AddFragment("state", ar.GetState())
resp.AddFragment("scope", strings.Join(ar.GetGrantedScopes(), "+"))
resp.AddFragment("scope", strings.Join(ar.GetGrantedScopes(), " "))
ar.SetResponseTypeHandled("token")

return nil
Expand Down
9 changes: 5 additions & 4 deletions handler/oauth2/strategy_hmacsha.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
enigma "github.com/ory-am/fosite/token/hmac"
"github.com/pkg/errors"
"golang.org/x/net/context"
"fmt"
)

type HMACSHAStrategy struct {
Expand All @@ -32,9 +33,9 @@ func (h HMACSHAStrategy) GenerateAccessToken(_ context.Context, _ fosite.Request

func (h HMACSHAStrategy) ValidateAccessToken(_ context.Context, r fosite.Requester, token string) (err error) {
if session, ok := r.GetSession().(HMACSessionContainer); !ok {
return errors.Errorf("Session must be of type HMACSessionContainer, got: %s", reflect.TypeOf(r.GetSession()))
return errors.Wrap(fosite.ErrMisconfiguration, fmt.Sprintf("Session must be of type HMACSessionContainer, got: %s", reflect.TypeOf(r.GetSession())))
} else if session.AccessTokenExpiresAt(r.GetRequestedAt().Add(h.AccessTokenLifespan)).Before(time.Now()) {
return errors.Errorf("Access token expired at %s", session.AccessTokenExpiresAt(r.GetRequestedAt().Add(h.AccessTokenLifespan)))
return errors.Wrap(fosite.ErrTokenExpired, fmt.Sprintf("Access token expired at %s", session.AccessTokenExpiresAt(r.GetRequestedAt().Add(h.AccessTokenLifespan))))
}
return h.Enigma.Validate(token)
}
Expand All @@ -53,9 +54,9 @@ func (h HMACSHAStrategy) GenerateAuthorizeCode(_ context.Context, _ fosite.Reque

func (h HMACSHAStrategy) ValidateAuthorizeCode(_ context.Context, r fosite.Requester, token string) (err error) {
if session, ok := r.GetSession().(HMACSessionContainer); !ok {
return errors.Errorf("Session must be of type HMACSessionContainer, got: %s", reflect.TypeOf(r.GetSession()))
return errors.Wrap(fosite.ErrMisconfiguration, fmt.Sprintf("Session must be of type HMACSessionContainer, got: %s", reflect.TypeOf(r.GetSession())))
} else if session.AuthorizeCodeExpiresAt(r.GetRequestedAt().Add(h.AuthorizeCodeLifespan)).Before(time.Now()) {
return errors.Errorf("Authorize code expired at %s", session.AuthorizeCodeExpiresAt(r.GetRequestedAt().Add(h.AccessTokenLifespan)))
return errors.Wrap(fosite.ErrTokenExpired, fmt.Sprintf("Authorize code expired at %s", session.AuthorizeCodeExpiresAt(r.GetRequestedAt().Add(h.AccessTokenLifespan))))
}

return h.Enigma.Validate(token)
Expand Down
27 changes: 26 additions & 1 deletion handler/oauth2/strategy_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/ory-am/fosite"
"github.com/ory-am/fosite/token/jwt"
jwtx "github.com/dgrijalva/jwt-go"
"github.com/pkg/errors"
"golang.org/x/net/context"
)
Expand Down Expand Up @@ -67,7 +68,31 @@ func (h *RS256JWTStrategy) validate(token string) error {

// validate the token
if err = t.Claims.Valid(); err != nil {
return errors.New("Token claims did not validate")
if e, ok := err.(*jwtx.ValidationError); ok {
switch e.Errors {
case jwtx.ValidationErrorMalformed:
return errors.Wrap(fosite.ErrInvalidTokenFormat, err.Error())
case jwtx.ValidationErrorUnverifiable:
return errors.Wrap(fosite.ErrTokenSignatureMismatch, err.Error())
case jwtx.ValidationErrorSignatureInvalid:
return errors.Wrap(fosite.ErrTokenSignatureMismatch, err.Error())
case jwtx.ValidationErrorAudience:
return errors.Wrap(fosite.ErrTokenClaim, err.Error())
case jwtx.ValidationErrorExpired:
return errors.Wrap(fosite.ErrTokenExpired, err.Error())
case jwtx.ValidationErrorIssuedAt:
return errors.Wrap(fosite.ErrTokenClaim, err.Error())
case jwtx.ValidationErrorIssuer :
return errors.Wrap(fosite.ErrTokenClaim, err.Error())
case jwtx.ValidationErrorNotValidYet:
return errors.Wrap(fosite.ErrTokenClaim, err.Error())
case jwtx.ValidationErrorId:
return errors.Wrap(fosite.ErrTokenClaim, err.Error())
case jwtx.ValidationErrorClaimsInvalid :
return errors.Wrap(fosite.ErrTokenClaim, err.Error())
}
return errors.Wrap(fosite.ErrRequestUnauthorized, err.Error())
}
}

return nil
Expand Down
8 changes: 4 additions & 4 deletions handler/oauth2/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (c *CoreValidator) validateAccessToken(ctx context.Context, token string, a
if err != nil {
return errors.Wrap(fosite.ErrRequestUnauthorized, err.Error())
} else if err := c.CoreStrategy.ValidateAccessToken(ctx, or, token); err != nil {
return errors.Wrap(fosite.ErrRequestUnauthorized, err.Error())
return err
}

for _, scope := range scopes {
Expand All @@ -49,7 +49,7 @@ func (c *CoreValidator) validateRefreshToken(ctx context.Context, token string,
if or, err := c.CoreStorage.GetAccessTokenSession(ctx, sig, accessRequest.GetSession()); err != nil {
return errors.Wrap(fosite.ErrRequestUnauthorized, err.Error())
} else if err := c.CoreStrategy.ValidateAccessToken(ctx, or, token); err != nil {
return errors.Wrap(fosite.ErrRequestUnauthorized, err.Error())
return err
} else {
accessRequest.Merge(or)
}
Expand All @@ -60,9 +60,9 @@ func (c *CoreValidator) validateRefreshToken(ctx context.Context, token string,
func (c *CoreValidator) validateAuthorizeCode(ctx context.Context, token string, accessRequest fosite.AccessRequester) error {
sig := c.CoreStrategy.AccessTokenSignature(token)
if or, err := c.CoreStorage.GetAccessTokenSession(ctx, sig, accessRequest.GetSession()); err != nil {
return errors.Wrap(fosite.ErrRequestUnauthorized, err.Error())
return errors.Wrap(err, fosite.ErrRequestUnauthorized.Error())
} else if err := c.CoreStrategy.ValidateAccessToken(ctx, or, token); err != nil {
return errors.Wrap(fosite.ErrRequestUnauthorized, err.Error())
return err
} else {
accessRequest.Merge(or)
}
Expand Down
4 changes: 2 additions & 2 deletions handler/oauth2/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ func TestValidateToken(t *testing.T) {
description: "should fail because validation fails",
setup: func() {
store.EXPECT().GetAccessTokenSession(nil, "asdf", nil).AnyTimes().Return(areq, nil)
chgen.EXPECT().ValidateAccessToken(nil, areq, "1234").Return(errors.New(""))
chgen.EXPECT().ValidateAccessToken(nil, areq, "1234").Return(errors.Wrap(fosite.ErrTokenExpired, ""))
},
expectErr: fosite.ErrRequestUnauthorized,
expectErr: fosite.ErrTokenExpired,
},
{
description: "should pass",
Expand Down
12 changes: 7 additions & 5 deletions handler/openid/flow_hybrid.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/ory-am/fosite/token/jwt"
"github.com/pkg/errors"
"golang.org/x/net/context"
"encoding/base64"
)

type OpenIDConnectHybridHandler struct {
Expand Down Expand Up @@ -62,14 +63,13 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context.
}

resp.AddFragment("code", code)
resp.AddFragment("state", ar.GetState())
ar.SetResponseTypeHandled("code")

hash, err := c.Enigma.Hash([]byte(resp.GetFragment().Get("code")))
if err != nil {
return err
}
claims.CodeHash = hash[:c.Enigma.GetSigningMethodLength()/2]
claims.CodeHash = []byte(base64.URLEncoding.EncodeToString([]byte(hash[:c.Enigma.GetSigningMethodLength()/2])))
}

if ar.GetResponseTypes().Has("token") {
Expand All @@ -84,7 +84,7 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context.
if err != nil {
return err
}
claims.AccessTokenHash = hash[:c.Enigma.GetSigningMethodLength()/2]
claims.AccessTokenHash = []byte(base64.URLEncoding.EncodeToString([]byte(hash[:c.Enigma.GetSigningMethodLength()/2])))
}

if !ar.GetGrantedScopes().Has("openid") {
Expand All @@ -93,13 +93,15 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context.

if err := c.IDTokenHandleHelper.IssueImplicitIDToken(ctx, req, ar, resp); err != nil {
return errors.Wrap(err, err.Error())
} else if err := c.IDTokenHandleHelper.IssueImplicitIDToken(ctx, req, ar, resp); err != nil {
return errors.Wrap(err, err.Error())
}

// there is no need to check for https, because implicit flow does not require https
// https://tools.ietf.org/html/rfc6819#section-4.4.2

if resp.GetFragment().Get("state") == "" {
resp.AddFragment("state", ar.GetState())
}

ar.SetResponseTypeHandled("id_token")
return nil
}
8 changes: 7 additions & 1 deletion handler/openid/flow_implicit.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/ory-am/fosite/token/jwt"
"github.com/pkg/errors"
"golang.org/x/net/context"
"encoding/base64"
)

type OpenIDConnectImplicitHandler struct {
Expand All @@ -23,6 +24,9 @@ type OpenIDConnectImplicitHandler struct {
func (c *OpenIDConnectImplicitHandler) HandleAuthorizeEndpointRequest(ctx context.Context, req *http.Request, ar fosite.AuthorizeRequester, resp fosite.AuthorizeResponder) error {
if !(ar.GetGrantedScopes().Has("openid") && (ar.GetResponseTypes().Has("token", "id_token") || ar.GetResponseTypes().Exact("id_token"))) {
return nil
} else if ar.GetResponseTypes().Has("code") {
// hybrid flow
return nil
}

if !ar.GetClient().GetGrantTypes().Has("implicit") {
Expand Down Expand Up @@ -59,7 +63,9 @@ func (c *OpenIDConnectImplicitHandler) HandleAuthorizeEndpointRequest(ctx contex
return err
}

claims.AccessTokenHash = hash[:c.RS256JWTStrategy.GetSigningMethodLength()/2]
claims.AccessTokenHash = []byte(base64.URLEncoding.EncodeToString([]byte(hash[:c.RS256JWTStrategy.GetSigningMethodLength()/2])))
} else {
resp.AddFragment("state", ar.GetState())
}

if err := c.IssueImplicitIDToken(ctx, req, ar, resp); err != nil {
Expand Down
4 changes: 4 additions & 0 deletions handler/openid/flow_implicit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func TestImplicit_HandleAuthorizeEndpointRequest(t *testing.T) {
description: "should not do anything because request requirements are not met",
setup: func() {
areq.ResponseTypes = fosite.Arguments{"id_token"}
areq.State = "foostate"
},
},
{
Expand Down Expand Up @@ -121,6 +122,7 @@ func TestImplicit_HandleAuthorizeEndpointRequest(t *testing.T) {
},
check: func() {
assert.NotEmpty(t, aresp.GetFragment().Get("id_token"))
assert.NotEmpty(t, aresp.GetFragment().Get("state"))
assert.Empty(t, aresp.GetFragment().Get("access_token"))
},
},
Expand All @@ -131,6 +133,7 @@ func TestImplicit_HandleAuthorizeEndpointRequest(t *testing.T) {
},
check: func() {
assert.NotEmpty(t, aresp.GetFragment().Get("id_token"))
assert.NotEmpty(t, aresp.GetFragment().Get("state"))
assert.NotEmpty(t, aresp.GetFragment().Get("access_token"))
},
},
Expand All @@ -142,6 +145,7 @@ func TestImplicit_HandleAuthorizeEndpointRequest(t *testing.T) {
},
check: func() {
assert.NotEmpty(t, aresp.GetFragment().Get("id_token"))
assert.NotEmpty(t, aresp.GetFragment().Get("state"))
assert.NotEmpty(t, aresp.GetFragment().Get("access_token"))
},
},
Expand Down
Loading

0 comments on commit b16e3fc

Please sign in to comment.