Skip to content

Commit

Permalink
feat: add “provider id” parameter to kratos session (#3292)
Browse files Browse the repository at this point in the history
Closes #3283
  • Loading branch information
aeneasr committed May 25, 2023
1 parent 34ff1d2 commit 387f5a2
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 6 deletions.
2 changes: 1 addition & 1 deletion selfservice/flow/registration/handler.go
Expand Up @@ -599,7 +599,7 @@ func (h *Handler) updateRegistrationFlow(w http.ResponseWriter, r *http.Request,
return
}

if err := h.d.RegistrationExecutor().PostRegistrationHook(w, r, s.ID(), f, i); err != nil {
if err := h.d.RegistrationExecutor().PostRegistrationHook(w, r, s.ID(), "", f, i); err != nil {
h.d.RegistrationFlowErrorHandler().WriteFlowError(w, r, f, s.NodeGroup(), err)
return
}
Expand Down
9 changes: 7 additions & 2 deletions selfservice/flow/registration/hook.go
Expand Up @@ -96,7 +96,7 @@ func NewHookExecutor(d executorDependencies) *HookExecutor {
return &HookExecutor{d: d}
}

func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Request, ct identity.CredentialsType, a *Flow, i *identity.Identity) error {
func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Request, ct identity.CredentialsType, provider string, a *Flow, i *identity.Identity) error {
e.d.Logger().
WithRequest(r).
WithField("identity_id", i.ID).
Expand Down Expand Up @@ -174,7 +174,12 @@ func (e *HookExecutor) PostRegistrationHook(w http.ResponseWriter, r *http.Reque

trace.SpanFromContext(r.Context()).AddEvent(events.NewIdentityCreated(r.Context(), i.ID))

s, err := session.NewActiveSession(r, i, e.d.Config(), time.Now().UTC(), ct, identity.AuthenticatorAssuranceLevel1)
s := session.NewInactiveSession()
s.CompletedLoginForWithProvider(ct, identity.AuthenticatorAssuranceLevel1, provider)
if err := s.Activate(r, i, c, time.Now().UTC()); err != nil {
return err
}

if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion selfservice/flow/registration/hook_test.go
Expand Up @@ -57,7 +57,7 @@ func TestRegistrationExecutor(t *testing.T) {
a, err := registration.NewFlow(conf, time.Minute, x.FakeCSRFToken, r, ft)
require.NoError(t, err)
a.RequestURL = x.RequestURL(r).String()
_ = handleErr(t, w, r, reg.RegistrationHookExecutor().PostRegistrationHook(w, r, identity.CredentialsType(strategy), a, i))
_ = handleErr(t, w, r, reg.RegistrationHookExecutor().PostRegistrationHook(w, r, identity.CredentialsType(strategy), "", a, i))
})

ts := httptest.NewServer(router)
Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/oidc/strategy_login.go
Expand Up @@ -136,7 +136,7 @@ func (s *Strategy) processLogin(w http.ResponseWriter, r *http.Request, a *login
}

sess := session.NewInactiveSession()
sess.CompletedLoginFor(s.ID(), identity.AuthenticatorAssuranceLevel1)
sess.CompletedLoginForWithProvider(s.ID(), identity.AuthenticatorAssuranceLevel1, provider.Config().ID)
for _, c := range o.Providers {
if c.Subject == claims.Subject && c.Provider == provider.Config().ID {
if err = s.d.LoginHookExecutor().PostLoginHook(w, r, node.OpenIDConnectGroup, a, i, sess); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/oidc/strategy_registration.go
Expand Up @@ -282,7 +282,7 @@ func (s *Strategy) processRegistration(w http.ResponseWriter, r *http.Request, r
}

i.SetCredentials(s.ID(), *creds)
if err := s.d.RegistrationExecutor().PostRegistrationHook(w, r, identity.CredentialsTypeOIDC, rf, i); err != nil {
if err := s.d.RegistrationExecutor().PostRegistrationHook(w, r, identity.CredentialsTypeOIDC, provider.Config().ID, rf, i); err != nil {
return nil, s.handleError(w, r, rf, provider.Config().ID, i.Traits, err)
}

Expand Down
3 changes: 3 additions & 0 deletions selfservice/strategy/oidc/strategy_test.go
Expand Up @@ -435,6 +435,7 @@ func TestStrategy(t *testing.T) {
res, body := makeRequest(t, "valid", action, url.Values{})
assertIdentity(t, res, body)
expectTokens(t, "valid", body)
assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body)
})

t.Run("case=should pass login", func(t *testing.T) {
Expand All @@ -443,6 +444,7 @@ func TestStrategy(t *testing.T) {
res, body := makeRequest(t, "valid", action, url.Values{})
assertIdentity(t, res, body)
expectTokens(t, "valid", body)
assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body)
})
})

Expand All @@ -455,6 +457,7 @@ func TestStrategy(t *testing.T) {
action := assertFormValues(t, r.ID, "valid")
res, body := makeRequest(t, "valid", action, url.Values{})
assertIdentity(t, res, body)
assert.Equal(t, "valid", gjson.GetBytes(body, "authentication_methods.0.provider").String(), "%s", body)
})
})

Expand Down
7 changes: 7 additions & 0 deletions session/session.go
Expand Up @@ -158,6 +158,10 @@ func (s *Session) CompletedLoginFor(method identity.CredentialsType, aal identit
s.AMR = append(s.AMR, AuthenticationMethod{Method: method, AAL: aal, CompletedAt: time.Now().UTC()})
}

func (s *Session) CompletedLoginForWithProvider(method identity.CredentialsType, aal identity.AuthenticatorAssuranceLevel, providerID string) {
s.AMR = append(s.AMR, AuthenticationMethod{Method: method, AAL: aal, Provider: providerID, CompletedAt: time.Now().UTC()})
}

func (s *Session) AuthenticatedVia(method identity.CredentialsType) bool {
for _, authMethod := range s.AMR {
if authMethod.Method == method {
Expand Down Expand Up @@ -318,6 +322,9 @@ type AuthenticationMethod struct {

// When the authentication challenge was completed.
CompletedAt time.Time `json:"completed_at"`

// OIDC or SAML provider id used for authentication
Provider string `json:"provider,omitempty"`
}

// Scan implements the Scanner interface.
Expand Down

0 comments on commit 387f5a2

Please sign in to comment.