Skip to content

Commit 3722094

Browse files
authored
Cache JWKS clients per URL (#466)
* Cache JWKS clients per URL * Make JWKS client cache threadsafe
1 parent 84f7966 commit 3722094

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

tests/test_session.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
from unittest.mock import AsyncMock, Mock, patch
33
import jwt
44
from datetime import datetime, timezone
5+
import concurrent.futures
56

67
from tests.conftest import with_jwks_mock
7-
from workos.session import AsyncSession, Session
8+
from workos.session import AsyncSession, Session, _get_jwks_client
89
from workos.types.user_management.authentication_response import (
910
RefreshTokenAuthenticationResponse,
1011
)
@@ -20,6 +21,12 @@
2021

2122

2223
class SessionFixtures:
24+
@pytest.fixture(autouse=True)
25+
def clear_jwks_cache(self):
26+
_get_jwks_client.cache_clear()
27+
yield
28+
_get_jwks_client.cache_clear()
29+
2330
@pytest.fixture
2431
def session_constants(self):
2532
# Generate RSA key pair for testing
@@ -491,3 +498,43 @@ async def test_refresh_success_with_aud_claim(
491498
response = await session.refresh()
492499

493500
assert isinstance(response, RefreshWithSessionCookieSuccessResponse)
501+
502+
503+
class TestJWKSCaching:
504+
def test_jwks_client_caching_same_url(self):
505+
url = "https://api.workos.com/sso/jwks/test"
506+
507+
client1 = _get_jwks_client(url)
508+
client2 = _get_jwks_client(url)
509+
510+
# Should be the exact same instance
511+
assert client1 is client2
512+
assert id(client1) == id(client2)
513+
514+
def test_jwks_client_caching_different_urls(self):
515+
url1 = "https://api.workos.com/sso/jwks/client1"
516+
url2 = "https://api.workos.com/sso/jwks/client2"
517+
518+
client1 = _get_jwks_client(url1)
519+
client2 = _get_jwks_client(url2)
520+
521+
# Should be different instances
522+
assert client1 is not client2
523+
assert id(client1) != id(client2)
524+
525+
def test_jwks_cache_thread_safety(self):
526+
url = "https://api.workos.com/sso/jwks/thread_test"
527+
clients = []
528+
529+
def get_client():
530+
return _get_jwks_client(url)
531+
532+
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
533+
futures = [executor.submit(get_client) for _ in range(10)]
534+
clients = [future.result() for future in futures]
535+
536+
first_client = clients[0]
537+
for client in clients[1:]:
538+
assert (
539+
client is first_client
540+
), "All concurrent calls should return the same instance"

workos/session.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22
from typing import TYPE_CHECKING, List, Protocol
33

4+
from functools import lru_cache
45
import json
56
from typing import Any, Dict, Optional, Union, cast
67
import jwt
@@ -21,6 +22,11 @@
2122
from workos.user_management import AsyncUserManagement, UserManagement
2223

2324

25+
@lru_cache(maxsize=None)
26+
def _get_jwks_client(jwks_url: str) -> PyJWKClient:
27+
return PyJWKClient(jwks_url)
28+
29+
2430
class SessionModule(Protocol):
2531
user_management: "UserManagementModule"
2632
client_id: str
@@ -46,7 +52,7 @@ def __init__(
4652
self.session_data = session_data
4753
self.cookie_password = cookie_password
4854

49-
self.jwks = PyJWKClient(self.user_management.get_jwks_url())
55+
self.jwks = _get_jwks_client(self.user_management.get_jwks_url())
5056

5157
# Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm
5258
self.jwk_algorithms = ["RS256"]
@@ -164,7 +170,7 @@ def __init__(
164170
self.session_data = session_data
165171
self.cookie_password = cookie_password
166172

167-
self.jwks = PyJWKClient(self.user_management.get_jwks_url())
173+
self.jwks = _get_jwks_client(self.user_management.get_jwks_url())
168174

169175
# Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm
170176
self.jwk_algorithms = ["RS256"]
@@ -254,7 +260,7 @@ def __init__(
254260
self.session_data = session_data
255261
self.cookie_password = cookie_password
256262

257-
self.jwks = PyJWKClient(self.user_management.get_jwks_url())
263+
self.jwks = _get_jwks_client(self.user_management.get_jwks_url())
258264

259265
# Algorithms are hardcoded for security reasons. See https://pyjwt.readthedocs.io/en/stable/algorithms.html#specifying-an-algorithm
260266
self.jwk_algorithms = ["RS256"]

0 commit comments

Comments
 (0)