From 00c047b11426a6a1837fe9e9aac17af3f31d1cfa Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Thu, 23 Feb 2023 08:30:50 -0700 Subject: [PATCH] authenticate: fix identity provider id in encrypted query string (#4006) Co-authored-by: Denis Mishin --- authenticate/handlers.go | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 63198f62601..a370279e4b3 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -133,7 +133,7 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { defer span.End() state := a.state.Load() - idpID := r.FormValue(urlutil.QueryIdentityProviderID) + idpID := a.getIdentityProviderIDForRequest(r) sessionState, err := a.getSessionFromCtx(ctx) if err != nil { @@ -247,7 +247,7 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e defer span.End() options := a.options.Load() - idpID := r.FormValue(urlutil.QueryIdentityProviderID) + idpID := a.getIdentityProviderIDForRequest(r) authenticator, err := a.cfg.getIdentityProvider(options, idpID) if err != nil { @@ -304,7 +304,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque state := a.state.Load() options := a.options.Load() - idpID := r.FormValue(urlutil.QueryIdentityProviderID) + idpID := a.getIdentityProviderIDForRequest(r) authenticator, err := a.cfg.getIdentityProvider(options, idpID) if err != nil { @@ -408,7 +408,7 @@ Or contact your administrator. `, redirectURL.String(), redirectURL.String())) } - idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID) + idpID := a.getIdentityProviderIDForURLValues(redirectURL.Query()) authenticator, err := a.cfg.getIdentityProvider(options, idpID) if err != nil { @@ -588,3 +588,24 @@ func (a *Authenticate) saveCallbackSession(w http.ResponseWriter, r *http.Reques } return rawJWT, nil } + +func (a *Authenticate) getIdentityProviderIDForRequest(r *http.Request) string { + if err := r.ParseForm(); err != nil { + return "" + } + return a.getIdentityProviderIDForURLValues(r.Form) +} + +func (a *Authenticate) getIdentityProviderIDForURLValues(vs url.Values) string { + state := a.state.Load() + idpID := "" + if _, requestParams, err := hpke.DecryptURLValues(state.hpkePrivateKey, vs); err == nil { + if idpID == "" { + idpID = requestParams.Get(urlutil.QueryIdentityProviderID) + } + } + if idpID == "" { + idpID = vs.Get(urlutil.QueryIdentityProviderID) + } + return idpID +}