Skip to content

Commit

Permalink
authenticate: get/set identity provider id for all sessions (#3608)
Browse files Browse the repository at this point in the history
authenticate: get/set identity provider id for all sessions (#3597)

Co-authored-by: Caleb Doxsey <cdoxsey@pomerium.com>
  • Loading branch information
backport-actions-token[bot] and calebdoxsey committed Sep 7, 2022
1 parent c3ef43c commit c0a8870
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 31 deletions.
81 changes: 56 additions & 25 deletions authenticate/handlers.go
Expand Up @@ -138,20 +138,25 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
defer span.End()

state := a.state.Load()
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}

sessionState, err := a.getSessionFromCtx(ctx)
if err != nil {
log.FromRequest(r).Info().
Err(err).
Str("idp_id", idpID).
Str("idp_id", idp.GetId()).
Msg("authenticate: session load error")
return a.reauthenticateOrFail(w, r, err)
}

if sessionState.IdentityProviderID != idpID {
if sessionState.IdentityProviderID != idp.GetId() {
log.FromRequest(r).Info().
Str("idp_id", idpID).
Str("idp_id", idp.GetId()).
Str("session_idp_id", sessionState.IdentityProviderID).
Str("id", sessionState.ID).
Msg("authenticate: session not associated with identity provider")
return a.reauthenticateOrFail(w, r, err)
Expand All @@ -163,7 +168,7 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
if _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID); err != nil {
log.FromRequest(r).Info().
Err(err).
Str("idp_id", idpID).
Str("idp_id", idp.GetId()).
Str("id", sessionState.ID).
Msg("authenticate: session not found in databroker")
return a.reauthenticateOrFail(w, r, err)
Expand All @@ -187,6 +192,11 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
defer span.End()

state := a.state.Load()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}

redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil {
Expand Down Expand Up @@ -216,8 +226,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
}

// start over if this is a different identity provider
if s == nil || s.IdentityProviderID != r.FormValue(urlutil.QueryIdentityProviderID) {
s = sessions.NewState(urlutil.QueryIdentityProviderID)
if s == nil || s.IdentityProviderID != idp.GetId() {
s = sessions.NewState(idp.GetId())
}

newSession := s.WithNewIssuer(state.redirectURL.Host, jwtAudience)
Expand Down Expand Up @@ -276,16 +286,20 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
defer span.End()

options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}

idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil {
return err
}

rawIDToken := a.revokeSession(ctx, w, r)

redirectString := ""
signOutURL, err := a.options.Load().GetSignOutRedirectURL()
signOutURL, err := options.GetSignOutRedirectURL()
if err != nil {
return err
}
Expand All @@ -296,14 +310,14 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
redirectString = uri
}

endSessionURL, err := idp.LogOut()
endSessionURL, err := authenticator.LogOut()
if err == nil && redirectString != "" {
params := url.Values{}
params.Add("id_token_hint", rawIDToken)
params.Add("post_logout_redirect_uri", redirectString)
endSessionURL.RawQuery = params.Encode()
redirectString = endSessionURL.String()
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
} else if err != nil && !errors.Is(err, oidc.ErrSignoutNotImplemented) {
log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session")
}
if redirectString != "" {
Expand All @@ -330,10 +344,14 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
return httputil.NewError(http.StatusUnauthorized, err)
}

options := a.options.Load()
state := a.state.Load()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}

idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil {
return err
}
Expand All @@ -346,7 +364,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b)
b = append(b, enc...)
encodedState := base64.URLEncoding.EncodeToString(b)
signinURL, err := idp.GetSignInURL(encodedState)
signinURL, err := authenticator.GetSignInURL(encodedState)
if err != nil {
return httputil.NewError(http.StatusInternalServerError,
fmt.Errorf("failed to get sign in url: %w", err))
Expand Down Expand Up @@ -381,8 +399,8 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
defer span.End()

options := a.options.Load()
state := a.state.Load()
options := a.options.Load()

// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
//
Expand Down Expand Up @@ -428,9 +446,13 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
if err != nil {
return nil, httputil.NewError(http.StatusBadRequest, err)
}
idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)

idp, err := a.cfg.getIdentityProvider(options, idpID)
idp, err := options.GetIdentityProviderForID(redirectURL.Query().Get(urlutil.QueryIdentityProviderID))
if err != nil {
return nil, err
}

authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil {
return nil, err
}
Expand All @@ -439,12 +461,12 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
//
// Exchange the supplied Authorization Code for a valid user session.
var claims identity.SessionClaims
accessToken, err := idp.Authenticate(ctx, code, &claims)
accessToken, err := authenticator.Authenticate(ctx, code, &claims)
if err != nil {
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
}

s := sessions.NewState(idpID)
s := sessions.NewState(idp.GetId())
err = claims.Claims.Claims(&s)
if err != nil {
return nil, fmt.Errorf("error unmarshaling session state: %w", err)
Expand Down Expand Up @@ -582,8 +604,12 @@ func (a *Authenticate) saveSessionToDataBroker(
) error {
state := a.state.Load()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}

idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil {
return err
}
Expand All @@ -593,7 +619,7 @@ func (a *Authenticate) saveSessionToDataBroker(

s := &session.Session{
Id: sessionState.ID,
UserId: sessionState.UserID(idp.Name()),
UserId: sessionState.UserID(authenticator.Name()),
IssuedAt: timestamppb.Now(),
AccessedAt: timestamppb.Now(),
ExpiresAt: sessionExpiry,
Expand All @@ -617,7 +643,7 @@ func (a *Authenticate) saveSessionToDataBroker(
Id: s.GetUserId(),
}
}
err = idp.UpdateUserInfo(ctx, accessToken, &managerUser)
err = authenticator.UpdateUserInfo(ctx, accessToken, &managerUser)
if err != nil {
return fmt.Errorf("authenticate: error retrieving user info: %w", err)
}
Expand Down Expand Up @@ -648,13 +674,18 @@ func (a *Authenticate) saveSessionToDataBroker(
// databroker. If successful, it returns the original `id_token` of the session, if failed, returns
// and empty string.
func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string {
options := a.options.Load()
state := a.state.Load()
options := a.options.Load()

// clear the user's local session no matter what
defer state.sessionStore.ClearSession(w, r)

idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return ""
}

authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil {
return ""
}
Expand All @@ -667,7 +698,7 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter,

if s, _ := session.Get(ctx, state.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil {
rawIDToken = s.GetIdToken().GetRaw()
if err := idp.Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil {
if err := authenticator.Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil {
log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token")
}
}
Expand Down
14 changes: 8 additions & 6 deletions authenticate/handlers_test.go
Expand Up @@ -478,6 +478,8 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
w.WriteHeader(http.StatusOK)
})

idp, _ := new(config.Options).GetIdentityProviderForID("")

tests := []struct {
name string
headers map[string]string
Expand All @@ -491,47 +493,47 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
{
"good",
nil,
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
nil,
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
http.StatusOK,
},
{
"invalid session",
nil,
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
errors.New("hi"),
identity.MockProvider{},
http.StatusFound,
},
{
"good refresh expired",
nil,
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
nil,
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
http.StatusOK,
},
{
"expired,refresh error",
nil,
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
sessions.ErrExpired,
identity.MockProvider{RefreshError: errors.New("error")},
http.StatusFound,
},
{
"expired,save error",
nil,
&mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{ID: "xyz"}},
&mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
sessions.ErrExpired,
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
http.StatusFound,
},
{
"expired XHR,refresh error",
map[string]string{"X-Requested-With": "XmlHttpRequest"},
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
sessions.ErrExpired,
identity.MockProvider{RefreshError: errors.New("error")},
http.StatusUnauthorized,
Expand Down

0 comments on commit c0a8870

Please sign in to comment.