Skip to content

Commit

Permalink
authenticate: always trust the passed in idp
Browse files Browse the repository at this point in the history
  • Loading branch information
calebdoxsey committed Jan 27, 2023
1 parent 447e38f commit d4571fb
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 43 deletions.
54 changes: 17 additions & 37 deletions authenticate/handlers.go
Expand Up @@ -118,24 +118,20 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
defer span.End()

state := a.state.Load()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)

sessionState, err := a.getSessionFromCtx(ctx)
if err != nil {
log.FromRequest(r).Info().
Err(err).
Str("idp_id", idp.GetId()).
Str("idp_id", idpID).
Msg("authenticate: session load error")
return a.reauthenticateOrFail(w, r, err)
}

if sessionState.IdentityProviderID != idp.GetId() {
if sessionState.IdentityProviderID != idpID {
log.FromRequest(r).Info().
Str("idp_id", idp.GetId()).
Str("idp_id", idpID).
Str("session_idp_id", sessionState.IdentityProviderID).
Str("id", sessionState.ID).
Msg("authenticate: session not associated with identity provider")
Expand All @@ -146,7 +142,7 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
if err != nil {
log.FromRequest(r).Info().
Err(err).
Str("idp_id", idp.GetId()).
Str("idp_id", idpID).
Msg("authenticate: identity profile load error")
return a.reauthenticateOrFail(w, r, err)
}
Expand All @@ -169,7 +165,6 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
defer span.End()

state := a.state.Load()
options := a.options.Load()

if err := r.ParseForm(); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
Expand All @@ -179,10 +174,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
return err
}

idp, err := options.GetIdentityProviderForID(requestParams.Get(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idpID := requestParams.Get(urlutil.QueryIdentityProviderID)

s, err := a.getSessionFromCtx(ctx)
if err != nil {
Expand All @@ -191,8 +183,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
}

// start over if this is a different identity provider
if s == nil || s.IdentityProviderID != idp.GetId() {
s = sessions.NewState(idp.GetId())
if s == nil || s.IdentityProviderID != idpID {
s = sessions.NewState(idpID)
}

// re-persist the session, useful when session was evicted from session
Expand Down Expand Up @@ -240,12 +232,9 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
defer span.End()

options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)

authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return err
}
Expand Down Expand Up @@ -300,12 +289,9 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque

state := a.state.Load()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)

authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return err
}
Expand Down Expand Up @@ -407,12 +393,9 @@ Or contact your administrator.
`, redirectURL.String(), redirectURL.String()))
}

idp, err := options.GetIdentityProviderForID(redirectURL.Query().Get(urlutil.QueryIdentityProviderID))
if err != nil {
return nil, err
}
idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)

authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return nil, err
}
Expand All @@ -426,7 +409,7 @@ Or contact your administrator.
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
}

s := sessions.NewState(idp.GetId())
s := sessions.NewState(idpID)
err = claims.Claims.Claims(&s)
if err != nil {
return nil, fmt.Errorf("error unmarshaling session state: %w", err)
Expand Down Expand Up @@ -522,12 +505,9 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter,
// clear the user's local session no matter what
defer state.sessionStore.ClearSession(w, r)

idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return ""
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)

authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return ""
}
Expand Down
9 changes: 3 additions & 6 deletions authenticate/identity_profile.go
Expand Up @@ -31,12 +31,9 @@ func (a *Authenticate) buildIdentityProfile(
oauthToken *oauth2.Token,
) (*identitypb.Profile, error) {
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return nil, fmt.Errorf("authenticate: error getting identity provider for id: %w", err)
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)

authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return nil, fmt.Errorf("authenticate: error getting identity provider authenticator: %w", err)
}
Expand All @@ -57,7 +54,7 @@ func (a *Authenticate) buildIdentityProfile(
}

return &identitypb.Profile{
ProviderId: idp.GetId(),
ProviderId: idpID,
IdToken: rawIDToken,
OauthToken: rawOAuthToken,
Claims: rawClaims,
Expand Down

0 comments on commit d4571fb

Please sign in to comment.