diff --git a/scalekit/client.py b/scalekit/client.py index b51c8ba..45e0175 100644 --- a/scalekit/client.py +++ b/scalekit/client.py @@ -23,6 +23,7 @@ GrantType, IdpInitiatedLoginClaims, LogoutUrlOptions, + TokenValidationOptions, ) from scalekit.constants.user import id_token_claim_to_user_map from scalekit.common.exceptions import (WebhookVerificationError, @@ -139,7 +140,7 @@ def authenticate_with_code( access_token = response["access_token"] refresh_token = response.get("refresh_token") # Validate id_token - claims = self.__validate_token(id_token, {"verify_aud": False}) + claims = self.__validate_token(id_token, options=None) user = {} amr_claims = claims.get('amr', []) connection_id = amr_claims[0] if amr_claims else None @@ -162,18 +163,20 @@ def authenticate_with_code( except Exception as exp: raise exp - def validate_access_token(self, token: str, options: Optional[Dict] = None, audience = None) -> bool: + def validate_access_token(self, token: str, options: Optional[TokenValidationOptions] = None, audience = None) -> bool: """ Method to validate access token :param token : access token :type : ``` str ``` + :param options : Optional validation options for issuer, audience, and scopes + :type : ``` TokenValidationOptions ``` + :param audience : audience for validation (deprecated, use options.audience instead) + :type : ``` str ``` :returns: bool """ - options = options if options else {} - options["verify_aud"] = False if not audience else True try: self.__validate_token(token, options=options, audience=audience) return True @@ -206,19 +209,20 @@ def generate_client_token(self, client_id: str, client_secret: str) -> str: except Exception as exp: raise exp - def validate_access_token_and_get_claims(self, token: str, options: Optional[Dict] = None, audience = None) -> Dict[str, Any]: + def validate_access_token_and_get_claims(self, token: str, options: Optional[TokenValidationOptions] = None, audience = None) -> Dict[str, Any]: """ Method to validate access token and get claims :param token : access token :type : ``` str ``` + :param options : Optional validation options for issuer, audience, and scopes + :type : ``` TokenValidationOptions ``` + :param audience : audience for validation (deprecated, use options.audience instead) + :type : ``` str ``` :returns: claims """ - - options = options if options else {} - options["verify_aud"] = False if not audience else True try: claims = self.__validate_token(token, options=options, audience=audience) @@ -228,18 +232,22 @@ def validate_access_token_and_get_claims(self, token: str, options: Optional[Dic except Exception as exp: raise exp - def get_idp_initiated_login_claims(self, idp_initiated_login_token: str) -> IdpInitiatedLoginClaims: + def get_idp_initiated_login_claims(self, idp_initiated_login_token: str, options: Optional[TokenValidationOptions] = None, audience = None) -> IdpInitiatedLoginClaims: """ Method to get IDP initiated login claims :param idp_initiated_login_token : IDP initiated login token :type : ``` str ``` + :param options : Optional validation options for issuer and audience + :type : ``` TokenValidationOptions ``` + :param audience : audience for validation (deprecated, use options.audience instead) + :type : ``` str ``` :returns: ``` IdpInitiatedLoginClaims ``` """ try: - claims = self.__validate_token(idp_initiated_login_token, {"verify_aud": False}) + claims = self.__validate_token(idp_initiated_login_token, options=options, audience=audience) return claims except jwt.exceptions.InvalidTokenError as exp: raise ScalekitValidateTokenFailureException(exp) @@ -247,22 +255,85 @@ def get_idp_initiated_login_claims(self, idp_initiated_login_token: str) -> IdpI raise exp def __validate_token( - self, token: str, options: Optional[Dict] = None, audience: Optional[str] = None + self, token: str, options: Optional[TokenValidationOptions] = None, audience: Optional[str] = None ) -> Dict[str, Any]: """ Method to validate token :param token : token :type : ``` str ``` + :param options : validation options for issuer, audience, and scopes + :type : ``` TokenValidationOptions ``` + :param audience : audience for validation + :type : ``` str ``` :returns: payload """ + # Convert TokenValidationOptions to jwt decode options if provided + jwt_options = {} + if options: + if options.issuer: + jwt_options["issuer"] = options.issuer + if options.audience: + jwt_options["audience"] = options.audience + jwt_options["verify_aud"] = True + elif audience is not None: + jwt_options["audience"] = [audience] + jwt_options["verify_aud"] = True + else: + jwt_options["verify_aud"] = False + self.core_client.get_jwks() kid = jwt.get_unverified_header(token)["kid"] key = self.core_client.keys[kid] - return jwt.decode(token, key=key, algorithms="RS256", options=options, audience=audience) + payload = jwt.decode(token, key=key, algorithms="RS256", options=jwt_options) + + # Validate scopes if required + if options and options.required_scopes: + self.verify_scopes(token, options.required_scopes) + + return payload + + + + def verify_scopes(self, token: str, required_scopes: list[str]) -> bool: + """ + Verify that the token contains the required scopes + + :param token : The token to verify + :type : ``` str ``` + :param required_scopes : The scopes that must be present in the token + :type : ``` list[str] ``` + + :returns: + bool: Returns True if all required scopes are present + :raises: + Error: If required scopes are missing, with details about which scopes are missing + """ + payload = jwt.decode(token, options={"verify_signature": False}) + scopes = self.__extract_scopes_from_payload(payload) + + missing_scopes = [scope for scope in required_scopes if scope not in scopes] + + if missing_scopes: + raise ScalekitValidateTokenFailureException(f"Token missing required scopes: {', '.join(missing_scopes)}") + + return True + + def __extract_scopes_from_payload(self, payload: Dict[str, Any]) -> list[str]: + """ + Extract scopes from token payload + + :param payload : The token payload + :type : ``` Dict[str, Any] ``` + + :returns: + list[str]: Array of scopes found in the token + """ + scopes = payload.get("scopes", []) + return [scope for scope in scopes if scope and scope.strip()] def verify_webhook_payload(self, secret: str, headers: Dict[str, str], payload: [str | bytes]) -> bool: """ diff --git a/scalekit/common/scalekit.py b/scalekit/common/scalekit.py index ae3d1bd..a1abfc2 100644 --- a/scalekit/common/scalekit.py +++ b/scalekit/common/scalekit.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional +from typing import Optional, List class GrantType(Enum): @@ -72,3 +72,16 @@ def __init__( self.id_token_hint = id_token_hint self.post_logout_redirect_uri = post_logout_redirect_uri self.state = state + + +class TokenValidationOptions: + """Options for token validation including issuer, audience, and scope validation""" + def __init__( + self, + issuer: Optional[str] = None, + audience: Optional[List[str]] = None, + required_scopes: Optional[List[str]] = None + ): + self.issuer = issuer + self.audience = audience + self.required_scopes = required_scopes diff --git a/scalekit/core.py b/scalekit/core.py index 31ee633..b51b09a 100644 --- a/scalekit/core.py +++ b/scalekit/core.py @@ -26,7 +26,7 @@ def __call__(self, request: TRequest, metadata: TMetadata) -> TResponse: ... class CoreClient: """Class definition for Core Client""" - sdk_version = "Scalekit-Python/2.2.2" + sdk_version = "Scalekit-Python/2.3.0" api_version = "20250718" user_agent = f"{sdk_version} Python/{platform.python_version()} ({platform.system()}; {platform.architecture()}" diff --git a/setup.py b/setup.py index eaaad31..30606fc 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="scalekit-sdk-python", - version="2.2.2", + version="2.3.0", packages=find_packages(), install_requires=[ "grpcio>=1.64.1",