Skip to content

Commit

Permalink
fix: modernized JWT stateless introspection (#519)
Browse files Browse the repository at this point in the history
  • Loading branch information
mitar committed Oct 29, 2020
1 parent c747e64 commit a6bfb92
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 59 deletions.
5 changes: 3 additions & 2 deletions compose/compose_oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package compose

import (
"github.com/ory/fosite/handler/oauth2"
"github.com/ory/fosite/token/jwt"
)

// OAuth2AuthorizeExplicitFactory creates an OAuth2 authorize code grant ("authorize explicit flow") handler and registers
Expand Down Expand Up @@ -132,7 +133,7 @@ func OAuth2TokenIntrospectionFactory(config *Config, storage interface{}, strate
// If you need revocation, you can validate JWTs statefully, using the other factories.
func OAuth2StatelessJWTIntrospectionFactory(config *Config, storage interface{}, strategy interface{}) interface{} {
return &oauth2.StatelessJWTValidator{
JWTAccessTokenStrategy: strategy.(oauth2.JWTAccessTokenStrategy),
ScopeStrategy: config.GetScopeStrategy(),
JWTStrategy: strategy.(jwt.JWTStrategy),
ScopeStrategy: config.GetScopeStrategy(),
}
}
78 changes: 62 additions & 16 deletions handler/oauth2/introspector_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,84 @@ package oauth2

import (
"context"
"time"

"github.com/pkg/errors"
jwtx "github.com/dgrijalva/jwt-go"

"github.com/ory/fosite"
"github.com/ory/fosite/token/jwt"
)

type JWTAccessTokenStrategy interface {
AccessTokenStrategy
JWTStrategy
}

type StatelessJWTValidator struct {
JWTAccessTokenStrategy
jwt.JWTStrategy
ScopeStrategy fosite.ScopeStrategy
}

// AccessTokenJWTToRequest tries to reconstruct fosite.Request from a JWT.
func AccessTokenJWTToRequest(token *jwtx.Token) fosite.Requester {
mapClaims := token.Claims.(jwtx.MapClaims)
claims := jwt.JWTClaims{}
claims.FromMapClaims(mapClaims)

requestedAt := claims.IssuedAt
requestedAtClaim, ok := mapClaims["rat"]
if ok {
switch requestedAtClaim.(type) {
case float64:
requestedAt = time.Unix(int64(requestedAtClaim.(float64)), 0).UTC()
case int64:
requestedAt = time.Unix(requestedAtClaim.(int64), 0).UTC()
}
}

clientId := ""
clientIdClaim, ok := mapClaims["client_id"]
if ok {
switch clientIdClaim.(type) {
case string:
clientId = clientIdClaim.(string)
}
}

return &fosite.Request{
RequestedAt: requestedAt,
Client: &fosite.DefaultClient{
ID: clientId,
},
// We do not really know which scopes were requested, so we set them to granted.
RequestedScope: claims.Scope,
GrantedScope: claims.Scope,
Session: &JWTSession{
JWTClaims: &claims,
JWTHeader: &jwt.Headers{
Extra: token.Header,
},
ExpiresAt: map[fosite.TokenType]time.Time{
fosite.AccessToken: claims.ExpiresAt,
},
Subject: claims.Subject,
},
// We do not really know which audiences were requested, so we set them to granted.
RequestedAudience: claims.Audience,
GrantedAudience: claims.Audience,
}
}

func (v *StatelessJWTValidator) IntrospectToken(ctx context.Context, token string, tokenUse fosite.TokenUse, accessRequest fosite.AccessRequester, scopes []string) (fosite.TokenUse, error) {
or, err := v.JWTAccessTokenStrategy.ValidateJWT(ctx, fosite.AccessToken, token)
t, err := validate(ctx, v.JWTStrategy, token)
if err != nil {
return "", err
}

for _, scope := range scopes {
if scope == "" {
continue
}
// TODO: From here we assume it is an access token, but how do we know it is really and that is not an ID token?

if !v.ScopeStrategy(or.GetGrantedScopes(), scope) {
return "", errors.WithStack(fosite.ErrInvalidScope)
}
requester := AccessTokenJWTToRequest(t)

if err := matchScopes(v.ScopeStrategy, requester.GetGrantedScopes(), scopes); err != nil {
return fosite.AccessToken, err
}

accessRequest.Merge(or)
accessRequest.Merge(requester)

return fosite.AccessToken, nil
}
6 changes: 3 additions & 3 deletions handler/oauth2/introspector_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ func TestIntrospectJWT(t *testing.T) {
}

v := &StatelessJWTValidator{
JWTAccessTokenStrategy: strat,
ScopeStrategy: fosite.HierarchicScopeStrategy,
JWTStrategy: strat,
ScopeStrategy: fosite.HierarchicScopeStrategy,
}

for k, c := range []struct {
Expand Down Expand Up @@ -137,7 +137,7 @@ func BenchmarkIntrospectJWT(b *testing.B) {
}

v := &StatelessJWTValidator{
JWTAccessTokenStrategy: strat,
JWTStrategy: strat,
}

jwt := jwtValidCase(fosite.AccessToken)
Expand Down
4 changes: 0 additions & 4 deletions handler/oauth2/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ type CoreStrategy interface {
AuthorizeCodeStrategy
}

type JWTStrategy interface {
ValidateJWT(ctx context.Context, tokenType fosite.TokenType, token string) (requester fosite.Requester, err error)
}

type AccessTokenStrategy interface {
AccessTokenSignature(token string) string
GenerateAccessToken(ctx context.Context, requester fosite.Requester) (token string, signature string, err error)
Expand Down
37 changes: 3 additions & 34 deletions handler/oauth2/strategy_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,41 +69,10 @@ func (h *DefaultJWTStrategy) GenerateAccessToken(ctx context.Context, requester
}

func (h *DefaultJWTStrategy) ValidateAccessToken(ctx context.Context, _ fosite.Requester, token string) error {
_, err := h.validate(ctx, token)
_, err := validate(ctx, h.JWTStrategy, token)
return err
}

func (h *DefaultJWTStrategy) ValidateJWT(ctx context.Context, tokenType fosite.TokenType, token string) (requester fosite.Requester, err error) {
t, err := h.validate(ctx, token)
if err != nil {
return nil, err
}

claims := jwt.JWTClaims{
ScopeField: h.ScopeField,
}
claims.FromMapClaims(t.Claims.(jwtx.MapClaims))

requester = &fosite.Request{
Client: &fosite.DefaultClient{},
RequestedAt: claims.IssuedAt,
Session: &JWTSession{
JWTClaims: &claims,
JWTHeader: &jwt.Headers{
Extra: make(map[string]interface{}),
},
ExpiresAt: map[fosite.TokenType]time.Time{
tokenType: claims.ExpiresAt,
},
Subject: claims.Subject,
},
RequestedScope: claims.Scope,
GrantedScope: claims.Scope,
}

return
}

func (h DefaultJWTStrategy) RefreshTokenSignature(token string) string {
return h.HMACSHAStrategy.RefreshTokenSignature(token)
}
Expand All @@ -128,8 +97,8 @@ func (h *DefaultJWTStrategy) ValidateAuthorizeCode(ctx context.Context, req fosi
return h.HMACSHAStrategy.ValidateAuthorizeCode(ctx, req, token)
}

func (h *DefaultJWTStrategy) validate(ctx context.Context, token string) (t *jwtx.Token, err error) {
t, err = h.JWTStrategy.Decode(ctx, token)
func validate(ctx context.Context, jwtStrategy jwt.JWTStrategy, token string) (t *jwtx.Token, err error) {
t, err = jwtStrategy.Decode(ctx, token)

if err == nil {
err = t.Claims.Valid()
Expand Down

0 comments on commit a6bfb92

Please sign in to comment.