Skip to content

Commit

Permalink
feat: refresh token hook to update claims (#2649)
Browse files Browse the repository at this point in the history
This patch adds a new feature to Ory Hydra which allows the updating of access and ID tokens during the refresh flow. To set it up, use the `oauth2.refresh_token_hook` configuration to set up a HTTP(S) endpoint which receives a POST request when a refresh token is about to be issued.

Closes #2570

Co-authored-by: hackerman <3372410+aeneasr@users.noreply.github.com>
  • Loading branch information
svrakitin and aeneasr committed Sep 1, 2021
1 parent 72a5dc3 commit 1a7dcd1
Show file tree
Hide file tree
Showing 9 changed files with 321 additions and 0 deletions.
52 changes: 52 additions & 0 deletions docs/docs/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,58 @@ compatibility):
}
```

#### Updating claims at token refresh

Hydra can be configured to retrieve updated token claims from an endpoint at
token refresh, which provides updated claims for a given subject and scopes.
This is similar to accepting consent request, where the application provides the
session data by calling Hydra Admin API.

:::note

This endpoint is called _before_ any logic in Ory Hydra is executed. If the hook, for example, returns an error, the refresh token will remain unused!

::note

You can configure `oauth2.refresh_token_hook` config key:

```yaml
oauth2:
refresh_token_hook: https://my-example.app/token-refresh-hook
```

Hydra makes a `POST` request to this hook with the following payload:

```json
{
"subject": "foo",
"client_id": "bar",
"granted_scopes": ["openid", "offline"],
"granted_audience": []
}
```

Hook has to respond with `200 OK` and updated session data (i.e. "extra" claims)
for a token refresh to continue:

```json
{
"session": {
"access_token": {
"foo": "bar"
},
"id_token": {
"bar": "baz"
}
}
}
```

This will overwrite existing session data from the original consent request.

Hydra will gracefully deny refresh requests if the hook responds with
`403 Forbidden`. Any other response from the hook will fail refresh requests.

### OAuth 2.0 Client Authentication with private/public keypairs

ORY Hydra supports OAuth 2.0 Client Authentication with RSA and ECDSA
Expand Down
5 changes: 5 additions & 0 deletions driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ const (
KeyOAuth2LegacyErrors = "oauth2.include_legacy_error_fields"
KeyExcludeNotBeforeClaim = "oauth2.exclude_not_before_claim"
KeyAllowedTopLevelClaims = "oauth2.allowed_top_level_claims"
KeyRefreshTokenHookURL = "oauth2.refresh_token_hook" // #nosec G101
)

const DSNMemory = "memory"
Expand Down Expand Up @@ -375,6 +376,10 @@ func (p *Provider) JWKSURL() *url.URL {
return p.p.RequestURIF(KeyJWKSURL, urlx.AppendPaths(p.IssuerURL(), "/.well-known/jwks.json"))
}

func (p *Provider) TokenRefreshHookURL() *url.URL {
return p.p.URIF(KeyRefreshTokenHookURL, nil)
}

func (p *Provider) AccessTokenStrategy() string {
return strings.ToLower(p.p.StringF(KeyAccessTokenStrategy, "opaque"))
}
Expand Down
10 changes: 10 additions & 0 deletions driver/registry_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type RegistryBase struct {
pmm *prometheus.MetricsManager
oa2mw func(h http.Handler) http.Handler
o2mc *foauth2.HMACSHAStrategy
arhs []oauth2.AccessRequestHook
buildVersion string
buildHash string
buildDate string
Expand Down Expand Up @@ -465,3 +466,12 @@ func (m *RegistryBase) WithOAuth2Provider(f fosite.OAuth2Provider) {
func (m *RegistryBase) WithConsentStrategy(c consent.Strategy) {
m.cos = c
}

func (m *RegistryBase) AccessRequestHooks() []oauth2.AccessRequestHook {
if m.arhs == nil {
m.arhs = []oauth2.AccessRequestHook{
oauth2.RefreshTokenHook(m.C),
}
}
return m.arhs
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ require (
github.com/gorilla/securecookie v1.1.1
github.com/gorilla/sessions v1.2.0
github.com/gtank/cryptopasta v0.0.0-20170601214702-1f550f6f2f69
github.com/hashicorp/go-cleanhttp v0.5.1
github.com/jackc/pgx/v4 v4.11.0
github.com/jmoiron/sqlx v1.3.3
github.com/julienschmidt/httprouter v1.3.0
Expand Down
8 changes: 8 additions & 0 deletions oauth2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,14 @@ func (h *Handler) TokenHandler(w http.ResponseWriter, r *http.Request) {
}
}

for _, hook := range h.r.AccessRequestHooks() {
if err := hook(ctx, accessRequest); err != nil {
h.logOrAudit(err, r)
h.r.OAuth2Provider().WriteAccessError(w, accessRequest, err)
return
}
}

accessResponse, err := h.r.OAuth2Provider().NewAccessResponse(ctx, accessRequest)

if err != nil {
Expand Down
126 changes: 126 additions & 0 deletions oauth2/hook.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package oauth2

import (
"bytes"
"context"
"encoding/json"
"net/http"

"github.com/hashicorp/go-cleanhttp"

"github.com/ory/fosite"
"github.com/ory/hydra/consent"
"github.com/ory/hydra/driver/config"
"github.com/ory/x/errorsx"
)

// AccessRequestHook is called when an access token is being refreshed.
type AccessRequestHook func(ctx context.Context, requester fosite.AccessRequester) error

// RefreshTokenHookRequest is the request body sent to the refresh token hook.
//
// swagger:model refreshTokenHookRequest
type RefreshTokenHookRequest struct {
// Subject is the identifier of the authenticated end-user.
Subject string `json:"subject"`
// ClientID is the identifier of the OAuth 2.0 client.
ClientID string `json:"client_id"`
// GrantedScopes is the list of scopes granted to the OAuth 2.0 client.
GrantedScopes []string `json:"granted_scopes"`
// GrantedAudience is the list of audiences granted to the OAuth 2.0 client.
GrantedAudience []string `json:"granted_audience"`
}

// RefreshTokenHookResponse is the response body received from the refresh token hook.
//
// swagger:model refreshTokenHookResponse
type RefreshTokenHookResponse struct {
// Session is the session data returned by the hook.
Session consent.ConsentRequestSessionData `json:"session"`
}

// RefreshTokenHook is an AccessRequestHook called for `refresh_token` grant type.
func RefreshTokenHook(config *config.Provider) AccessRequestHook {
client := cleanhttp.DefaultPooledClient()

return func(ctx context.Context, requester fosite.AccessRequester) error {
hookURL := config.TokenRefreshHookURL()
if hookURL == nil {
return nil
}

if !requester.GetGrantTypes().ExactOne("refresh_token") {
return nil
}

session, ok := requester.GetSession().(*Session)
if !ok {
return nil
}

reqBody := RefreshTokenHookRequest{
Subject: session.GetSubject(),
ClientID: requester.GetClient().GetID(),
GrantedScopes: requester.GetGrantedScopes(),
GrantedAudience: requester.GetGrantedAudience(),
}
reqBodyBytes, err := json.Marshal(&reqBody)
if err != nil {
return errorsx.WithStack(
fosite.ErrServerError.
WithWrap(err).
WithDebug("refresh token hook: marshal request body"),
)
}

req, err := http.NewRequestWithContext(ctx, http.MethodPost, hookURL.String(), bytes.NewReader(reqBodyBytes))
if err != nil {
return errorsx.WithStack(
fosite.ErrServerError.
WithWrap(err).
WithDebug("refresh token hook: new http request"),
)
}

resp, err := client.Do(req)
if err != nil {
return errorsx.WithStack(
fosite.ErrServerError.
WithWrap(err).
WithDebug("refresh token hook: do http request"),
)
}
defer resp.Body.Close()

switch resp.StatusCode {
case http.StatusOK:
// We only accept '200 OK' here. Any other status code is considered an error.
case http.StatusForbidden:
return errorsx.WithStack(
fosite.ErrAccessDenied.
WithDebugf("refresh token hook: %s", resp.Status),
)
default:
return errorsx.WithStack(
fosite.ErrServerError.
WithDebugf("refresh token hook: %s", resp.Status),
)
}

var respBody RefreshTokenHookResponse
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
return errorsx.WithStack(
fosite.ErrServerError.
WithWrap(err).
WithDebugf("refresh token hook: unmarshal response body"),
)
}

// Overwrite existing session data (extra claims).
session.Extra = respBody.Session.AccessToken
idTokenClaims := session.IDTokenClaims()
idTokenClaims.Extra = respBody.Session.IDToken

return nil
}
}
109 changes: 109 additions & 0 deletions oauth2/oauth2_auth_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/tidwall/gjson"

"github.com/ory/hydra/client"
"github.com/ory/hydra/consent"
"github.com/ory/hydra/internal/testhelpers"

"github.com/julienschmidt/httprouter"
Expand All @@ -53,6 +54,7 @@ import (
hydra "github.com/ory/hydra/internal/httpclient/client"
"github.com/ory/hydra/internal/httpclient/client/admin"
"github.com/ory/hydra/internal/httpclient/models"
hydraoauth2 "github.com/ory/hydra/oauth2"
"github.com/ory/hydra/x"
"github.com/ory/x/pointerx"
"github.com/ory/x/urlx"
Expand Down Expand Up @@ -908,6 +910,113 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
require.NoError(t, json.Unmarshal(body, &refreshedToken))
})

t.Run("should call refresh token hook if configured", func(t *testing.T) {
hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var hookReq hydraoauth2.RefreshTokenHookRequest
require.NoError(t, json.NewDecoder(r.Body).Decode(&hookReq))
require.Equal(t, hookReq.Subject, "foo")
require.ElementsMatch(t, hookReq.GrantedScopes, []string{"openid", "offline", "hydra.*"})
require.ElementsMatch(t, hookReq.GrantedAudience, []string{})
require.Equal(t, hookReq.ClientID, oauthConfig.ClientID)

claims := map[string]interface{}{
"hooked": true,
}

hookResp := hydraoauth2.RefreshTokenHookResponse{
Session: consent.ConsentRequestSessionData{
AccessToken: claims,
IDToken: claims,
},
}

w.WriteHeader(http.StatusOK)
require.NoError(t, json.NewEncoder(w).Encode(&hookResp))
}))
defer hs.Close()

conf.MustSet(config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(config.KeyRefreshTokenHookURL, nil)

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)

body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
require.NoError(t, json.Unmarshal(body, &refreshedToken))

accessTokenClaims := testhelpers.IntrospectToken(t, oauthConfig, &refreshedToken, ts)
require.True(t, accessTokenClaims.Get("ext.hooked").Bool())

idTokenBody, err := x.DecodeSegment(
strings.Split(
gjson.GetBytes(body, "id_token").String(),
".",
)[1],
)
require.NoError(t, err)

require.True(t, gjson.GetBytes(idTokenBody, "hooked").Bool())
})

t.Run("should fail token refresh with `server_error` if hook fails", func(t *testing.T) {
hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer hs.Close()

conf.MustSet(config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(config.KeyRefreshTokenHookURL, nil)

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
require.NoError(t, err)
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)

var errBody fosite.RFC6749ErrorJson
require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody))
require.Equal(t, fosite.ErrServerError.Error(), errBody.Name)
require.Equal(t, fosite.ErrServerError.GetDescription(), errBody.Description)
})

t.Run("should fail token refresh with `access_denied` if hook denied the request", func(t *testing.T) {
hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
}))
defer hs.Close()

conf.MustSet(config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(config.KeyRefreshTokenHookURL, nil)

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, res.StatusCode)

var errBody fosite.RFC6749ErrorJson
require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody))
require.Equal(t, fosite.ErrAccessDenied.Error(), errBody.Name)
require.Equal(t, fosite.ErrAccessDenied.GetDescription(), errBody.Description)
})

t.Run("should fail token refresh with `server_error` if hook response is malformed", func(t *testing.T) {
hs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer hs.Close()

conf.MustSet(config.KeyRefreshTokenHookURL, hs.URL)
defer conf.MustSet(config.KeyRefreshTokenHookURL, nil)

res, err := testRefresh(t, &refreshedToken, ts.URL, false)
require.NoError(t, err)
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)

var errBody fosite.RFC6749ErrorJson
require.NoError(t, json.NewDecoder(res.Body).Decode(&errBody))
require.Equal(t, fosite.ErrServerError.Error(), errBody.Name)
require.Equal(t, fosite.ErrServerError.GetDescription(), errBody.Description)
})

t.Run("refreshing old token should no longer work", func(t *testing.T) {
res, err := testRefresh(t, token, ts.URL, false)
require.NoError(t, err)
Expand Down
2 changes: 2 additions & 0 deletions oauth2/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ type Registry interface {
OpenIDJWTStrategy() jwk.JWTStrategy

OpenIDConnectRequestValidator() *openid.OpenIDConnectRequestValidator

AccessRequestHooks() []AccessRequestHook
}

0 comments on commit 1a7dcd1

Please sign in to comment.