diff --git a/selfservice/flow/registration/handler.go b/selfservice/flow/registration/handler.go index 281821bc36d..8a76d02521e 100644 --- a/selfservice/flow/registration/handler.go +++ b/selfservice/flow/registration/handler.go @@ -599,7 +599,7 @@ func (h *Handler) updateRegistrationFlow(w http.ResponseWriter, r *http.Request, return } - if err := h.d.RegistrationExecutor().PostRegistrationHook(w, r, s.ID(), f, i); err != nil { + if err := h.d.RegistrationExecutor().PostRegistrationHook(w, r, s.ID(), "", f, i); err != nil { h.d.RegistrationFlowErrorHandler().WriteFlowError(w, r, f, s.NodeGroup(), err) return } diff --git a/selfservice/flow/registration/hook.go b/selfservice/flow/registration/hook.go index 8a41b18ae08..344c3507545 100644 --- a/selfservice/flow/registration/hook.go +++ b/selfservice/flow/registration/hook.go @@ -96,7 +96,7 @@ func NewHookExecutor(d executorDependencies) *HookExecutor { return &HookExecutor{d: d} } -func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Request, ct identity.CredentialsType, a *Flow, i *identity.Identity) error { +func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Request, ct identity.CredentialsType, provider string, a *Flow, i *identity.Identity) error { e.d.Logger(). WithRequest(r). WithField("identity_id", i.ID). @@ -174,7 +174,12 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque trace.SpanFromContext(r.Context()).AddEvent(events.NewIdentityCreated(r.Context(), i.ID)) - s, err := session.NewActiveSession(r, i, e.d.Config(), time.Now().UTC(), ct, identity.AuthenticatorAssuranceLevel1) + s := session.NewInactiveSession() + s.CompletedLoginForWithProvider(ct, identity.AuthenticatorAssuranceLevel1, provider) + if err := s.Activate(r, i, c, time.Now().UTC()); err != nil { + return err + } + if err != nil { return err } diff --git a/selfservice/flow/registration/hook_test.go b/selfservice/flow/registration/hook_test.go index bb1f5f1537a..49719a7f10c 100644 --- a/selfservice/flow/registration/hook_test.go +++ b/selfservice/flow/registration/hook_test.go @@ -57,7 +57,7 @@ func TestRegistrationExecutor(t *testing.T) { a, err := registration.NewFlow(conf, time.Minute, x.FakeCSRFToken, r, ft) require.NoError(t, err) a.RequestURL = x.RequestURL(r).String() - _ = handleErr(t, w, r, reg.RegistrationHookExecutor().PostRegistrationHook(w, r, identity.CredentialsType(strategy), a, i)) + _ = handleErr(t, w, r, reg.RegistrationHookExecutor().PostRegistrationHook(w, r, identity.CredentialsType(strategy), "", a, i)) }) ts := httptest.NewServer(router) diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 2c014f2ef75..1aa79376104 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -136,7 +136,7 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, a *login } sess := session.NewInactiveSession() - sess.CompletedLoginFor(s.ID(), identity.AuthenticatorAssuranceLevel1) + sess.CompletedLoginForWithProvider(s.ID(), identity.AuthenticatorAssuranceLevel1, provider.Config().ID) for _, c := range o.Providers { if c.Subject == claims.Subject && c.Provider == provider.Config().ID { if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, a, i, sess); err != nil { diff --git a/selfservice/strategy/oidc/strategy_registration.go b/selfservice/strategy/oidc/strategy_registration.go index ad8d8c49591..1e303c50df3 100644 --- a/selfservice/strategy/oidc/strategy_registration.go +++ b/selfservice/strategy/oidc/strategy_registration.go @@ -282,7 +282,7 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r } i.SetCredentials(s.ID(), *creds) - if err := s.d.RegistrationExecutor().PostRegistrationHook(w, r, identity.CredentialsTypeOIDC, rf, i); err != nil { + if err := s.d.RegistrationExecutor().PostRegistrationHook(w, r, identity.CredentialsTypeOIDC, provider.Config().ID, rf, i); err != nil { return nil, s.handleError(w, r, rf, provider.Config().ID, i.Traits, err) } diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index 5fb693a0f82..7205ee20d77 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -435,6 +435,7 @@ func TestStrategy(t *testing.T) { res, body := makeRequest(t, "valid", action, url.Values{}) assertIdentity(t, res, body) expectTokens(t, "valid", body) + assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body) }) t.Run("case=should pass login", func(t *testing.T) { @@ -443,6 +444,7 @@ func TestStrategy(t *testing.T) { res, body := makeRequest(t, "valid", action, url.Values{}) assertIdentity(t, res, body) expectTokens(t, "valid", body) + assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body) }) }) @@ -455,6 +457,7 @@ func TestStrategy(t *testing.T) { action := assertFormValues(t, r.ID, "valid") res, body := makeRequest(t, "valid", action, url.Values{}) assertIdentity(t, res, body) + assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body) }) }) diff --git a/session/session.go b/session/session.go index 294f1bec6db..c409323c296 100644 --- a/session/session.go +++ b/session/session.go @@ -158,6 +158,10 @@ func (s *Session) CompletedLoginFor(method identity.CredentialsType, aal identit s.AMR = append(s.AMR, AuthenticationMethod{Method: method, AAL: aal, CompletedAt: time.Now().UTC()}) } +func (s *Session) CompletedLoginForWithProvider(method identity.CredentialsType, aal identity.AuthenticatorAssuranceLevel, providerID string) { + s.AMR = append(s.AMR, AuthenticationMethod{Method: method, AAL: aal, Provider: providerID, CompletedAt: time.Now().UTC()}) +} + func (s *Session) AuthenticatedVia(method identity.CredentialsType) bool { for _, authMethod := range s.AMR { if authMethod.Method == method { @@ -318,6 +322,9 @@ type AuthenticationMethod struct { // When the authentication challenge was completed. CompletedAt time.Time `json:"completed_at"` + + // OIDC or SAML provider id used for authentication + Provider string `json:"provider,omitempty"` } // Scan implements the Scanner interface.