|
2 | 2 | from unittest.mock import AsyncMock, Mock, patch |
3 | 3 | import jwt |
4 | 4 | from datetime import datetime, timezone |
| 5 | +import concurrent.futures |
5 | 6 |
|
6 | 7 | from tests.conftest import with_jwks_mock |
7 | | -from workos.session import AsyncSession, Session |
| 8 | +from workos.session import AsyncSession, Session, _get_jwks_client |
8 | 9 | from workos.types.user_management.authentication_response import ( |
9 | 10 | RefreshTokenAuthenticationResponse, |
10 | 11 | ) |
|
20 | 21 |
|
21 | 22 |
|
22 | 23 | class SessionFixtures: |
| 24 | + @pytest.fixture(autouse=True) |
| 25 | + def clear_jwks_cache(self): |
| 26 | + _get_jwks_client.cache_clear() |
| 27 | + yield |
| 28 | + _get_jwks_client.cache_clear() |
| 29 | + |
23 | 30 | @pytest.fixture |
24 | 31 | def session_constants(self): |
25 | 32 | # Generate RSA key pair for testing |
@@ -491,3 +498,43 @@ async def test_refresh_success_with_aud_claim( |
491 | 498 | response = await session.refresh() |
492 | 499 |
|
493 | 500 | assert isinstance(response, RefreshWithSessionCookieSuccessResponse) |
| 501 | + |
| 502 | + |
| 503 | +class TestJWKSCaching: |
| 504 | + def test_jwks_client_caching_same_url(self): |
| 505 | + url = "https://api.workos.com/sso/jwks/test" |
| 506 | + |
| 507 | + client1 = _get_jwks_client(url) |
| 508 | + client2 = _get_jwks_client(url) |
| 509 | + |
| 510 | + # Should be the exact same instance |
| 511 | + assert client1 is client2 |
| 512 | + assert id(client1) == id(client2) |
| 513 | + |
| 514 | + def test_jwks_client_caching_different_urls(self): |
| 515 | + url1 = "https://api.workos.com/sso/jwks/client1" |
| 516 | + url2 = "https://api.workos.com/sso/jwks/client2" |
| 517 | + |
| 518 | + client1 = _get_jwks_client(url1) |
| 519 | + client2 = _get_jwks_client(url2) |
| 520 | + |
| 521 | + # Should be different instances |
| 522 | + assert client1 is not client2 |
| 523 | + assert id(client1) != id(client2) |
| 524 | + |
| 525 | + def test_jwks_cache_thread_safety(self): |
| 526 | + url = "https://api.workos.com/sso/jwks/thread_test" |
| 527 | + clients = [] |
| 528 | + |
| 529 | + def get_client(): |
| 530 | + return _get_jwks_client(url) |
| 531 | + |
| 532 | + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: |
| 533 | + futures = [executor.submit(get_client) for _ in range(10)] |
| 534 | + clients = [future.result() for future in futures] |
| 535 | + |
| 536 | + first_client = clients[0] |
| 537 | + for client in clients[1:]: |
| 538 | + assert ( |
| 539 | + client is first_client |
| 540 | + ), "All concurrent calls should return the same instance" |
0 commit comments