diff --git a/Gopkg.lock b/Gopkg.lock index d64ebe40ed..78420f9867 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -351,6 +351,7 @@ ".", "ext", "log", + "mocktracer", ] pruneopts = "" revision = "1949ddbfd147afd4d964a9f00b24eb291e0e7c38" @@ -391,7 +392,7 @@ version = "v3.3.0" [[projects]] - digest = "1:ae58fa3a26065fc3408e6c4e3899b62453e19bb5aa5e79725484bfb534422ca4" + digest = "1:0a0c89d174f4b76bfe9a10d18c92e5a1eb3f445d4073884889b8a0a389f0ea80" name = "github.com/ory/fosite" packages = [ ".", @@ -404,8 +405,8 @@ "token/jwt", ] pruneopts = "" - revision = "1ad9cd36069f61b2ace0fec097fe4bdc92e9f6c6" - version = "v0.21.5" + revision = "514fdbd20393c2175c66f3a69eb7bb849b3d5dfa" + version = "v0.22.0" [[projects]] digest = "1:88233ef02f3da33b9d4cf4f6c514c206ce4efec67f455c5be6dd3aa0fdf3bd32" @@ -467,12 +468,13 @@ version = "v0.0.1" [[projects]] - digest = "1:fff9b8b263350edc9041d38d9cc17c7b492f3bde282f6c10f0a648152e96e20d" + digest = "1:80e8a237b37d36ff222fce1c2dd34899479ef23b18bb123319b50611023785e3" name = "github.com/ory/sqlcon" packages = [ ".", "dockertest", ] + pruneopts = "" revision = "068c69998749cdb876c4adf179a8ed702864ad2b" version = "v0.0.7" @@ -869,6 +871,7 @@ "github.com/oleiade/reflections", "github.com/opentracing/opentracing-go", "github.com/opentracing/opentracing-go/ext", + "github.com/opentracing/opentracing-go/mocktracer", "github.com/ory/fosite", "github.com/ory/fosite/compose", "github.com/ory/fosite/handler/oauth2", diff --git a/Gopkg.toml b/Gopkg.toml index d6a88d084a..923e0dfee0 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -75,7 +75,7 @@ [[constraint]] name = "github.com/ory/fosite" - version = "0.21.5" + version = "0.22.0" [[constraint]] name = "github.com/ory/graceful" diff --git a/consent/strategy_default.go b/consent/strategy_default.go index 07f24ada8f..8f630ac1bd 100644 --- a/consent/strategy_default.go +++ b/consent/strategy_default.go @@ -28,8 +28,6 @@ import ( "strings" "time" - "context" - jwtgo "github.com/dgrijalva/jwt-go" "github.com/gorilla/sessions" "github.com/ory/fosite" @@ -119,7 +117,7 @@ func (s *DefaultStrategy) requestAuthentication(w http.ResponseWriter, r *http.R return s.forwardAuthenticationRequest(w, r, ar, "", time.Time{}, nil) } - session, err := s.M.GetAuthenticationSession(context.TODO(), sessionID) + session, err := s.M.GetAuthenticationSession(r.Context(), sessionID) if errors.Cause(err) == pkg.ErrNotFound { return s.forwardAuthenticationRequest(w, r, ar, "", time.Time{}, nil) } else if err != nil { @@ -147,7 +145,7 @@ func (s *DefaultStrategy) requestAuthentication(w http.ResponseWriter, r *http.R return s.forwardAuthenticationRequest(w, r, ar, session.Subject, session.AuthenticatedAt, session) } - token, err := s.JWTStrategy.Decode(idTokenHint) + token, err := s.JWTStrategy.Decode(r.Context(), idTokenHint) if ve, ok := errors.Cause(err).(*jwtgo.ValidationError); err == nil || (ok && ve.Errors == jwtgo.ValidationErrorExpired) { } else { return err @@ -162,7 +160,7 @@ func (s *DefaultStrategy) requestAuthentication(w http.ResponseWriter, r *http.R return err } - if s, err := s.M.GetForcedObfuscatedAuthenticationSession(context.TODO(), ar.GetClient().GetID(), hintSub); errors.Cause(err) == pkg.ErrNotFound { + if s, err := s.M.GetForcedObfuscatedAuthenticationSession(r.Context(), ar.GetClient().GetID(), hintSub); errors.Cause(err) == pkg.ErrNotFound { // do nothing } else if err != nil { return err @@ -208,7 +206,7 @@ func (s *DefaultStrategy) forwardAuthenticationRequest(w http.ResponseWriter, r var idTokenHintClaims jwtgo.MapClaims if idTokenHint := ar.GetRequestForm().Get("id_token_hint"); len(idTokenHint) > 0 { - token, err := s.JWTStrategy.Decode(idTokenHint) + token, err := s.JWTStrategy.Decode(r.Context(), idTokenHint) if ve, ok := errors.Cause(err).(*jwtgo.ValidationError); err == nil || (ok && ve.Errors == jwtgo.ValidationErrorExpired) { if hintClaims, ok := token.Claims.(jwtgo.MapClaims); ok { idTokenHintClaims = hintClaims @@ -223,7 +221,7 @@ func (s *DefaultStrategy) forwardAuthenticationRequest(w http.ResponseWriter, r // Set the session if err := s.M.CreateAuthenticationRequest( - context.TODO(), + r.Context(), &AuthenticationRequest{ Challenge: challenge, Verifier: verifier, @@ -277,7 +275,7 @@ func (s *DefaultStrategy) revokeAuthenticationSession(w http.ResponseWriter, r * return nil } - return s.M.DeleteAuthenticationSession(context.TODO(), sid) + return s.M.DeleteAuthenticationSession(r.Context(), sid) } func revokeAuthenticationCookie(w http.ResponseWriter, r *http.Request, s sessions.Store) (string, error) { @@ -313,7 +311,7 @@ func (s *DefaultStrategy) obfuscateSubjectIdentifier(subject string, req fosite. } func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester, verifier string) (*HandledAuthenticationRequest, error) { - session, err := s.M.VerifyAndInvalidateAuthenticationRequest(context.TODO(), verifier) + session, err := s.M.VerifyAndInvalidateAuthenticationRequest(r.Context(), verifier) if errors.Cause(err) == pkg.ErrNotFound { return nil, errors.WithStack(fosite.ErrAccessDenied.WithDebug("The login verifier has already been used, has not been granted, or is invalid.")) } else if err != nil { @@ -386,7 +384,7 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re } if session.ForceSubjectIdentifier != "" { - if err := s.M.CreateForcedObfuscatedAuthenticationSession(context.TODO(), &ForcedObfuscatedAuthenticationSession{ + if err := s.M.CreateForcedObfuscatedAuthenticationSession(r.Context(), &ForcedObfuscatedAuthenticationSession{ Subject: session.Subject, ClientID: req.GetClient().GetID(), SubjectObfuscated: session.ForceSubjectIdentifier, @@ -410,7 +408,7 @@ func (s *DefaultStrategy) verifyAuthentication(w http.ResponseWriter, r *http.Re cookie, _ := s.CookieStore.Get(r, cookieAuthenticationName) sid := uuid.New() - if err := s.M.CreateAuthenticationSession(context.TODO(), &AuthenticationSession{ + if err := s.M.CreateAuthenticationSession(r.Context(), &AuthenticationSession{ ID: sid, Subject: session.Subject, AuthenticatedAt: session.AuthenticatedAt, @@ -472,7 +470,7 @@ func (s *DefaultStrategy) requestConsent(w http.ResponseWriter, r *http.Request, // return s.forwardConsentRequest(w, r, ar, authenticationSession, nil) // } - consentSessions, err := s.M.FindPreviouslyGrantedConsentRequests(context.TODO(), ar.GetClient().GetID(), authenticationSession.Subject) + consentSessions, err := s.M.FindPreviouslyGrantedConsentRequests(r.Context(), ar.GetClient().GetID(), authenticationSession.Subject) if errors.Cause(err) == ErrNoPreviousConsentFound { return s.forwardConsentRequest(w, r, ar, authenticationSession, nil) } else if err != nil { @@ -503,7 +501,7 @@ func (s *DefaultStrategy) forwardConsentRequest(w http.ResponseWriter, r *http.R csrf := strings.Replace(uuid.New(), "-", "", -1) if err := s.M.CreateConsentRequest( - context.TODO(), + r.Context(), &ConsentRequest{ Challenge: challenge, Verifier: verifier, @@ -544,7 +542,7 @@ func (s *DefaultStrategy) forwardConsentRequest(w http.ResponseWriter, r *http.R } func (s *DefaultStrategy) verifyConsent(w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester, verifier string) (*HandledConsentRequest, error) { - session, err := s.M.VerifyAndInvalidateConsentRequest(context.TODO(), verifier) + session, err := s.M.VerifyAndInvalidateConsentRequest(r.Context(), verifier) if errors.Cause(err) == pkg.ErrNotFound { return nil, errors.WithStack(fosite.ErrAccessDenied.WithDebug("The consent verifier has already been used, has not been granted, or is invalid.")) } else if err != nil { diff --git a/consent/strategy_default_test.go b/consent/strategy_default_test.go index 9da129efa9..fc43281a1d 100644 --- a/consent/strategy_default_test.go +++ b/consent/strategy_default_test.go @@ -34,6 +34,8 @@ import ( "testing" "time" + "context" + "github.com/gorilla/securecookie" "github.com/gorilla/sessions" "github.com/julienschmidt/httprouter" @@ -90,32 +92,32 @@ func TestStrategy(t *testing.T) { PrivateKey: pkg.MustINSECURELOWENTROPYRSAKEYFORTEST(), } - fooUserIDToken, _, err := jwts.Generate((jwt.IDTokenClaims{ + fooUserIDToken, _, err := jwts.Generate(context.TODO(), jwt.IDTokenClaims{ Subject: "foouser", ExpiresAt: time.Now().Add(time.Hour), IssuedAt: time.Now(), - }).ToMapClaims(), jwt.NewHeaders()) + }.ToMapClaims(), jwt.NewHeaders()) require.NoError(t, err) - forcedAuthUserIDToken, _, err := jwts.Generate((jwt.IDTokenClaims{ + forcedAuthUserIDToken, _, err := jwts.Generate(context.TODO(), jwt.IDTokenClaims{ Subject: "forced-auth-user", ExpiresAt: time.Now().Add(time.Hour), IssuedAt: time.Now(), - }).ToMapClaims(), jwt.NewHeaders()) + }.ToMapClaims(), jwt.NewHeaders()) require.NoError(t, err) - pairwiseIDToken, _, err := jwts.Generate((jwt.IDTokenClaims{ + pairwiseIDToken, _, err := jwts.Generate(context.TODO(), jwt.IDTokenClaims{ Subject: "c737d5e1fec8896d096d49f6b1a73eb45ac7becb87de9ac3f0a350bad2a9c9fd", ExpiresAt: time.Now().Add(time.Hour), IssuedAt: time.Now(), - }).ToMapClaims(), jwt.NewHeaders()) + }.ToMapClaims(), jwt.NewHeaders()) require.NoError(t, err) - expiredAuthUserToken, _, err := jwts.Generate((jwt.IDTokenClaims{ + expiredAuthUserToken, _, err := jwts.Generate(context.TODO(), jwt.IDTokenClaims{ Subject: "user", ExpiresAt: time.Now().Add(-time.Hour), IssuedAt: time.Now(), - }).ToMapClaims(), jwt.NewHeaders()) + }.ToMapClaims(), jwt.NewHeaders()) require.NoError(t, err) cs := sessions.NewCookieStore([]byte("dummy-secret-yay")) diff --git a/jwk/jwt_strategy.go b/jwk/jwt_strategy.go index f1f59beea8..4cca8ac8f6 100644 --- a/jwk/jwt_strategy.go +++ b/jwk/jwt_strategy.go @@ -32,7 +32,7 @@ import ( ) type JWTStrategy interface { - GetPublicKeyID() (string, error) + GetPublicKeyID(ctx context.Context) (string, error) jwt.JWTStrategy } @@ -43,7 +43,7 @@ func NewRS256JWTStrategy(m Manager, set string) (*RS256JWTStrategy, error) { RS256JWTStrategy: &jwt.RS256JWTStrategy{}, Set: set, } - if err := j.refresh(); err != nil { + if err := j.refresh(context.TODO()); err != nil { return nil, err } return j, nil @@ -60,8 +60,8 @@ type RS256JWTStrategy struct { privateKeyID string } -func (j *RS256JWTStrategy) Hash(in []byte) ([]byte, error) { - return j.RS256JWTStrategy.Hash(in) +func (j *RS256JWTStrategy) Hash(ctx context.Context, in []byte) ([]byte, error) { + return j.RS256JWTStrategy.Hash(ctx, in) } // GetSigningMethodLength will return the length of the signing method @@ -69,44 +69,44 @@ func (j *RS256JWTStrategy) GetSigningMethodLength() int { return j.RS256JWTStrategy.GetSigningMethodLength() } -func (j *RS256JWTStrategy) GetSignature(token string) (string, error) { - return j.RS256JWTStrategy.GetSignature(token) +func (j *RS256JWTStrategy) GetSignature(ctx context.Context, token string) (string, error) { + return j.RS256JWTStrategy.GetSignature(ctx, token) } -func (j *RS256JWTStrategy) Generate(claims jwt2.Claims, header jwt.Mapper) (string, string, error) { - if err := j.refresh(); err != nil { +func (j *RS256JWTStrategy) Generate(ctx context.Context, claims jwt2.Claims, header jwt.Mapper) (string, string, error) { + if err := j.refresh(ctx); err != nil { return "", "", err } - return j.RS256JWTStrategy.Generate(claims, header) + return j.RS256JWTStrategy.Generate(ctx, claims, header) } -func (j *RS256JWTStrategy) Validate(token string) (string, error) { - if err := j.refresh(); err != nil { +func (j *RS256JWTStrategy) Validate(ctx context.Context, token string) (string, error) { + if err := j.refresh(ctx); err != nil { return "", err } - return j.RS256JWTStrategy.Validate(token) + return j.RS256JWTStrategy.Validate(ctx, token) } -func (j *RS256JWTStrategy) Decode(token string) (*jwt2.Token, error) { - if err := j.refresh(); err != nil { +func (j *RS256JWTStrategy) Decode(ctx context.Context, token string) (*jwt2.Token, error) { + if err := j.refresh(ctx); err != nil { return nil, err } - return j.RS256JWTStrategy.Decode(token) + return j.RS256JWTStrategy.Decode(ctx, token) } -func (j *RS256JWTStrategy) GetPublicKeyID() (string, error) { - if err := j.refresh(); err != nil { +func (j *RS256JWTStrategy) GetPublicKeyID(ctx context.Context) (string, error) { + if err := j.refresh(ctx); err != nil { return "", err } return j.publicKeyID, nil } -func (j *RS256JWTStrategy) refresh() error { - keys, err := j.Manager.GetKeySet(context.TODO(), j.Set) +func (j *RS256JWTStrategy) refresh(ctx context.Context) error { + keys, err := j.Manager.GetKeySet(ctx, j.Set) if err != nil { return err } diff --git a/jwk/jwt_strategy_test.go b/jwk/jwt_strategy_test.go index bb059eda7c..709e1fc56a 100644 --- a/jwk/jwt_strategy_test.go +++ b/jwk/jwt_strategy_test.go @@ -45,15 +45,15 @@ func TestRS256JWTStrategy(t *testing.T) { s, err := NewRS256JWTStrategy(m, "foo-set") require.NoError(t, err) - a, b, err := s.Generate(jwt2.MapClaims{"foo": "bar"}, &jwt.Headers{}) + a, b, err := s.Generate(context.TODO(), jwt2.MapClaims{"foo": "bar"}, &jwt.Headers{}) require.NoError(t, err) assert.NotEmpty(t, a) assert.NotEmpty(t, b) - _, err = s.Validate(a) + _, err = s.Validate(context.TODO(), a) require.NoError(t, err) - kid, err := s.GetPublicKeyID() + kid, err := s.GetPublicKeyID(context.TODO()) assert.NoError(t, err) assert.Equal(t, "public:foo", kid) @@ -61,15 +61,15 @@ func TestRS256JWTStrategy(t *testing.T) { require.NoError(t, err) require.NoError(t, m.AddKeySet(context.TODO(), "foo-set", ks)) - a, b, err = s.Generate(jwt2.MapClaims{"foo": "bar"}, &jwt.Headers{}) + a, b, err = s.Generate(context.TODO(), jwt2.MapClaims{"foo": "bar"}, &jwt.Headers{}) require.NoError(t, err) assert.NotEmpty(t, a) assert.NotEmpty(t, b) - _, err = s.Validate(a) + _, err = s.Validate(context.TODO(), a) require.NoError(t, err) - kid, err = s.GetPublicKeyID() + kid, err = s.GetPublicKeyID(context.TODO()) assert.NoError(t, err) assert.Equal(t, "public:bar", kid) } diff --git a/oauth2/handler.go b/oauth2/handler.go index 6213bceceb..36756ca1b8 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -281,13 +281,13 @@ func (h *Handler) UserinfoHandler(w http.ResponseWriter, r *http.Request) { delete(interim, "exp") delete(interim, "jti") - keyID, err := h.OpenIDJWTStrategy.GetPublicKeyID() + keyID, err := h.OpenIDJWTStrategy.GetPublicKeyID(r.Context()) if err != nil { h.H.WriteError(w, r, err) return } - token, _, err := h.OpenIDJWTStrategy.Generate(jwt2.MapClaims(interim), &jwt.Headers{ + token, _, err := h.OpenIDJWTStrategy.Generate(r.Context(), jwt2.MapClaims(interim), &jwt.Headers{ Extra: map[string]interface{}{ "kid": keyID, }, @@ -526,7 +526,7 @@ func (h *Handler) TokenHandler(w http.ResponseWriter, r *http.Request) { if accessRequest.GetGrantTypes().Exact("client_credentials") { var accessTokenKeyID string if h.AccessTokenStrategy == "jwt" { - accessTokenKeyID, err = h.AccessTokenJWTStrategy.GetPublicKeyID() + accessTokenKeyID, err = h.AccessTokenJWTStrategy.GetPublicKeyID(r.Context()) if err != nil { pkg.LogError(err, h.L) h.OAuth2.WriteAccessError(w, accessRequest, err) @@ -599,7 +599,7 @@ func (h *Handler) AuthHandler(w http.ResponseWriter, r *http.Request, _ httprout authorizeRequest.GrantScope(scope) } - openIDKeyID, err := h.OpenIDJWTStrategy.GetPublicKeyID() + openIDKeyID, err := h.OpenIDJWTStrategy.GetPublicKeyID(r.Context()) if err != nil { pkg.LogError(err, h.L) h.writeAuthorizeError(w, authorizeRequest, err) @@ -608,7 +608,7 @@ func (h *Handler) AuthHandler(w http.ResponseWriter, r *http.Request, _ httprout var accessTokenKeyID string if h.AccessTokenStrategy == "jwt" { - accessTokenKeyID, err = h.AccessTokenJWTStrategy.GetPublicKeyID() + accessTokenKeyID, err = h.AccessTokenJWTStrategy.GetPublicKeyID(r.Context()) if err != nil { pkg.LogError(err, h.L) h.writeAuthorizeError(w, authorizeRequest, err) diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index 5339fa1b93..6fb5e32fcd 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -124,11 +124,11 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { WorkFactor: 4, } - fooUserIDToken, _, err := jwts.Generate((jwt.IDTokenClaims{ + fooUserIDToken, _, err := jwts.Generate(context.TODO(), jwt.IDTokenClaims{ Subject: "foouser", ExpiresAt: time.Now().Add(time.Hour), IssuedAt: time.Now(), - }).ToMapClaims(), jwt.NewHeaders()) + }.ToMapClaims(), jwt.NewHeaders()) require.NoError(t, err) // we create a new fositeStore here because the old one