Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump fosite version and integrate breaking changes #1042

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 7 additions & 4 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Gopkg.toml
Expand Up @@ -75,7 +75,7 @@

[[constraint]]
name = "github.com/ory/fosite"
version = "0.21.5"
version = "0.22.0"

[[constraint]]
name = "github.com/ory/graceful"
Expand Down
26 changes: 12 additions & 14 deletions consent/strategy_default.go
Expand Up @@ -28,8 +28,6 @@ import (
"strings"
"time"

"context"

jwtgo "github.com/dgrijalva/jwt-go"
"github.com/gorilla/sessions"
"github.com/ory/fosite"
Expand Down Expand Up @@ -119,7 +117,7 @@ func (s *DefaultStrategy) requestAuthentication(w http.ResponseWriter, r *http.R
return s.forwardAuthenticationRequest(w, r, ar, "", time.Time{}, nil)
}

session, err := s.M.GetAuthenticationSession(context.TODO(), sessionID)
session, err := s.M.GetAuthenticationSession(r.Context(), sessionID)
if errors.Cause(err) == pkg.ErrNotFound {
return s.forwardAuthenticationRequest(w, r, ar, "", time.Time{}, nil)
} else if err != nil {
Expand Down Expand Up @@ -147,7 +145,7 @@ func (s *DefaultStrategy) requestAuthentication(w http.ResponseWriter, r *http.R
return s.forwardAuthenticationRequest(w, r, ar, session.Subject, session.AuthenticatedAt, session)
}

token, err := s.JWTStrategy.Decode(idTokenHint)
token, err := s.JWTStrategy.Decode(r.Context(), idTokenHint)
if ve, ok := errors.Cause(err).(*jwtgo.ValidationError); err == nil || (ok && ve.Errors == jwtgo.ValidationErrorExpired) {
} else {
return err
Expand All @@ -162,7 +160,7 @@ func (s *DefaultStrategy) requestAuthentication(w http.ResponseWriter, r *http.R
return err
}

if s, err := s.M.GetForcedObfuscatedAuthenticationSession(context.TODO(), ar.GetClient().GetID(), hintSub); errors.Cause(err) == pkg.ErrNotFound {
if s, err := s.M.GetForcedObfuscatedAuthenticationSession(r.Context(), ar.GetClient().GetID(), hintSub); errors.Cause(err) == pkg.ErrNotFound {
// do nothing
} else if err != nil {
return err
Expand Down Expand Up @@ -208,7 +206,7 @@ func (s *DefaultStrategy) forwardAuthenticationRequest(w http.ResponseWriter, r

var idTokenHintClaims jwtgo.MapClaims
if idTokenHint := ar.GetRequestForm().Get("id_token_hint"); len(idTokenHint) > 0 {
token, err := s.JWTStrategy.Decode(idTokenHint)
token, err := s.JWTStrategy.Decode(r.Context(), idTokenHint)
if ve, ok := errors.Cause(err).(*jwtgo.ValidationError); err == nil || (ok && ve.Errors == jwtgo.ValidationErrorExpired) {
if hintClaims, ok := token.Claims.(jwtgo.MapClaims); ok {
idTokenHintClaims = hintClaims
Expand All @@ -223,7 +221,7 @@ func (s *DefaultStrategy) forwardAuthenticationRequest(w http.ResponseWriter, r

// Set the session
if err := s.M.CreateAuthenticationRequest(
context.TODO(),
r.Context(),
&AuthenticationRequest{
Challenge: challenge,
Verifier: verifier,
Expand Down Expand Up @@ -277,7 +275,7 @@ func (s *DefaultStrategy) revokeAuthenticationSession(w http.ResponseWriter, r *
return nil
}

return s.M.DeleteAuthenticationSession(context.TODO(), sid)
return s.M.DeleteAuthenticationSession(r.Context(), sid)
}

func revokeAuthenticationCookie(w http.ResponseWriter, r *http.Request, s sessions.Store) (string, error) {
Expand Down Expand Up @@ -313,7 +311,7 @@ func (s *DefaultStrategy) obfuscateSubjectIdentifier(subject string, req fosite.
}

func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester, verifier string) (*HandledAuthenticationRequest, error) {
session, err := s.M.VerifyAndInvalidateAuthenticationRequest(context.TODO(), verifier)
session, err := s.M.VerifyAndInvalidateAuthenticationRequest(r.Context(), verifier)
if errors.Cause(err) == pkg.ErrNotFound {
return nil, errors.WithStack(fosite.ErrAccessDenied.WithDebug("The login verifier has already been used, has not been granted, or is invalid."))
} else if err != nil {
Expand Down Expand Up @@ -386,7 +384,7 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re
}

if session.ForceSubjectIdentifier != "" {
if err := s.M.CreateForcedObfuscatedAuthenticationSession(context.TODO(), &ForcedObfuscatedAuthenticationSession{
if err := s.M.CreateForcedObfuscatedAuthenticationSession(r.Context(), &ForcedObfuscatedAuthenticationSession{
Subject: session.Subject,
ClientID: req.GetClient().GetID(),
SubjectObfuscated: session.ForceSubjectIdentifier,
Expand All @@ -410,7 +408,7 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re
cookie, _ := s.CookieStore.Get(r, cookieAuthenticationName)
sid := uuid.New()

if err := s.M.CreateAuthenticationSession(context.TODO(), &AuthenticationSession{
if err := s.M.CreateAuthenticationSession(r.Context(), &AuthenticationSession{
ID: sid,
Subject: session.Subject,
AuthenticatedAt: session.AuthenticatedAt,
Expand Down Expand Up @@ -472,7 +470,7 @@ func (s *DefaultStrategy) requestConsent(w http.ResponseWriter, r *http.Request,
// return s.forwardConsentRequest(w, r, ar, authenticationSession, nil)
// }

consentSessions, err := s.M.FindPreviouslyGrantedConsentRequests(context.TODO(), ar.GetClient().GetID(), authenticationSession.Subject)
consentSessions, err := s.M.FindPreviouslyGrantedConsentRequests(r.Context(), ar.GetClient().GetID(), authenticationSession.Subject)
if errors.Cause(err) == ErrNoPreviousConsentFound {
return s.forwardConsentRequest(w, r, ar, authenticationSession, nil)
} else if err != nil {
Expand Down Expand Up @@ -503,7 +501,7 @@ func (s *DefaultStrategy) forwardConsentRequest(w http.ResponseWriter, r *http.R
csrf := strings.Replace(uuid.New(), "-", "", -1)

if err := s.M.CreateConsentRequest(
context.TODO(),
r.Context(),
&ConsentRequest{
Challenge: challenge,
Verifier: verifier,
Expand Down Expand Up @@ -544,7 +542,7 @@ func (s *DefaultStrategy) forwardConsentRequest(w http.ResponseWriter, r *http.R
}

func (s *DefaultStrategy) verifyConsent(w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester, verifier string) (*HandledConsentRequest, error) {
session, err := s.M.VerifyAndInvalidateConsentRequest(context.TODO(), verifier)
session, err := s.M.VerifyAndInvalidateConsentRequest(r.Context(), verifier)
if errors.Cause(err) == pkg.ErrNotFound {
return nil, errors.WithStack(fosite.ErrAccessDenied.WithDebug("The consent verifier has already been used, has not been granted, or is invalid."))
} else if err != nil {
Expand Down
18 changes: 10 additions & 8 deletions consent/strategy_default_test.go
Expand Up @@ -34,6 +34,8 @@ import (
"testing"
"time"

"context"

"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
"github.com/julienschmidt/httprouter"
Expand Down Expand Up @@ -90,32 +92,32 @@ func TestStrategy(t *testing.T) {
PrivateKey: pkg.MustINSECURELOWENTROPYRSAKEYFORTEST(),
}

fooUserIDToken, _, err := jwts.Generate((jwt.IDTokenClaims{
fooUserIDToken, _, err := jwts.Generate(context.TODO(), jwt.IDTokenClaims{
Subject: "foouser",
ExpiresAt: time.Now().Add(time.Hour),
IssuedAt: time.Now(),
}).ToMapClaims(), jwt.NewHeaders())
}.ToMapClaims(), jwt.NewHeaders())
require.NoError(t, err)

forcedAuthUserIDToken, _, err := jwts.Generate((jwt.IDTokenClaims{
forcedAuthUserIDToken, _, err := jwts.Generate(context.TODO(), jwt.IDTokenClaims{
Subject: "forced-auth-user",
ExpiresAt: time.Now().Add(time.Hour),
IssuedAt: time.Now(),
}).ToMapClaims(), jwt.NewHeaders())
}.ToMapClaims(), jwt.NewHeaders())
require.NoError(t, err)

pairwiseIDToken, _, err := jwts.Generate((jwt.IDTokenClaims{
pairwiseIDToken, _, err := jwts.Generate(context.TODO(), jwt.IDTokenClaims{
Subject: "c737d5e1fec8896d096d49f6b1a73eb45ac7becb87de9ac3f0a350bad2a9c9fd",
ExpiresAt: time.Now().Add(time.Hour),
IssuedAt: time.Now(),
}).ToMapClaims(), jwt.NewHeaders())
}.ToMapClaims(), jwt.NewHeaders())
require.NoError(t, err)

expiredAuthUserToken, _, err := jwts.Generate((jwt.IDTokenClaims{
expiredAuthUserToken, _, err := jwts.Generate(context.TODO(), jwt.IDTokenClaims{
Subject: "user",
ExpiresAt: time.Now().Add(-time.Hour),
IssuedAt: time.Now(),
}).ToMapClaims(), jwt.NewHeaders())
}.ToMapClaims(), jwt.NewHeaders())
require.NoError(t, err)

cs := sessions.NewCookieStore([]byte("dummy-secret-yay"))
Expand Down
38 changes: 19 additions & 19 deletions jwk/jwt_strategy.go
Expand Up @@ -32,7 +32,7 @@ import (
)

type JWTStrategy interface {
GetPublicKeyID() (string, error)
GetPublicKeyID(ctx context.Context) (string, error)

jwt.JWTStrategy
}
Expand All @@ -43,7 +43,7 @@ func NewRS256JWTStrategy(m Manager, set string) (*RS256JWTStrategy, error) {
RS256JWTStrategy: &jwt.RS256JWTStrategy{},
Set: set,
}
if err := j.refresh(); err != nil {
if err := j.refresh(context.TODO()); err != nil {
return nil, err
}
return j, nil
Expand All @@ -60,53 +60,53 @@ type RS256JWTStrategy struct {
privateKeyID string
}

func (j *RS256JWTStrategy) Hash(in []byte) ([]byte, error) {
return j.RS256JWTStrategy.Hash(in)
func (j *RS256JWTStrategy) Hash(ctx context.Context, in []byte) ([]byte, error) {
return j.RS256JWTStrategy.Hash(ctx, in)
}

// GetSigningMethodLength will return the length of the signing method
func (j *RS256JWTStrategy) GetSigningMethodLength() int {
return j.RS256JWTStrategy.GetSigningMethodLength()
}

func (j *RS256JWTStrategy) GetSignature(token string) (string, error) {
return j.RS256JWTStrategy.GetSignature(token)
func (j *RS256JWTStrategy) GetSignature(ctx context.Context, token string) (string, error) {
return j.RS256JWTStrategy.GetSignature(ctx, token)
}

func (j *RS256JWTStrategy) Generate(claims jwt2.Claims, header jwt.Mapper) (string, string, error) {
if err := j.refresh(); err != nil {
func (j *RS256JWTStrategy) Generate(ctx context.Context, claims jwt2.Claims, header jwt.Mapper) (string, string, error) {
if err := j.refresh(ctx); err != nil {
return "", "", err
}

return j.RS256JWTStrategy.Generate(claims, header)
return j.RS256JWTStrategy.Generate(ctx, claims, header)
}

func (j *RS256JWTStrategy) Validate(token string) (string, error) {
if err := j.refresh(); err != nil {
func (j *RS256JWTStrategy) Validate(ctx context.Context, token string) (string, error) {
if err := j.refresh(ctx); err != nil {
return "", err
}

return j.RS256JWTStrategy.Validate(token)
return j.RS256JWTStrategy.Validate(ctx, token)
}

func (j *RS256JWTStrategy) Decode(token string) (*jwt2.Token, error) {
if err := j.refresh(); err != nil {
func (j *RS256JWTStrategy) Decode(ctx context.Context, token string) (*jwt2.Token, error) {
if err := j.refresh(ctx); err != nil {
return nil, err
}

return j.RS256JWTStrategy.Decode(token)
return j.RS256JWTStrategy.Decode(ctx, token)
}

func (j *RS256JWTStrategy) GetPublicKeyID() (string, error) {
if err := j.refresh(); err != nil {
func (j *RS256JWTStrategy) GetPublicKeyID(ctx context.Context) (string, error) {
if err := j.refresh(ctx); err != nil {
return "", err
}

return j.publicKeyID, nil
}

func (j *RS256JWTStrategy) refresh() error {
keys, err := j.Manager.GetKeySet(context.TODO(), j.Set)
func (j *RS256JWTStrategy) refresh(ctx context.Context) error {
keys, err := j.Manager.GetKeySet(ctx, j.Set)
if err != nil {
return err
}
Expand Down
12 changes: 6 additions & 6 deletions jwk/jwt_strategy_test.go
Expand Up @@ -45,31 +45,31 @@ func TestRS256JWTStrategy(t *testing.T) {

s, err := NewRS256JWTStrategy(m, "foo-set")
require.NoError(t, err)
a, b, err := s.Generate(jwt2.MapClaims{"foo": "bar"}, &jwt.Headers{})
a, b, err := s.Generate(context.TODO(), jwt2.MapClaims{"foo": "bar"}, &jwt.Headers{})
require.NoError(t, err)
assert.NotEmpty(t, a)
assert.NotEmpty(t, b)

_, err = s.Validate(a)
_, err = s.Validate(context.TODO(), a)
require.NoError(t, err)

kid, err := s.GetPublicKeyID()
kid, err := s.GetPublicKeyID(context.TODO())
assert.NoError(t, err)
assert.Equal(t, "public:foo", kid)

ks, err = testGenerator.Generate("bar", "sig")
require.NoError(t, err)
require.NoError(t, m.AddKeySet(context.TODO(), "foo-set", ks))

a, b, err = s.Generate(jwt2.MapClaims{"foo": "bar"}, &jwt.Headers{})
a, b, err = s.Generate(context.TODO(), jwt2.MapClaims{"foo": "bar"}, &jwt.Headers{})
require.NoError(t, err)
assert.NotEmpty(t, a)
assert.NotEmpty(t, b)

_, err = s.Validate(a)
_, err = s.Validate(context.TODO(), a)
require.NoError(t, err)

kid, err = s.GetPublicKeyID()
kid, err = s.GetPublicKeyID(context.TODO())
assert.NoError(t, err)
assert.Equal(t, "public:bar", kid)
}
10 changes: 5 additions & 5 deletions oauth2/handler.go
Expand Up @@ -281,13 +281,13 @@ func (h *Handler) UserinfoHandler(w http.ResponseWriter, r *http.Request) {
delete(interim, "exp")
delete(interim, "jti")

keyID, err := h.OpenIDJWTStrategy.GetPublicKeyID()
keyID, err := h.OpenIDJWTStrategy.GetPublicKeyID(r.Context())
if err != nil {
h.H.WriteError(w, r, err)
return
}

token, _, err := h.OpenIDJWTStrategy.Generate(jwt2.MapClaims(interim), &jwt.Headers{
token, _, err := h.OpenIDJWTStrategy.Generate(r.Context(), jwt2.MapClaims(interim), &jwt.Headers{
Extra: map[string]interface{}{
"kid": keyID,
},
Expand Down Expand Up @@ -526,7 +526,7 @@ func (h *Handler) TokenHandler(w http.ResponseWriter, r *http.Request) {
if accessRequest.GetGrantTypes().Exact("client_credentials") {
var accessTokenKeyID string
if h.AccessTokenStrategy == "jwt" {
accessTokenKeyID, err = h.AccessTokenJWTStrategy.GetPublicKeyID()
accessTokenKeyID, err = h.AccessTokenJWTStrategy.GetPublicKeyID(r.Context())
if err != nil {
pkg.LogError(err, h.L)
h.OAuth2.WriteAccessError(w, accessRequest, err)
Expand Down Expand Up @@ -599,7 +599,7 @@ func (h *Handler) AuthHandler(w http.ResponseWriter, r *http.Request, _ httprout
authorizeRequest.GrantScope(scope)
}

openIDKeyID, err := h.OpenIDJWTStrategy.GetPublicKeyID()
openIDKeyID, err := h.OpenIDJWTStrategy.GetPublicKeyID(r.Context())
if err != nil {
pkg.LogError(err, h.L)
h.writeAuthorizeError(w, authorizeRequest, err)
Expand All @@ -608,7 +608,7 @@ func (h *Handler) AuthHandler(w http.ResponseWriter, r *http.Request, _ httprout

var accessTokenKeyID string
if h.AccessTokenStrategy == "jwt" {
accessTokenKeyID, err = h.AccessTokenJWTStrategy.GetPublicKeyID()
accessTokenKeyID, err = h.AccessTokenJWTStrategy.GetPublicKeyID(r.Context())
if err != nil {
pkg.LogError(err, h.L)
h.writeAuthorizeError(w, authorizeRequest, err)
Expand Down