Skip to content

Commit

Permalink
core/authenticate: refactor idp sign out (#4582)
Browse files Browse the repository at this point in the history
  • Loading branch information
calebdoxsey committed Sep 28, 2023
1 parent 7211a8d commit a0c9289
Show file tree
Hide file tree
Showing 16 changed files with 318 additions and 93 deletions.
43 changes: 23 additions & 20 deletions authenticate/handlers.go
Expand Up @@ -89,6 +89,7 @@ func (a *Authenticate) mountDashboard(r *mux.Router) {

// routes that don't need a session:
sr.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut))
sr.Path("/signed_out").Handler(handlers.SignedOut(handlers.SignedOutData{})).Methods(http.MethodGet)

// routes that need a session:
sr = sr.NewRoute().Subrouter()
Expand Down Expand Up @@ -266,33 +267,35 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e

rawIDToken := a.revokeSession(ctx, w, r)

redirectString := ""
signOutURL, err := options.GetSignOutRedirectURL()
authenticateURL, err := options.GetAuthenticateURL()
if err != nil {
return err
return fmt.Errorf("error getting authenticate url: %w", err)
}
if signOutURL != nil {
redirectString = signOutURL.String()

signOutRedirectURL, err := options.GetSignOutRedirectURL()
if err != nil {
return err
}

var signOutURL string
if uri := r.FormValue(urlutil.QueryRedirectURI); uri != "" {
redirectString = uri
signOutURL = uri
} else if signOutRedirectURL != nil {
signOutURL = signOutRedirectURL.String()
} else {
signOutURL = authenticateURL.ResolveReference(&url.URL{
Path: "/.pomerium/signed_out",
}).String()
}

endSessionURL, err := authenticator.LogOut()
if err == nil && redirectString != "" {
params := endSessionURL.Query()
params.Add("id_token_hint", rawIDToken)
params.Add("post_logout_redirect_uri", redirectString)
endSessionURL.RawQuery = params.Encode()
redirectString = endSessionURL.String()
} else if err != nil && !errors.Is(err, oidc.ErrSignoutNotImplemented) {
log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session")
if idpSignOutURL, err := authenticator.GetSignOutURL(rawIDToken, signOutURL); err == nil {
signOutURL = idpSignOutURL
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
log.Warn(r.Context()).Err(err).Msg("authenticate: failed to get sign out url for authenticator")
}
if redirectString != "" {
httputil.Redirect(w, r, redirectString, http.StatusFound)
return nil
}
return httputil.NewError(http.StatusOK, errors.New("user logged out"))

httputil.Redirect(w, r, signOutURL, http.StatusFound)
return nil
}

// reauthenticateOrFail starts the authenticate process by redirecting the
Expand Down
25 changes: 6 additions & 19 deletions authenticate/handlers_test.go
Expand Up @@ -135,7 +135,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"",
"sig",
"ts",
identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))},
identity.MockProvider{GetSignOutURLResponse: "https://microsoft.com"},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound,
"",
Expand All @@ -148,7 +148,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"https://signout-redirect-url.example.com",
"sig",
"ts",
identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))},
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound,
"",
Expand All @@ -161,7 +161,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"",
"sig",
"ts",
identity.MockProvider{RevokeError: errors.New("OH NO")},
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound,
"",
Expand All @@ -174,7 +174,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"",
"sig",
"ts",
identity.MockProvider{RevokeError: errors.New("OH NO")},
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound,
"",
Expand All @@ -187,24 +187,11 @@ func TestAuthenticate_SignOut(t *testing.T) {
"",
"sig",
"ts",
identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented},
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound,
"",
},
{
"no redirect uri",
http.MethodPost,
nil,
"",
"",
"sig",
"ts",
identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusOK,
"{\"Status\":200}\n",
},
}
for _, tt := range tests {
tt := tt
Expand Down Expand Up @@ -253,7 +240,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
}
if tt.signoutRedirectURL != "" {
loc := w.Header().Get("Location")
assert.Contains(t, loc, url.QueryEscape(tt.signoutRedirectURL))
assert.Contains(t, loc, tt.signoutRedirectURL)
}
})
}
Expand Down
23 changes: 23 additions & 0 deletions internal/handlers/signedout.go
@@ -0,0 +1,23 @@
package handlers

import (
"net/http"

"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/ui"
)

// SignedOutData is the data for the SignedOut page.
type SignedOutData struct{}

// ToJSON converts the data into a JSON map.
func (data SignedOutData) ToJSON() map[string]interface{} {
return map[string]interface{}{}
}

// SignedOut returns a handler that renders the signed out page.
func SignedOut(data SignedOutData) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return ui.ServePage(w, r, "SignedOut", data.ToJSON())
})
}
25 changes: 13 additions & 12 deletions internal/identity/mock_provider.go
Expand Up @@ -2,7 +2,6 @@ package identity

import (
"context"
"net/url"

"golang.org/x/oauth2"

Expand All @@ -11,15 +10,15 @@ import (

// MockProvider provides a mocked implementation of the providers interface.
type MockProvider struct {
AuthenticateResponse oauth2.Token
AuthenticateError error
RefreshResponse oauth2.Token
RefreshError error
RevokeError error
GetSignInURLResponse string
LogOutResponse url.URL
LogOutError error
UpdateUserInfoError error
AuthenticateResponse oauth2.Token
AuthenticateError error
RefreshResponse oauth2.Token
RefreshError error
RevokeError error
GetSignInURLResponse string
GetSignOutURLResponse string
GetSignOutURLError error
UpdateUserInfoError error
}

// Authenticate is a mocked providers function.
Expand All @@ -40,8 +39,10 @@ func (mp MockProvider) Revoke(_ context.Context, _ *oauth2.Token) error {
// GetSignInURL is a mocked providers function.
func (mp MockProvider) GetSignInURL(_ string) (string, error) { return mp.GetSignInURLResponse, nil }

// LogOut is a mocked providers function.
func (mp MockProvider) LogOut() (*url.URL, error) { return &mp.LogOutResponse, mp.LogOutError }
// GetSignOutURL is a mocked providers function.
func (mp MockProvider) GetSignOutURL(_, _ string) (string, error) {
return mp.GetSignOutURLResponse, mp.GetSignOutURLError
}

// UpdateUserInfo is a mocked providers function.
func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ interface{}) error {
Expand Down
10 changes: 5 additions & 5 deletions internal/identity/oauth/apple/apple.go
Expand Up @@ -103,6 +103,11 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
return authURL, nil
}

// GetSignOutURL is not implemented.
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
return "", oidc.ErrSignoutNotImplemented
}

// Authenticate converts an authorization code returned from the identity
// provider into a token which is then converted into a user session.
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
Expand All @@ -123,11 +128,6 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta
return oauth2Token, nil
}

// LogOut is not implemented by Apple.
func (p *Provider) LogOut() (*url.URL, error) {
return nil, oidc.ErrSignoutNotImplemented
}

// Refresh renews a user's session.
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) {
if t == nil {
Expand Down
6 changes: 3 additions & 3 deletions internal/identity/oauth/github/github.go
Expand Up @@ -245,9 +245,9 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil
}

// LogOut is not implemented by github.
func (p *Provider) LogOut() (*url.URL, error) {
return nil, oidc.ErrSignoutNotImplemented
// GetSignOutURL is not implemented.
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
return "", oidc.ErrSignoutNotImplemented
}

// Name returns the provider name.
Expand Down
27 changes: 27 additions & 0 deletions internal/identity/oidc/auth0/auth0.go
Expand Up @@ -6,10 +6,12 @@ package auth0
import (
"context"
"fmt"
"net/url"
"strings"

"github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/urlutil"
)

const (
Expand Down Expand Up @@ -47,3 +49,28 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
func (p *Provider) Name() string {
return Name
}

// GetSignOutURL implements logout as described in https://auth0.com/docs/api/authentication#logout.
func (p *Provider) GetSignOutURL(_, redirectToURL string) (string, error) {
oa, err := p.GetOauthConfig()
if err != nil {
return "", fmt.Errorf("error getting auth0 oauth config: %w", err)
}

authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
if err != nil {
return "", fmt.Errorf("error parsing auth0 endpoint auth url: %w", err)
}

logoutQuery := url.Values{
"client_id": {oa.ClientID},
}
if redirectToURL != "" {
logoutQuery.Set("returnTo", redirectToURL)
}
logoutURL := authURL.ResolveReference(&url.URL{
Path: "/v2/logout",
RawQuery: logoutQuery.Encode(),
})
return logoutURL.String(), nil
}
61 changes: 61 additions & 0 deletions internal/identity/oidc/auth0/auth0_test.go
@@ -0,0 +1,61 @@
package auth0

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/pomerium/pomerium/internal/identity/oauth"
)

func TestProvider(t *testing.T) {
t.Parallel()

ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout)

var srv *httptest.Server
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
baseURL, err := url.Parse(srv.URL + "/")
require.NoError(t, err)

w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/.well-known/openid-configuration":
json.NewEncoder(w).Encode(map[string]any{
"issuer": baseURL.String(),
"authorization_endpoint": srv.URL + "/authorize",
})

default:
assert.Failf(t, "unexpected http request", "url: %s", r.URL.String())
}
})
srv = httptest.NewServer(handler)
t.Cleanup(srv.Close)

redirectURL, err := url.Parse(srv.URL)
require.NoError(t, err)

p, err := New(ctx, &oauth.Options{
ProviderURL: srv.URL,
RedirectURL: redirectURL,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
})
require.NoError(t, err)
require.NotNil(t, p)

t.Run("GetSignOutURL", func(t *testing.T) {
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b")
assert.NoError(t, err)
assert.Equal(t, srv.URL+"/v2/logout?client_id=CLIENT_ID&returnTo=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL)
})
}
34 changes: 25 additions & 9 deletions internal/identity/oidc/cognito/cognito.go
Expand Up @@ -48,14 +48,30 @@ func New(ctx context.Context, opts *oauth.Options) (*Provider, error) {
// https://docs.aws.amazon.com/cognito/latest/developerguide/revocation-endpoint.html
p.RevocationURL = cognitoURL.ResolveReference(&url.URL{Path: "/oauth2/revoke"}).String()

// https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html
p.EndSessionURL = cognitoURL.ResolveReference(&url.URL{
Path: "/logout",
RawQuery: url.Values{
"client_id": []string{opts.ClientID},
"logout_uri": []string{opts.RedirectURL.ResolveReference(&url.URL{Path: "/"}).String()},
}.Encode(),
}).String()

return &p, nil
}

// GetSignOutURL gets the sign out URL according to https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html.
func (p *Provider) GetSignOutURL(_, returnToURL string) (string, error) {
oa, err := p.GetOauthConfig()
if err != nil {
return "", fmt.Errorf("error getting cognito oauth config: %w", err)
}

authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
if err != nil {
return "", fmt.Errorf("error getting cognito endpoint auth url: %w", err)
}

logOutQuery := url.Values{
"client_id": []string{oa.ClientID},
}
if returnToURL != "" {
logOutQuery.Set("logout_uri", returnToURL)
}
logOutURL := authURL.ResolveReference(&url.URL{
Path: "/logout",
RawQuery: logOutQuery.Encode(),
})
return logOutURL.String(), nil
}

0 comments on commit a0c9289

Please sign in to comment.