Skip to content

Commit

Permalink
identity: override TokenSource expiry behavior (#4632)
Browse files Browse the repository at this point in the history
The current session refresh loop attempts to refresh access tokens when
they are due to expire in less than one minute. However, the code to
perform the refresh relies on a TokenSource from the x/oauth2 package,
which has its own internal 'expiryDelta' threshold, with a default of
10 seconds. As a result, the first four or five attempts to refresh a
particular access token will not actually refresh the token. The refresh
will happen only when the access token is within 10 seconds of expiring.

Instead, before we obtain a new TokenSource, first clear any existing
access token. This causes the TokenSource to consider the token invalid,
triggering a refresh. This should give the refresh loop more control
over when refreshes happen.

Consolidate this logic in a new Refresh() method in the oidc package.
Add unit tests for this new method.
  • Loading branch information
kenjenkins committed Oct 23, 2023
1 parent 5a73526 commit 39a477c
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 18 deletions.
11 changes: 2 additions & 9 deletions internal/identity/oauth/apple/apple.go
Expand Up @@ -130,16 +130,9 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta

// Refresh renews a user's session.
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) {
if t == nil {
return nil, oidc.ErrMissingAccessToken
}
if t.RefreshToken == "" {
return nil, oidc.ErrMissingRefreshToken
}

newToken, err := p.oauth.TokenSource(ctx, t).Token()
newToken, err := oidc.Refresh(ctx, p.oauth, t)
if err != nil {
return nil, fmt.Errorf("identity/apple: refresh failed: %w", err)
return nil, err
}

if rawIDToken, ok := newToken.Extra("id_token").(string); ok {
Expand Down
11 changes: 2 additions & 9 deletions internal/identity/oidc/oidc.go
Expand Up @@ -213,16 +213,9 @@ func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.Stat
return nil, err
}

if t == nil {
return nil, ErrMissingAccessToken
}
if t.RefreshToken == "" {
return nil, ErrMissingRefreshToken
}

newToken, err := oa.TokenSource(ctx, t).Token()
newToken, err := Refresh(ctx, oa, t)
if err != nil {
return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err)
return nil, err
}

// Many identity providers _will not_ return `id_token` on refresh
Expand Down
29 changes: 29 additions & 0 deletions internal/identity/oidc/refresh.go
@@ -0,0 +1,29 @@
package oidc

import (
"context"
"fmt"

"golang.org/x/oauth2"
)

// Refresh requests a new oauth2.Token based on an existing Token and the
// provided Config. The existing Token must contain a refresh token.
func Refresh(ctx context.Context, cfg *oauth2.Config, t *oauth2.Token) (*oauth2.Token, error) {
if t == nil || t.RefreshToken == "" {
return nil, ErrMissingRefreshToken
}

// Note: the TokenSource returned by oauth2.Config has its own threshold
// for determining when to attempt a refresh. In order to force a refresh
// we can remove the current AccessToken.
t = &oauth2.Token{
TokenType: t.TokenType,
RefreshToken: t.RefreshToken,
}
newToken, err := cfg.TokenSource(ctx, t).Token()
if err != nil {
return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err)
}
return newToken, nil
}
71 changes: 71 additions & 0 deletions internal/identity/oidc/refresh_test.go
@@ -0,0 +1,71 @@
package oidc

import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)

func TestRefresh(t *testing.T) {
t.Parallel()

ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(clearTimeout)

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"access_token": "NEW_TOKEN",
"refresh_token": "NEW_REFRESH_TOKEN",
"expires_in": 3600
}`))
}))
t.Cleanup(s.Close)

cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}}

token := &oauth2.Token{
AccessToken: "OLD_TOKEN",
RefreshToken: "OLD_REFRESH_TOKEN",

// Even if a token is not expiring soon, Refresh() should still perform
// the refresh.
Expiry: time.Now().Add(time.Hour),
}
require.True(t, token.Valid())

newToken, err := Refresh(ctx, cfg, token)
require.NoError(t, err)
assert.Equal(t, "NEW_TOKEN", newToken.AccessToken)
assert.Equal(t, "NEW_REFRESH_TOKEN", newToken.RefreshToken)
}

func TestRefresh_errors(t *testing.T) {
t.Parallel()

ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(clearTimeout)

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("{}"))
}))
t.Cleanup(s.Close)

cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}}

_, err := Refresh(ctx, cfg, nil)
assert.Equal(t, ErrMissingRefreshToken, err)

_, err = Refresh(ctx, cfg, &oauth2.Token{})
assert.Equal(t, ErrMissingRefreshToken, err)

_, err = Refresh(ctx, cfg, &oauth2.Token{RefreshToken: "REFRESH_TOKEN"})
assert.Equal(t, "identity/oidc: refresh failed: oauth2: server response missing access_token",
err.Error())
}

0 comments on commit 39a477c

Please sign in to comment.