Skip to content

Commit

Permalink
end session based on v2 session ID
Browse files Browse the repository at this point in the history
  • Loading branch information
muhlemmer committed May 8, 2024
1 parent d1c6b50 commit 3b56d60
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 16 deletions.
15 changes: 8 additions & 7 deletions internal/api/oidc/auth_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,18 +363,19 @@ func (o *OPStorage) TerminateSessionFromRequest(ctx context.Context, endSessionR
}()

// check for the login client header
// and if not provided, terminate the session using the V1 method
headers, _ := http_utils.HeadersFromCtx(ctx)
if loginClient := headers.Get(LoginClientHeader); loginClient == "" {
return endSessionRequest.RedirectURI, o.TerminateSession(ctx, endSessionRequest.UserID, endSessionRequest.ClientID)
// in case there is no id_token_hint, redirect to the UI and let it decide which session to terminate
if headers.Get(LoginClientHeader) != "" && endSessionRequest.IDTokenHintClaims == nil {
return o.defaultLogoutURLV2 + endSessionRequest.RedirectURI, nil
}

// in case there are not id_token_hint, redirect to the UI and let it decide which session to terminate
if endSessionRequest.IDTokenHintClaims == nil {
return o.defaultLogoutURLV2 + endSessionRequest.RedirectURI, nil
// If there is no login client header and no id_token_hint or the id_token_hint does not have a session ID,
// do a v1 Terminate session.
if endSessionRequest.IDTokenHintClaims == nil || endSessionRequest.IDTokenHintClaims.SessionID == "" {
return endSessionRequest.RedirectURI, o.TerminateSession(ctx, endSessionRequest.UserID, endSessionRequest.ClientID)
}

// terminate the session of the id_token_hint
// terminate the v2 session of the id_token_hint
_, err = o.command.TerminateSessionWithoutTokenCheck(ctx, endSessionRequest.IDTokenHintClaims.SessionID)
if err != nil {
return "", err
Expand Down
5 changes: 3 additions & 2 deletions internal/api/oidc/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (s *Server) accessTokenResponseFromSession(ctx context.Context, client op.C
}

if slices.Contains(session.Scope, oidc.ScopeOpenID) {
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor)
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, session.SessionID, resp.AccessToken, session.Audience, session.AuthMethods, session.AuthTime, session.Nonce, session.Actor)
}
return resp, err
}
Expand Down Expand Up @@ -112,7 +112,7 @@ func (s *Server) getUserInfoOnce(userID, projectID string, projectRoleAssertion
}
}

func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, getSigningKey signerFunc, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) {
func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo userInfoFunc, getSigningKey signerFunc, sessionID, accessToken string, audience []string, authMethods []domain.UserAuthMethodType, authTime time.Time, nonce string, actor *domain.TokenActor) (idToken string, exp uint64, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()

Expand All @@ -139,6 +139,7 @@ func (*Server) createIDToken(ctx context.Context, client op.Client, getUserInfo
client.GetID(),
client.ClockSkew(),
)
claims.SessionID = sessionID
claims.Actor = actorDomainToClaims(actor)
claims.SetUserInfo(userInfo)
if accessToken != "" {
Expand Down
15 changes: 8 additions & 7 deletions internal/api/oidc/token_exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,10 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
actor = actorToken.nestedActor()
}

var sessionID string
switch tokenType {
case oidc.AccessTokenType, "":
resp.AccessToken, resp.RefreshToken, resp.ExpiresIn, err = s.createExchangeAccessToken(ctx, client, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, reason, actor)
resp.AccessToken, resp.RefreshToken, sessionID, resp.ExpiresIn, err = s.createExchangeAccessToken(ctx, client, subjectToken.userID, subjectToken.resourceOwner, audience, scopes, actorToken.authMethods, actorToken.authTime, reason, actor)
resp.TokenType = oidc.BearerToken
resp.IssuedTokenType = oidc.AccessTokenType

Expand All @@ -235,7 +236,7 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
resp.IssuedTokenType = oidc.JWTTokenType

case oidc.IDTokenType:
resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, getSigner, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
resp.AccessToken, resp.ExpiresIn, err = s.createIDToken(ctx, client, getUserInfo, getSigner, "", resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
resp.TokenType = TokenTypeNA
resp.IssuedTokenType = oidc.IDTokenType

Expand All @@ -249,7 +250,7 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
}

if slices.Contains(scopes, oidc.ScopeOpenID) && tokenType != oidc.IDTokenType {
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
resp.IDToken, _, err = s.createIDToken(ctx, client, getUserInfo, getSigner, sessionID, resp.AccessToken, audience, actorToken.authMethods, actorToken.authTime, "", actor)
if err != nil {
return nil, err
}
Expand All @@ -258,7 +259,7 @@ func (s *Server) createExchangeTokens(ctx context.Context, tokenType oidc.TokenT
return resp, nil
}

func (s *Server) createExchangeAccessToken(ctx context.Context, client *Client, userID, resourceOwner string, audience, scope []string, authMethods []domain.UserAuthMethodType, authTime time.Time, reason domain.TokenReason, actor *domain.TokenActor) (accessToken string, refreshToken string, exp uint64, err error) {
func (s *Server) createExchangeAccessToken(ctx context.Context, client *Client, userID, resourceOwner string, audience, scope []string, authMethods []domain.UserAuthMethodType, authTime time.Time, reason domain.TokenReason, actor *domain.TokenActor) (accessToken, refreshToken, sessionID string, exp uint64, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()

Expand All @@ -277,13 +278,13 @@ func (s *Server) createExchangeAccessToken(ctx context.Context, client *Client,
slices.Contains(scope, oidc.ScopeOfflineAccess),
)
if err != nil {
return "", "", 0, err
return "", "", "", 0, err
}
accessToken, err = op.CreateBearerToken(session.TokenID, userID, s.opCrypto)
if err != nil {
return "", "", 0, err
return "", "", "", 0, err
}
return accessToken, session.RefreshToken, timeToOIDCExpiresIn(session.Expiration), nil
return accessToken, session.RefreshToken, session.SessionID, timeToOIDCExpiresIn(session.Expiration), nil
}

func (s *Server) createExchangeJWT(ctx context.Context, client *Client, getUserInfo userInfoFunc, getSigner signerFunc, userID, resourceOwner string, audience, scope []string, authMethods []domain.UserAuthMethodType, authTime time.Time, reason domain.TokenReason, actor *domain.TokenActor) (accessToken string, refreshToken string, exp uint64, err error) {
Expand Down
2 changes: 2 additions & 0 deletions internal/command/oidc_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const (
)

type OIDCSession struct {
SessionID string
TokenID string
ClientID string
UserID string
Expand Down Expand Up @@ -416,6 +417,7 @@ func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (*OIDCSession, error
return nil, err
}
session := &OIDCSession{
SessionID: c.oidcSessionWriteModel.SessionID,
ClientID: c.oidcSessionWriteModel.ClientID,
UserID: c.oidcSessionWriteModel.UserID,
Audience: c.oidcSessionWriteModel.Audience,
Expand Down

0 comments on commit 3b56d60

Please sign in to comment.