diff --git a/driver/config/provider.go b/driver/config/provider.go index fc39ee4a96e..658f3525d73 100644 --- a/driver/config/provider.go +++ b/driver/config/provider.go @@ -72,6 +72,7 @@ const ( KeyOAuth2LegacyErrors = "oauth2.include_legacy_error_fields" KeyExcludeNotBeforeClaim = "oauth2.exclude_not_before_claim" KeyAllowedTopLevelClaims = "oauth2.allowed_top_level_claims" + KeyRefreshTokenRotationGracePeriod = "oauth2.refresh_token_rotation.grace_period" // #nosec G101 KeyOAuth2GrantJWTIDOptional = "oauth2.grant.jwt.jti_optional" KeyOAuth2GrantJWTIssuedDateOptional = "oauth2.grant.jwt.iat_optional" KeyOAuth2GrantJWTMaxDuration = "oauth2.grant.jwt.max_ttl" @@ -480,3 +481,12 @@ func (p *Provider) GrantTypeJWTBearerIssuedDateOptional() bool { func (p *Provider) GrantTypeJWTBearerMaxDuration() time.Duration { return p.p.DurationF(KeyOAuth2GrantJWTMaxDuration, time.Hour*24*30) } + +func (p *Provider) RefreshTokenRotationGracePeriod() time.Duration { + var duration = p.p.DurationF(KeyRefreshTokenRotationGracePeriod, 0) + if duration > time.Hour { + return time.Hour + } + + return p.p.DurationF(KeyRefreshTokenRotationGracePeriod, 0) +} diff --git a/driver/config/provider_test.go b/driver/config/provider_test.go index d6b57d288b6..2cc404c903c 100644 --- a/driver/config/provider_test.go +++ b/driver/config/provider_test.go @@ -278,6 +278,13 @@ func TestViperProviderValidates(t *testing.T) { assert.Equal(t, "random_salt", c.SubjectIdentifierAlgorithmSalt()) assert.Equal(t, []string{"whatever"}, c.DefaultClientScope()) + // refresh + assert.Equal(t, time.Duration(0), c.RefreshTokenRotationGracePeriod()) + require.NoError(t, c.Set(KeyRefreshTokenRotationGracePeriod, "1s")) + assert.Equal(t, time.Second, c.RefreshTokenRotationGracePeriod()) + require.NoError(t, c.Set(KeyRefreshTokenRotationGracePeriod, "2h")) + assert.Equal(t, time.Hour, c.RefreshTokenRotationGracePeriod()) + // urls assert.Equal(t, urlx.ParseOrPanic("https://issuer/"), c.IssuerURL()) assert.Equal(t, urlx.ParseOrPanic("https://public/"), c.PublicURL()) diff --git a/internal/config/config.yaml b/internal/config/config.yaml index 868d7918ae4..911cd18bc44 100644 --- a/internal/config/config.yaml +++ b/internal/config/config.yaml @@ -399,6 +399,23 @@ oauth2: session: # store encrypted data in database, default true encrypt_at_rest: true + ## refresh_token_rotation + # + # By default Refresh Tokens are rotated and invalidated with each use. + # See https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics#section-4.13.2 for more details + # + refresh_token_rotation: + # + ## grace_period + # + # Set the grace period for a refresh token to allow it to be used for the duration of this configuration after its first use. New refresh tokens will continue + # to be issued. + # + # Examples: + # - 5s + # - 1m + # - 0s (default; grace period disabled) + grace_period: 0s # The secrets section configures secrets used for encryption and signing of several systems. All secrets can be rotated, # for more information on this topic navigate to: diff --git a/oauth2/fosite_store_helpers.go b/oauth2/fosite_store_helpers.go index a2834cc0500..93234e7b06f 100644 --- a/oauth2/fosite_store_helpers.go +++ b/oauth2/fosite_store_helpers.go @@ -186,6 +186,7 @@ func TestHelperRunner(t *testing.T, store InternalRegistry, k string) { t.Run(fmt.Sprintf("case=testHelperDeleteAccessTokens/db=%s", k), testHelperDeleteAccessTokens(store)) t.Run(fmt.Sprintf("case=testHelperRevokeAccessToken/db=%s", k), testHelperRevokeAccessToken(store)) t.Run(fmt.Sprintf("case=testFositeJWTBearerGrantStorage/db=%s", k), testFositeJWTBearerGrantStorage(store)) + t.Run(fmt.Sprintf("case=testHelperRevokeRefreshTokenMaybeGracePeriod/db=%s", k), testHelperRevokeRefreshTokenMaybeGracePeriod(store)) } func testHelperRequestIDMultiples(m InternalRegistry, _ string) func(t *testing.T) { @@ -414,6 +415,63 @@ func testHelperRevokeAccessToken(x InternalRegistry) func(t *testing.T) { } } +func testHelperRevokeRefreshTokenMaybeGracePeriod(x InternalRegistry) func(t *testing.T) { + + return func(t *testing.T) { + t.Run("Revokes refresh token when grace period not configured", func(t *testing.T) { + // SETUP + m := x.OAuth2Storage() + ctx := context.Background() + + refreshTokenSession := fmt.Sprintf("refresh_token_%d", time.Now().Unix()) + err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest) + assert.NoError(t, err, "precondition failed: could not create refresh token session") + + // ACT + err = m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession) + assert.NoError(t, err) + + tmpSession := new(fosite.Session) + _, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, *tmpSession) + + // ASSERT + // a revoked refresh token returns an error when getting the token again + assert.Error(t, err) + assert.True(t, errors.Is(err, fosite.ErrInactiveToken)) + }) + + t.Run("refresh token enters grace period when configured,", func(t *testing.T) { + + // SETUP + x.Config().Set("oauth2.refresh_token_rotation.grace_period", "1m") + + // always reset back to the default + defer x.Config().Set("oauth2.refresh_token_rotation.grace_period", "0m") + + ctx := context.Background() + m := x.OAuth2Storage() + + refreshTokenSession := fmt.Sprintf("refresh_token_%d_with_grace_period", time.Now().Unix()) + err := m.CreateRefreshTokenSession(ctx, refreshTokenSession, &defaultRequest) + assert.NoError(t, err, "precondition failed: could not create refresh token session") + + // ACT + err = m.RevokeRefreshTokenMaybeGracePeriod(ctx, defaultRequest.GetID(), refreshTokenSession) + + assert.NoError(t, err) + + tmpSession := new(fosite.Session) + _, err = m.GetRefreshTokenSession(ctx, refreshTokenSession, *tmpSession) + + // ASSERT + // when grace period is configured the refresh token can be obtained within + // the grace period without error + assert.NoError(t, err) + }) + } + +} + func testHelperCreateGetDeletePKCERequestSession(x InternalRegistry) func(t *testing.T) { return func(t *testing.T) { m := x.OAuth2Storage() diff --git a/persistence/sql/migrations/20220214121000000000_add_refresh_token_in_grace_period_flag.down.sql b/persistence/sql/migrations/20220214121000000000_add_refresh_token_in_grace_period_flag.down.sql new file mode 100644 index 00000000000..163e5f5cc05 --- /dev/null +++ b/persistence/sql/migrations/20220214121000000000_add_refresh_token_in_grace_period_flag.down.sql @@ -0,0 +1 @@ +ALTER TABLE hydra_oauth2_refresh DROP COLUMN in_grace_period; diff --git a/persistence/sql/migrations/20220214121000000000_add_refresh_token_in_grace_period_flag.up.sql b/persistence/sql/migrations/20220214121000000000_add_refresh_token_in_grace_period_flag.up.sql new file mode 100644 index 00000000000..3a94b7dc7ce --- /dev/null +++ b/persistence/sql/migrations/20220214121000000000_add_refresh_token_in_grace_period_flag.up.sql @@ -0,0 +1 @@ +ALTER TABLE hydra_oauth2_refresh ADD in_grace_period bool DEFAULT false; diff --git a/persistence/sql/persister_oauth2.go b/persistence/sql/persister_oauth2.go index 5ef95cfa672..2c0c8dc4c7a 100644 --- a/persistence/sql/persister_oauth2.go +++ b/persistence/sql/persister_oauth2.go @@ -57,7 +57,11 @@ const ( ) func (r OAuth2RequestSQL) TableName() string { - return "hydra_oauth2_" + string(r.Table) + return r.Table.TableName() +} + +func (table tableName) TableName() string { + return "hydra_oauth2_" + string(table) } func (p *Persister) sqlSchemaFromRequest(rawSignature string, r fosite.Requester, table tableName) (*OAuth2RequestSQL, error) { @@ -68,19 +72,11 @@ func (p *Persister) sqlSchemaFromRequest(rawSignature string, r fosite.Requester subject = r.GetSession().GetSubject() } - session, err := json.Marshal(r.GetSession()) + session, err := p.marshalSession(r.GetSession()) if err != nil { return nil, errorsx.WithStack(err) } - if p.config.EncryptSessionData() { - ciphertext, err := p.r.KeyCipher().Encrypt(session) - if err != nil { - return nil, errorsx.WithStack(err) - } - session = []byte(ciphertext) - } - var challenge sql.NullString rr, ok := r.GetSession().(*oauth2.Session) if !ok && r.GetSession() != nil { @@ -109,6 +105,24 @@ func (p *Persister) sqlSchemaFromRequest(rawSignature string, r fosite.Requester }, nil } +func (p *Persister) marshalSession(session fosite.Session) ([]byte, error) { + sessionBytes, err := json.Marshal(session) + if err != nil { + return nil, err + } + + if !p.config.EncryptSessionData() { + return sessionBytes, nil + } + + ciphertext, err := p.r.KeyCipher().Encrypt(sessionBytes) + if err != nil { + return nil, err + } + + return []byte(ciphertext), nil +} + func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session, p *Persister) (*fosite.Request, error) { sess := r.Session if !gjson.ValidBytes(sess) { @@ -220,6 +234,30 @@ func (p *Persister) createSession(ctx context.Context, signature string, request return nil } +func (p *Persister) updateRefreshSession(ctx context.Context, requestId string, session fosite.Session, inGracePeriod bool) error { + _, ok := session.(*oauth2.Session) + if !ok && session != nil { + return errors.Errorf("expected session to be of type *oauth2.Session but got: %T", session) + } + sessionBytes, err := p.marshalSession(session) + if err != nil { + return err + } + + updateSql := fmt.Sprintf("UPDATE %s SET session_data = ?, in_grace_period = ? WHERE request_id = ?", + sqlTableRefresh.TableName()) + + return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + err := p.Connection(ctx).RawQuery(updateSql, sessionBytes, inGracePeriod, requestId).Exec() + if errors.Is(err, sql.ErrNoRows) { + return errorsx.WithStack(fosite.ErrNotFound) + } else if err != nil { + return sqlcon.HandleError(err) + } + return nil + }) +} + func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature string, session fosite.Session, table tableName) (fosite.Requester, error) { rawSignature = p.hashSignature(rawSignature, table) @@ -281,13 +319,27 @@ func (p *Persister) deactivateSessionByRequestID(ctx context.Context, id string, return sqlcon.HandleError( p.Connection(ctx). RawQuery( - fmt.Sprintf("UPDATE %s SET active=false WHERE request_id=?", OAuth2RequestSQL{Table: table}.TableName()), + fmt.Sprintf("UPDATE %s SET active=false, in_grace_period=false WHERE request_id=?", OAuth2RequestSQL{Table: table}.TableName()), id, ). Exec(), ) } +func (p *Persister) getRefreshTokenGracePeriodStatusBySignature(ctx context.Context, signature string) (bool, error) { + var inGracePeriod bool + return inGracePeriod, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error { + query := fmt.Sprintf("SELECT in_grace_period FROM %s WHERE signature = ?", sqlTableRefresh.TableName()) + err := p.Connection(ctx).RawQuery(query, signature).First(&inGracePeriod) + if errors.Is(err, sql.ErrNoRows) { + return errorsx.WithStack(fosite.ErrNotFound) + } else if err != nil { + return sqlcon.HandleError(err) + } + return err + }) +} + func (p *Persister) CreateAuthorizeCodeSession(ctx context.Context, signature string, requester fosite.Requester) (err error) { return p.createSession(ctx, signature, requester, sqlTableCode) } @@ -354,12 +406,41 @@ func (p *Persister) DeletePKCERequestSession(ctx context.Context, signature stri return p.deleteSessionBySignature(ctx, signature, sqlTablePKCE) } -func (p *Persister) RevokeRefreshToken(ctx context.Context, id string) error { - return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) +func (p *Persister) RevokeRefreshToken(ctx context.Context, requestId string) error { + return p.deactivateSessionByRequestID(ctx, requestId, sqlTableRefresh) } -func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, id string, signature string) error { - return p.deactivateSessionByRequestID(ctx, id, sqlTableRefresh) +func (p *Persister) RevokeRefreshTokenMaybeGracePeriod(ctx context.Context, requestId string, signature string) error { + gracePeriod := p.config.RefreshTokenRotationGracePeriod() + if gracePeriod <= 0 { + return p.RevokeRefreshToken(ctx, requestId) + } + + var requester fosite.Requester + var err error + session := new(oauth2.Session) + if requester, err = p.GetRefreshTokenSession(ctx, signature, session); err != nil { + p.l.Errorf("signature: %s not found. grace period not applied", signature) + return errors.WithStack(err) + } + + var inGracePeriod bool + if inGracePeriod, err = p.getRefreshTokenGracePeriodStatusBySignature(ctx, signature); err != nil { + p.l.Errorf("signature: %s in_grace_period status not found. grace period not applied", signature) + return errors.WithStack(err) + } + + requesterSession := requester.GetSession() + if !inGracePeriod { + requesterSession.SetExpiresAt(fosite.RefreshToken, time.Now().UTC().Add(gracePeriod)) + if err = p.updateRefreshSession(ctx, requestId, requesterSession, true); err != nil { + p.l.Errorf("failed to update session with signature: %s", signature) + return errors.WithStack(err) + } + } else { + p.l.Tracef("request_id: %s is in the grace period", requestId) + } + return nil } func (p *Persister) RevokeAccessToken(ctx context.Context, id string) error { diff --git a/spec/config.json b/spec/config.json index 8775deea372..ce4b55c8877 100644 --- a/spec/config.json +++ b/spec/config.json @@ -941,9 +941,24 @@ "examples": [ "https://my-example.app/token-refresh-hook" ] + }, + "refresh_token_rotation": { + "type": "object", + "properties": { + "grace_period": { + "title": "Refresh Token Rotation Grace Period", + "description": "Configures how long a Refresh Token remains valid after it has been used. The maximum value is one hour.", + "default": "0h", + "allOf": [ + { + "$ref": "#/definitions/duration" + } + ] + } } } - }, + } + }, "secrets": { "type": "object", "additionalProperties": false, diff --git a/x/fosite_storer.go b/x/fosite_storer.go index 4ca3677b1cf..2427693c94d 100644 --- a/x/fosite_storer.go +++ b/x/fosite_storer.go @@ -34,14 +34,11 @@ import ( type FositeStorer interface { fosite.Storage oauth2.CoreStorage + oauth2.TokenRevocationStorage openid.OpenIDConnectRequestStorage pkce.PKCERequestStorage rfc7523.RFC7523KeyStorage - RevokeRefreshToken(ctx context.Context, requestID string) error - - RevokeAccessToken(ctx context.Context, requestID string) error - // flush the access token requests from the database. // no data will be deleted after the 'notAfter' timeframe. FlushInactiveAccessTokens(ctx context.Context, notAfter time.Time, limit int, batchSize int) error