diff --git a/tests/test_session.py b/tests/test_session.py index 3b7ddd78..6409dc7f 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,8 +1,11 @@ -import pytest +import concurrent.futures +from datetime import datetime, timezone from unittest.mock import AsyncMock, Mock, patch + import jwt -from datetime import datetime, timezone -import concurrent.futures +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa from tests.conftest import with_jwks_mock from workos.session import AsyncSession, Session, _get_jwks_client @@ -16,9 +19,6 @@ RefreshWithSessionCookieSuccessResponse, ) -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import rsa - class SessionFixtures: @pytest.fixture(autouse=True) @@ -48,6 +48,7 @@ def session_constants(self): "sid": "session_123", "org_id": "organization_123", "role": "admin", + "roles": ["admin"], "permissions": ["read"], "entitlements": ["feature_1"], "exp": int(current_datetime.timestamp()) + 3600, @@ -215,6 +216,7 @@ def test_authenticate_success(self, session_constants, mock_user_management): "sid": session_constants["SESSION_ID"], "org_id": session_constants["ORGANIZATION_ID"], "role": "admin", + "roles": ["admin"], "permissions": ["read"], "entitlements": ["feature_1"], "exp": int(datetime.now(timezone.utc).timestamp()) + 3600, @@ -239,6 +241,7 @@ def test_authenticate_success(self, session_constants, mock_user_management): "sid": session_constants["SESSION_ID"], "org_id": session_constants["ORGANIZATION_ID"], "role": "admin", + "roles": ["admin"], "permissions": ["read"], "entitlements": ["feature_1"], } @@ -257,11 +260,80 @@ def test_authenticate_success(self, session_constants, mock_user_management): assert response.session_id == session_constants["SESSION_ID"] assert response.organization_id == session_constants["ORGANIZATION_ID"] assert response.role == "admin" + assert response.roles == ["admin"] assert response.permissions == ["read"] assert response.entitlements == ["feature_1"] assert response.user.id == session_constants["USER_ID"] assert response.impersonator is None + @with_jwks_mock + def test_authenticate_success_with_roles( + self, session_constants, mock_user_management + ): + session = Session( + user_management=mock_user_management, + client_id=session_constants["CLIENT_ID"], + session_data=session_constants["SESSION_DATA"], + cookie_password=session_constants["COOKIE_PASSWORD"], + ) + + # Mock the session data that would be unsealed + mock_session = { + "access_token": jwt.encode( + { + "sid": session_constants["SESSION_ID"], + "org_id": session_constants["ORGANIZATION_ID"], + "role": "admin", + "roles": ["admin", "member"], + "permissions": ["read", "write"], + "entitlements": ["feature_1"], + "exp": int(datetime.now(timezone.utc).timestamp()) + 3600, + "iat": int(datetime.now(timezone.utc).timestamp()), + }, + session_constants["PRIVATE_KEY"], + algorithm="RS256", + ), + "user": { + "object": "user", + "id": session_constants["USER_ID"], + "email": "user@example.com", + "email_verified": True, + "created_at": session_constants["CURRENT_TIMESTAMP"], + "updated_at": session_constants["CURRENT_TIMESTAMP"], + }, + "impersonator": None, + } + + # Mock the JWT payload that would be decoded + mock_jwt_payload = { + "sid": session_constants["SESSION_ID"], + "org_id": session_constants["ORGANIZATION_ID"], + "role": "admin", + "roles": ["admin", "member"], + "permissions": ["read", "write"], + "entitlements": ["feature_1"], + } + + with patch.object(Session, "unseal_data", return_value=mock_session), patch( + "jwt.decode", return_value=mock_jwt_payload + ), patch.object( + session.jwks, + "get_signing_key_from_jwt", + return_value=Mock(key=session_constants["PUBLIC_KEY"]), + ): + response = session.authenticate() + + assert isinstance(response, AuthenticateWithSessionCookieSuccessResponse) + assert response.authenticated is True + assert response.session_id == session_constants["SESSION_ID"] + assert response.organization_id == session_constants["ORGANIZATION_ID"] + assert response.role == "admin" + assert response.roles == ["admin", "member"] + assert response.permissions == ["read", "write"] + assert response.entitlements == ["feature_1"] + assert response.user.id == session_constants["USER_ID"] + assert response.impersonator is None + @with_jwks_mock def test_refresh_invalid_session_cookie( self, session_constants, mock_user_management @@ -335,6 +407,7 @@ def test_refresh_success(self, session_constants, mock_user_management): "sid": session_constants["SESSION_ID"], "org_id": session_constants["ORGANIZATION_ID"], "role": "admin", + "roles": ["admin"], "permissions": ["read"], "entitlements": ["feature_1"], }, @@ -435,6 +508,7 @@ async def test_refresh_success(self, session_constants, mock_user_management): "sid": session_constants["SESSION_ID"], "org_id": session_constants["ORGANIZATION_ID"], "role": "admin", + "roles": ["admin"], "permissions": ["read"], "entitlements": ["feature_1"], }, diff --git a/workos/session.py b/workos/session.py index 62aaae36..2052534e 100644 --- a/workos/session.py +++ b/workos/session.py @@ -102,6 +102,7 @@ def authenticate( session_id=decoded["sid"], organization_id=decoded.get("org_id", None), role=decoded.get("role", None), + roles=decoded.get("roles", None), permissions=decoded.get("permissions", None), entitlements=decoded.get("entitlements", None), user=session["user"], @@ -229,6 +230,7 @@ def refresh( session_id=decoded["sid"], organization_id=decoded.get("org_id", None), role=decoded.get("role", None), + roles=decoded.get("roles", None), permissions=decoded.get("permissions", None), entitlements=decoded.get("entitlements", None), user=auth_response.user, @@ -319,6 +321,7 @@ async def refresh( session_id=decoded["sid"], organization_id=decoded.get("org_id", None), role=decoded.get("role", None), + roles=decoded.get("roles", None), permissions=decoded.get("permissions", None), entitlements=decoded.get("entitlements", None), user=auth_response.user, diff --git a/workos/types/user_management/organization_membership.py b/workos/types/user_management/organization_membership.py index 926f6428..0f944070 100644 --- a/workos/types/user_management/organization_membership.py +++ b/workos/types/user_management/organization_membership.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, Sequence, Optional from typing_extensions import TypedDict from workos.types.workos_model import WorkOSModel @@ -19,6 +19,7 @@ class OrganizationMembership(WorkOSModel): user_id: str organization_id: str role: OrganizationMembershipRole + roles: Optional[Sequence[OrganizationMembershipRole]] = None status: LiteralOrUntyped[OrganizationMembershipStatus] created_at: str updated_at: str diff --git a/workos/types/user_management/session.py b/workos/types/user_management/session.py index 76739f9d..1be8025f 100644 --- a/workos/types/user_management/session.py +++ b/workos/types/user_management/session.py @@ -1,6 +1,8 @@ -from typing import Optional, Sequence, TypedDict, Union from enum import Enum +from typing import Optional, Sequence, TypedDict, Union + from typing_extensions import Literal + from workos.types.user_management.impersonator import Impersonator from workos.types.user_management.user import User from workos.types.workos_model import WorkOSModel @@ -17,6 +19,7 @@ class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel): session_id: str organization_id: Optional[str] = None role: Optional[str] = None + roles: Optional[Sequence[str]] = None permissions: Optional[Sequence[str]] = None user: User impersonator: Optional[Impersonator] = None diff --git a/workos/user_management.py b/workos/user_management.py index edfa2142..85b8eeb2 100644 --- a/workos/user_management.py +++ b/workos/user_management.py @@ -245,15 +245,24 @@ def delete_user(self, user_id: str) -> SyncOrAsync[None]: ... def create_organization_membership( - self, *, user_id: str, organization_id: str, role_slug: Optional[str] = None + self, + *, + user_id: str, + organization_id: str, + role_slug: Optional[str] = None, + role_slugs: Optional[Sequence[str]] = None, ) -> SyncOrAsync[OrganizationMembership]: """Create a new OrganizationMembership for the given Organization and User. Kwargs: - user_id: The Unique ID of the User. - organization_id: The Unique ID of the Organization to which the user belongs to. - role_slug: The Unique Slug of the Role to which to grant to this membership. - If no slug is passed in, the default role will be granted.(Optional) + user_id: The unique ID of the User. + organization_id: The unique ID of the Organization to which the user belongs to. + role_slug: The unique slug of the role to grant to this membership.(Optional) + role_slugs: The unique slugs of the roles to grant to this membership.(Optional) + + Note: + role_slug and role_slugs are mutually exclusive. If neither is provided, + the user will be assigned the organization's default role. Returns: OrganizationMembership: Created OrganizationMembership response from WorkOS. @@ -261,14 +270,22 @@ def create_organization_membership( ... def update_organization_membership( - self, *, organization_membership_id: str, role_slug: Optional[str] = None + self, + *, + organization_membership_id: str, + role_slug: Optional[str] = None, + role_slugs: Optional[Sequence[str]] = None, ) -> SyncOrAsync[OrganizationMembership]: """Updates an OrganizationMembership for the given id. Args: organization_membership_id (str): The unique ID of the Organization Membership. - role_slug: The Unique Slug of the Role to which to grant to this membership. - If no slug is passed in, it will not be changed (Optional) + role_slug: The unique slug of the role to grant to this membership.(Optional) + role_slugs: The unique slugs of the roles to grant to this membership.(Optional) + + Note: + role_slug and role_slugs are mutually exclusive. If neither is provided, + the role(s) of the membership will remain unchanged. Returns: OrganizationMembership: Updated OrganizationMembership response from WorkOS. @@ -988,12 +1005,18 @@ def delete_user(self, user_id: str) -> None: ) def create_organization_membership( - self, *, user_id: str, organization_id: str, role_slug: Optional[str] = None + self, + *, + user_id: str, + organization_id: str, + role_slug: Optional[str] = None, + role_slugs: Optional[Sequence[str]] = None, ) -> OrganizationMembership: json = { "user_id": user_id, "organization_id": organization_id, "role_slug": role_slug, + "role_slugs": role_slugs, } response = self._http_client.request( @@ -1003,10 +1026,15 @@ def create_organization_membership( return OrganizationMembership.model_validate(response) def update_organization_membership( - self, *, organization_membership_id: str, role_slug: Optional[str] = None + self, + *, + organization_membership_id: str, + role_slug: Optional[str] = None, + role_slugs: Optional[Sequence[str]] = None, ) -> OrganizationMembership: json = { "role_slug": role_slug, + "role_slugs": role_slugs, } response = self._http_client.request( @@ -1614,12 +1642,18 @@ async def delete_user(self, user_id: str) -> None: ) async def create_organization_membership( - self, *, user_id: str, organization_id: str, role_slug: Optional[str] = None + self, + *, + user_id: str, + organization_id: str, + role_slug: Optional[str] = None, + role_slugs: Optional[Sequence[str]] = None, ) -> OrganizationMembership: json = { "user_id": user_id, "organization_id": organization_id, "role_slug": role_slug, + "role_slugs": role_slugs, } response = await self._http_client.request( @@ -1629,10 +1663,15 @@ async def create_organization_membership( return OrganizationMembership.model_validate(response) async def update_organization_membership( - self, *, organization_membership_id: str, role_slug: Optional[str] = None + self, + *, + organization_membership_id: str, + role_slug: Optional[str] = None, + role_slugs: Optional[Sequence[str]] = None, ) -> OrganizationMembership: json = { "role_slug": role_slug, + "role_slugs": role_slugs, } response = await self._http_client.request(