Skip to content

Commit

Permalink
Initial implementation of ttl jwk set cache
Browse files Browse the repository at this point in the history
  • Loading branch information
wuhaoyujerry committed Jun 29, 2022
1 parent a863a73 commit 479a7c1
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 5 deletions.
17 changes: 17 additions & 0 deletions jwt/api_jwk.py
@@ -1,4 +1,5 @@
import json
from datetime import datetime, timezone

from .algorithms import get_default_algorithms
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
Expand Down Expand Up @@ -108,3 +109,19 @@ def __getitem__(self, kid):
if key.key_id == kid:
return key
raise KeyError(f"keyset has no key for kid: {kid}")


class PyJWTSetWithTimestamp:
def __init__(self, jwt_set: PyJWKSet, timestamp: datetime = None):
self.jwt_set = jwt_set

if timestamp is None:
self.timestamp = datetime.now(timezone.utc)
else:
self.timestamp = timestamp

def get_jwk_set(self):
return self.jwt_set

def get_timestamp(self):
return self.timestamp
29 changes: 29 additions & 0 deletions jwt/jwk_set_cache.py
@@ -0,0 +1,29 @@
from typing import Optional
from datetime import datetime, timezone

from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp


class JWKSetCache:
def __init__(self, lifespan: int):
self.jwk_set_with_timestamp = None
self.lifespan = lifespan

def put(self, jwk_set: PyJWKSet):
if jwk_set is not None:
self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
else:
# clear cache
self.jwk_set_with_timestamp = None

def get(self) -> Optional:
if self.jwk_set_with_timestamp is None or self.is_expired():
return None

return self.jwk_set_with_timestamp.get_jwk_set()

def is_expired(self) -> bool:
return self.jwk_set_with_timestamp is not None \
and self.lifespan > -1 \
and datetime.now(timezone.utc) > self.jwk_set_with_timestamp.get_timestamp() + self.lifespan

32 changes: 27 additions & 5 deletions jwt/jwks_client.py
@@ -1,27 +1,49 @@
import json
import urllib.request
from urllib.request import urlopen
from urllib.error import URLError
from functools import lru_cache
from typing import Any, List

from .api_jwk import PyJWK, PyJWKSet
from .api_jwt import decode_complete as decode_token
from .jwk_set_cache import JWKSetCache
from .exceptions import PyJWKClientError


class PyJWKClient:
def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16):
def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16,
cache_jwk_set: bool = True, lifespan: int = 5):
self.uri = uri

if cache_jwk_set:
# Init jwt set cache with default or given lifespan.
self.jwk_set_cache = JWKSetCache(lifespan)

if cache_keys:
# Cache signing keys
# Ignore mypy (https://github.com/python/mypy/issues/2427)
self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore

def fetch_data(self) -> Any:
with urllib.request.urlopen(self.uri) as response:
return json.load(response)
try:
with urlopen(self.uri) as response:
jwk_set = json.load(response)
except URLError as e:
raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"')

if self.jwk_set_cache is not None:
self.jwk_set_cache.put(jwk_set)

return jwk_set

def get_jwk_set(self) -> PyJWKSet:
data = self.fetch_data()
data = None
if self.jwk_set_cache is not None:
data = self.jwk_set_cache.get()

if data is None:
data = self.fetch_data()

return PyJWKSet.from_dict(data)

def get_signing_keys(self) -> List[PyJWK]:
Expand Down

0 comments on commit 479a7c1

Please sign in to comment.