Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

authenticate: always trust the passed in idp #3917

Merged
merged 1 commit into from Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 17 additions & 37 deletions authenticate/handlers.go
Expand Up @@ -118,24 +118,20 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
defer span.End()

state := a.state.Load()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)

sessionState, err := a.getSessionFromCtx(ctx)
if err != nil {
log.FromRequest(r).Info().
Err(err).
Str("idp_id", idp.GetId()).
Str("idp_id", idpID).
Msg("authenticate: session load error")
return a.reauthenticateOrFail(w, r, err)
}

if sessionState.IdentityProviderID != idp.GetId() {
if sessionState.IdentityProviderID != idpID {
log.FromRequest(r).Info().
Str("idp_id", idp.GetId()).
Str("idp_id", idpID).
Str("session_idp_id", sessionState.IdentityProviderID).
Str("id", sessionState.ID).
Msg("authenticate: session not associated with identity provider")
Expand All @@ -146,7 +142,7 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
if err != nil {
log.FromRequest(r).Info().
Err(err).
Str("idp_id", idp.GetId()).
Str("idp_id", idpID).
Msg("authenticate: identity profile load error")
return a.reauthenticateOrFail(w, r, err)
}
Expand All @@ -169,7 +165,6 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
defer span.End()

state := a.state.Load()
options := a.options.Load()

if err := r.ParseForm(); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
Expand All @@ -179,10 +174,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
return err
}

idp, err := options.GetIdentityProviderForID(requestParams.Get(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idpID := requestParams.Get(urlutil.QueryIdentityProviderID)

s, err := a.getSessionFromCtx(ctx)
if err != nil {
Expand All @@ -191,8 +183,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
}

// start over if this is a different identity provider
if s == nil || s.IdentityProviderID != idp.GetId() {
s = sessions.NewState(idp.GetId())
if s == nil || s.IdentityProviderID != idpID {
s = sessions.NewState(idpID)
}

// re-persist the session, useful when session was evicted from session
Expand Down Expand Up @@ -240,12 +232,9 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
defer span.End()

options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)

authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return err
}
Expand Down Expand Up @@ -300,12 +289,9 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque

state := a.state.Load()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)

authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return err
}
Expand Down Expand Up @@ -407,12 +393,9 @@ Or contact your administrator.
`, redirectURL.String(), redirectURL.String()))
}

idp, err := options.GetIdentityProviderForID(redirectURL.Query().Get(urlutil.QueryIdentityProviderID))
if err != nil {
return nil, err
}
idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)

authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return nil, err
}
Expand All @@ -426,7 +409,7 @@ Or contact your administrator.
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
}

s := sessions.NewState(idp.GetId())
s := sessions.NewState(idpID)
err = claims.Claims.Claims(&s)
if err != nil {
return nil, fmt.Errorf("error unmarshaling session state: %w", err)
Expand Down Expand Up @@ -522,12 +505,9 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter,
// clear the user's local session no matter what
defer state.sessionStore.ClearSession(w, r)

idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return ""
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)

authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return ""
}
Expand Down
9 changes: 3 additions & 6 deletions authenticate/identity_profile.go
Expand Up @@ -31,12 +31,9 @@ func (a *Authenticate) buildIdentityProfile(
oauthToken *oauth2.Token,
) (*identitypb.Profile, error) {
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return nil, fmt.Errorf("authenticate: error getting identity provider for id: %w", err)
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)

authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return nil, fmt.Errorf("authenticate: error getting identity provider authenticator: %w", err)
}
Expand All @@ -57,7 +54,7 @@ func (a *Authenticate) buildIdentityProfile(
}

return &identitypb.Profile{
ProviderId: idp.GetId(),
ProviderId: idpID,
IdToken: rawIDToken,
OauthToken: rawOAuthToken,
Claims: rawClaims,
Expand Down