diff --git a/internal/identity/oauth/apple/apple.go b/internal/identity/oauth/apple/apple.go index 53878606f7f..bed70fa9496 100644 --- a/internal/identity/oauth/apple/apple.go +++ b/internal/identity/oauth/apple/apple.go @@ -130,16 +130,9 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta // Refresh renews a user's session. func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) { - if t == nil { - return nil, oidc.ErrMissingAccessToken - } - if t.RefreshToken == "" { - return nil, oidc.ErrMissingRefreshToken - } - - newToken, err := p.oauth.TokenSource(ctx, t).Token() + newToken, err := oidc.Refresh(ctx, p.oauth, t) if err != nil { - return nil, fmt.Errorf("identity/apple: refresh failed: %w", err) + return nil, err } if rawIDToken, ok := newToken.Extra("id_token").(string); ok { diff --git a/internal/identity/oidc/oidc.go b/internal/identity/oidc/oidc.go index 832a62a8518..d13e3634f0d 100644 --- a/internal/identity/oidc/oidc.go +++ b/internal/identity/oidc/oidc.go @@ -213,16 +213,9 @@ func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.Stat return nil, err } - if t == nil { - return nil, ErrMissingAccessToken - } - if t.RefreshToken == "" { - return nil, ErrMissingRefreshToken - } - - newToken, err := oa.TokenSource(ctx, t).Token() + newToken, err := Refresh(ctx, oa, t) if err != nil { - return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err) + return nil, err } // Many identity providers _will not_ return `id_token` on refresh diff --git a/internal/identity/oidc/refresh.go b/internal/identity/oidc/refresh.go new file mode 100644 index 00000000000..a2c1a216c26 --- /dev/null +++ b/internal/identity/oidc/refresh.go @@ -0,0 +1,29 @@ +package oidc + +import ( + "context" + "fmt" + + "golang.org/x/oauth2" +) + +// Refresh requests a new oauth2.Token based on an existing Token and the +// provided Config. The existing Token must contain a refresh token. +func Refresh(ctx context.Context, cfg *oauth2.Config, t *oauth2.Token) (*oauth2.Token, error) { + if t == nil || t.RefreshToken == "" { + return nil, ErrMissingRefreshToken + } + + // Note: the TokenSource returned by oauth2.Config has its own threshold + // for determining when to attempt a refresh. In order to force a refresh + // we can remove the current AccessToken. + t = &oauth2.Token{ + TokenType: t.TokenType, + RefreshToken: t.RefreshToken, + } + newToken, err := cfg.TokenSource(ctx, t).Token() + if err != nil { + return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err) + } + return newToken, nil +} diff --git a/internal/identity/oidc/refresh_test.go b/internal/identity/oidc/refresh_test.go new file mode 100644 index 00000000000..6ac7f067bf2 --- /dev/null +++ b/internal/identity/oidc/refresh_test.go @@ -0,0 +1,71 @@ +package oidc + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestRefresh(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(clearTimeout) + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "NEW_TOKEN", + "refresh_token": "NEW_REFRESH_TOKEN", + "expires_in": 3600 + }`)) + })) + t.Cleanup(s.Close) + + cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}} + + token := &oauth2.Token{ + AccessToken: "OLD_TOKEN", + RefreshToken: "OLD_REFRESH_TOKEN", + + // Even if a token is not expiring soon, Refresh() should still perform + // the refresh. + Expiry: time.Now().Add(time.Hour), + } + require.True(t, token.Valid()) + + newToken, err := Refresh(ctx, cfg, token) + require.NoError(t, err) + assert.Equal(t, "NEW_TOKEN", newToken.AccessToken) + assert.Equal(t, "NEW_REFRESH_TOKEN", newToken.RefreshToken) +} + +func TestRefresh_errors(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(clearTimeout) + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("{}")) + })) + t.Cleanup(s.Close) + + cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}} + + _, err := Refresh(ctx, cfg, nil) + assert.Equal(t, ErrMissingRefreshToken, err) + + _, err = Refresh(ctx, cfg, &oauth2.Token{}) + assert.Equal(t, ErrMissingRefreshToken, err) + + _, err = Refresh(ctx, cfg, &oauth2.Token{RefreshToken: "REFRESH_TOKEN"}) + assert.Equal(t, "identity/oidc: refresh failed: oauth2: server response missing access_token", + err.Error()) +}