From 54676afce25b139bf4c4de3ac6bd2072b3e08e0b Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Mon, 30 Jan 2023 18:59:07 -0700 Subject: [PATCH] authenticate: always trust the passed in idp (#3917) --- authenticate/handlers.go | 54 ++++++++++---------------------- authenticate/identity_profile.go | 9 ++---- 2 files changed, 20 insertions(+), 43 deletions(-) diff --git a/authenticate/handlers.go b/authenticate/handlers.go index a8730d791f4..d34b4e7154e 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -118,24 +118,20 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { defer span.End() state := a.state.Load() - options := a.options.Load() - idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) - if err != nil { - return err - } + idpID := r.FormValue(urlutil.QueryIdentityProviderID) sessionState, err := a.getSessionFromCtx(ctx) if err != nil { log.FromRequest(r).Info(). Err(err). - Str("idp_id", idp.GetId()). + Str("idp_id", idpID). Msg("authenticate: session load error") return a.reauthenticateOrFail(w, r, err) } - if sessionState.IdentityProviderID != idp.GetId() { + if sessionState.IdentityProviderID != idpID { log.FromRequest(r).Info(). - Str("idp_id", idp.GetId()). + Str("idp_id", idpID). Str("session_idp_id", sessionState.IdentityProviderID). Str("id", sessionState.ID). Msg("authenticate: session not associated with identity provider") @@ -146,7 +142,7 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { if err != nil { log.FromRequest(r).Info(). Err(err). - Str("idp_id", idp.GetId()). + Str("idp_id", idpID). Msg("authenticate: identity profile load error") return a.reauthenticateOrFail(w, r, err) } @@ -169,7 +165,6 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { defer span.End() state := a.state.Load() - options := a.options.Load() if err := r.ParseForm(); err != nil { return httputil.NewError(http.StatusBadRequest, err) @@ -179,10 +174,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { return err } - idp, err := options.GetIdentityProviderForID(requestParams.Get(urlutil.QueryIdentityProviderID)) - if err != nil { - return err - } + idpID := requestParams.Get(urlutil.QueryIdentityProviderID) s, err := a.getSessionFromCtx(ctx) if err != nil { @@ -191,8 +183,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { } // start over if this is a different identity provider - if s == nil || s.IdentityProviderID != idp.GetId() { - s = sessions.NewState(idp.GetId()) + if s == nil || s.IdentityProviderID != idpID { + s = sessions.NewState(idpID) } // re-persist the session, useful when session was evicted from session @@ -240,12 +232,9 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e defer span.End() options := a.options.Load() - idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) - if err != nil { - return err - } + idpID := r.FormValue(urlutil.QueryIdentityProviderID) - authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) + authenticator, err := a.cfg.getIdentityProvider(options, idpID) if err != nil { return err } @@ -300,12 +289,9 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque state := a.state.Load() options := a.options.Load() - idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) - if err != nil { - return err - } + idpID := r.FormValue(urlutil.QueryIdentityProviderID) - authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) + authenticator, err := a.cfg.getIdentityProvider(options, idpID) if err != nil { return err } @@ -407,12 +393,9 @@ Or contact your administrator. `, redirectURL.String(), redirectURL.String())) } - idp, err := options.GetIdentityProviderForID(redirectURL.Query().Get(urlutil.QueryIdentityProviderID)) - if err != nil { - return nil, err - } + idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID) - authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) + authenticator, err := a.cfg.getIdentityProvider(options, idpID) if err != nil { return nil, err } @@ -426,7 +409,7 @@ Or contact your administrator. return nil, fmt.Errorf("error redeeming authenticate code: %w", err) } - s := sessions.NewState(idp.GetId()) + s := sessions.NewState(idpID) err = claims.Claims.Claims(&s) if err != nil { return nil, fmt.Errorf("error unmarshaling session state: %w", err) @@ -522,12 +505,9 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, // clear the user's local session no matter what defer state.sessionStore.ClearSession(w, r) - idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) - if err != nil { - return "" - } + idpID := r.FormValue(urlutil.QueryIdentityProviderID) - authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) + authenticator, err := a.cfg.getIdentityProvider(options, idpID) if err != nil { return "" } diff --git a/authenticate/identity_profile.go b/authenticate/identity_profile.go index fcd9e171fc0..a806f6e67eb 100644 --- a/authenticate/identity_profile.go +++ b/authenticate/identity_profile.go @@ -31,12 +31,9 @@ func (a *Authenticate) buildIdentityProfile( oauthToken *oauth2.Token, ) (*identitypb.Profile, error) { options := a.options.Load() - idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) - if err != nil { - return nil, fmt.Errorf("authenticate: error getting identity provider for id: %w", err) - } + idpID := r.FormValue(urlutil.QueryIdentityProviderID) - authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) + authenticator, err := a.cfg.getIdentityProvider(options, idpID) if err != nil { return nil, fmt.Errorf("authenticate: error getting identity provider authenticator: %w", err) } @@ -57,7 +54,7 @@ func (a *Authenticate) buildIdentityProfile( } return &identitypb.Profile{ - ProviderId: idp.GetId(), + ProviderId: idpID, IdToken: rawIDToken, OauthToken: rawOAuthToken, Claims: rawClaims,