Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 80 additions & 6 deletions tests/test_session.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest
import concurrent.futures
from datetime import datetime, timezone
from unittest.mock import AsyncMock, Mock, patch

import jwt
from datetime import datetime, timezone
import concurrent.futures
import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa

from tests.conftest import with_jwks_mock
from workos.session import AsyncSession, Session, _get_jwks_client
Expand All @@ -16,9 +19,6 @@
RefreshWithSessionCookieSuccessResponse,
)

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa


class SessionFixtures:
@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -48,6 +48,7 @@ def session_constants(self):
"sid": "session_123",
"org_id": "organization_123",
"role": "admin",
"roles": ["admin"],
"permissions": ["read"],
"entitlements": ["feature_1"],
"exp": int(current_datetime.timestamp()) + 3600,
Expand Down Expand Up @@ -215,6 +216,7 @@ def test_authenticate_success(self, session_constants, mock_user_management):
"sid": session_constants["SESSION_ID"],
"org_id": session_constants["ORGANIZATION_ID"],
"role": "admin",
"roles": ["admin"],
"permissions": ["read"],
"entitlements": ["feature_1"],
"exp": int(datetime.now(timezone.utc).timestamp()) + 3600,
Expand All @@ -239,6 +241,7 @@ def test_authenticate_success(self, session_constants, mock_user_management):
"sid": session_constants["SESSION_ID"],
"org_id": session_constants["ORGANIZATION_ID"],
"role": "admin",
"roles": ["admin"],
"permissions": ["read"],
"entitlements": ["feature_1"],
}
Expand All @@ -257,11 +260,80 @@ def test_authenticate_success(self, session_constants, mock_user_management):
assert response.session_id == session_constants["SESSION_ID"]
assert response.organization_id == session_constants["ORGANIZATION_ID"]
assert response.role == "admin"
assert response.roles == ["admin"]
assert response.permissions == ["read"]
assert response.entitlements == ["feature_1"]
assert response.user.id == session_constants["USER_ID"]
assert response.impersonator is None

@with_jwks_mock
def test_authenticate_success_with_roles(
self, session_constants, mock_user_management
):
session = Session(
user_management=mock_user_management,
client_id=session_constants["CLIENT_ID"],
session_data=session_constants["SESSION_DATA"],
cookie_password=session_constants["COOKIE_PASSWORD"],
)

# Mock the session data that would be unsealed
mock_session = {
"access_token": jwt.encode(
{
"sid": session_constants["SESSION_ID"],
"org_id": session_constants["ORGANIZATION_ID"],
"role": "admin",
"roles": ["admin", "member"],
"permissions": ["read", "write"],
"entitlements": ["feature_1"],
"exp": int(datetime.now(timezone.utc).timestamp()) + 3600,
"iat": int(datetime.now(timezone.utc).timestamp()),
},
session_constants["PRIVATE_KEY"],
algorithm="RS256",
),
"user": {
"object": "user",
"id": session_constants["USER_ID"],
"email": "user@example.com",
"email_verified": True,
"created_at": session_constants["CURRENT_TIMESTAMP"],
"updated_at": session_constants["CURRENT_TIMESTAMP"],
},
"impersonator": None,
}

# Mock the JWT payload that would be decoded
mock_jwt_payload = {
"sid": session_constants["SESSION_ID"],
"org_id": session_constants["ORGANIZATION_ID"],
"role": "admin",
"roles": ["admin", "member"],
"permissions": ["read", "write"],
"entitlements": ["feature_1"],
}

with patch.object(Session, "unseal_data", return_value=mock_session), patch(
"jwt.decode", return_value=mock_jwt_payload
), patch.object(
session.jwks,
"get_signing_key_from_jwt",
return_value=Mock(key=session_constants["PUBLIC_KEY"]),
):
response = session.authenticate()

assert isinstance(response, AuthenticateWithSessionCookieSuccessResponse)
assert response.authenticated is True
assert response.session_id == session_constants["SESSION_ID"]
assert response.organization_id == session_constants["ORGANIZATION_ID"]
assert response.role == "admin"
assert response.roles == ["admin", "member"]
assert response.permissions == ["read", "write"]
assert response.entitlements == ["feature_1"]
assert response.user.id == session_constants["USER_ID"]
assert response.impersonator is None

@with_jwks_mock
def test_refresh_invalid_session_cookie(
self, session_constants, mock_user_management
Expand Down Expand Up @@ -335,6 +407,7 @@ def test_refresh_success(self, session_constants, mock_user_management):
"sid": session_constants["SESSION_ID"],
"org_id": session_constants["ORGANIZATION_ID"],
"role": "admin",
"roles": ["admin"],
"permissions": ["read"],
"entitlements": ["feature_1"],
},
Expand Down Expand Up @@ -435,6 +508,7 @@ async def test_refresh_success(self, session_constants, mock_user_management):
"sid": session_constants["SESSION_ID"],
"org_id": session_constants["ORGANIZATION_ID"],
"role": "admin",
"roles": ["admin"],
"permissions": ["read"],
"entitlements": ["feature_1"],
},
Expand Down
3 changes: 3 additions & 0 deletions workos/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def authenticate(
session_id=decoded["sid"],
organization_id=decoded.get("org_id", None),
role=decoded.get("role", None),
roles=decoded.get("roles", None),
permissions=decoded.get("permissions", None),
entitlements=decoded.get("entitlements", None),
user=session["user"],
Expand Down Expand Up @@ -229,6 +230,7 @@ def refresh(
session_id=decoded["sid"],
organization_id=decoded.get("org_id", None),
role=decoded.get("role", None),
roles=decoded.get("roles", None),
permissions=decoded.get("permissions", None),
entitlements=decoded.get("entitlements", None),
user=auth_response.user,
Expand Down Expand Up @@ -319,6 +321,7 @@ async def refresh(
session_id=decoded["sid"],
organization_id=decoded.get("org_id", None),
role=decoded.get("role", None),
roles=decoded.get("roles", None),
permissions=decoded.get("permissions", None),
entitlements=decoded.get("entitlements", None),
user=auth_response.user,
Expand Down
3 changes: 2 additions & 1 deletion workos/types/user_management/organization_membership.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal
from typing import Literal, Sequence, Optional
from typing_extensions import TypedDict

from workos.types.workos_model import WorkOSModel
Expand All @@ -19,6 +19,7 @@ class OrganizationMembership(WorkOSModel):
user_id: str
organization_id: str
role: OrganizationMembershipRole
roles: Optional[Sequence[OrganizationMembershipRole]] = None
status: LiteralOrUntyped[OrganizationMembershipStatus]
created_at: str
updated_at: str
5 changes: 4 additions & 1 deletion workos/types/user_management/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional, Sequence, TypedDict, Union
from enum import Enum
from typing import Optional, Sequence, TypedDict, Union

from typing_extensions import Literal

from workos.types.user_management.impersonator import Impersonator
from workos.types.user_management.user import User
from workos.types.workos_model import WorkOSModel
Expand All @@ -17,6 +19,7 @@ class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel):
session_id: str
organization_id: Optional[str] = None
role: Optional[str] = None
roles: Optional[Sequence[str]] = None
permissions: Optional[Sequence[str]] = None
user: User
impersonator: Optional[Impersonator] = None
Expand Down
63 changes: 51 additions & 12 deletions workos/user_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,30 +245,47 @@ def delete_user(self, user_id: str) -> SyncOrAsync[None]:
...

def create_organization_membership(
self, *, user_id: str, organization_id: str, role_slug: Optional[str] = None
self,
*,
user_id: str,
organization_id: str,
role_slug: Optional[str] = None,
role_slugs: Optional[Sequence[str]] = None,
) -> SyncOrAsync[OrganizationMembership]:
"""Create a new OrganizationMembership for the given Organization and User.

Kwargs:
user_id: The Unique ID of the User.
organization_id: The Unique ID of the Organization to which the user belongs to.
role_slug: The Unique Slug of the Role to which to grant to this membership.
If no slug is passed in, the default role will be granted.(Optional)
user_id: The unique ID of the User.
organization_id: The unique ID of the Organization to which the user belongs to.
role_slug: The unique slug of the role to grant to this membership.(Optional)
role_slugs: The unique slugs of the roles to grant to this membership.(Optional)

Note:
role_slug and role_slugs are mutually exclusive. If neither is provided,
the user will be assigned the organization's default role.

Returns:
OrganizationMembership: Created OrganizationMembership response from WorkOS.
"""
...

def update_organization_membership(
self, *, organization_membership_id: str, role_slug: Optional[str] = None
self,
*,
organization_membership_id: str,
role_slug: Optional[str] = None,
role_slugs: Optional[Sequence[str]] = None,
) -> SyncOrAsync[OrganizationMembership]:
"""Updates an OrganizationMembership for the given id.

Args:
organization_membership_id (str): The unique ID of the Organization Membership.
role_slug: The Unique Slug of the Role to which to grant to this membership.
If no slug is passed in, it will not be changed (Optional)
role_slug: The unique slug of the role to grant to this membership.(Optional)
role_slugs: The unique slugs of the roles to grant to this membership.(Optional)

Note:
role_slug and role_slugs are mutually exclusive. If neither is provided,
the role(s) of the membership will remain unchanged.

Returns:
OrganizationMembership: Updated OrganizationMembership response from WorkOS.
Expand Down Expand Up @@ -988,12 +1005,18 @@ def delete_user(self, user_id: str) -> None:
)

def create_organization_membership(
self, *, user_id: str, organization_id: str, role_slug: Optional[str] = None
self,
*,
user_id: str,
organization_id: str,
role_slug: Optional[str] = None,
role_slugs: Optional[Sequence[str]] = None,
) -> OrganizationMembership:
json = {
"user_id": user_id,
"organization_id": organization_id,
"role_slug": role_slug,
"role_slugs": role_slugs,
}

response = self._http_client.request(
Expand All @@ -1003,10 +1026,15 @@ def create_organization_membership(
return OrganizationMembership.model_validate(response)

def update_organization_membership(
self, *, organization_membership_id: str, role_slug: Optional[str] = None
self,
*,
organization_membership_id: str,
role_slug: Optional[str] = None,
role_slugs: Optional[Sequence[str]] = None,
) -> OrganizationMembership:
json = {
"role_slug": role_slug,
"role_slugs": role_slugs,
}

response = self._http_client.request(
Expand Down Expand Up @@ -1614,12 +1642,18 @@ async def delete_user(self, user_id: str) -> None:
)

async def create_organization_membership(
self, *, user_id: str, organization_id: str, role_slug: Optional[str] = None
self,
*,
user_id: str,
organization_id: str,
role_slug: Optional[str] = None,
role_slugs: Optional[Sequence[str]] = None,
) -> OrganizationMembership:
json = {
"user_id": user_id,
"organization_id": organization_id,
"role_slug": role_slug,
"role_slugs": role_slugs,
}

response = await self._http_client.request(
Expand All @@ -1629,10 +1663,15 @@ async def create_organization_membership(
return OrganizationMembership.model_validate(response)

async def update_organization_membership(
self, *, organization_membership_id: str, role_slug: Optional[str] = None
self,
*,
organization_membership_id: str,
role_slug: Optional[str] = None,
role_slugs: Optional[Sequence[str]] = None,
) -> OrganizationMembership:
json = {
"role_slug": role_slug,
"role_slugs": role_slugs,
}

response = await self._http_client.request(
Expand Down