forked from jpadilla/pyjwt
/
api_jwk.py
127 lines (99 loc) · 3.68 KB
/
api_jwk.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import json
from datetime import datetime, timezone
from .algorithms import get_default_algorithms
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
class PyJWK:
def __init__(self, jwk_data, algorithm=None):
self._algorithms = get_default_algorithms()
self._jwk_data = jwk_data
kty = self._jwk_data.get("kty", None)
if not kty:
raise InvalidKeyError(f"kty is not found: {self._jwk_data}")
if not algorithm and isinstance(self._jwk_data, dict):
algorithm = self._jwk_data.get("alg", None)
if not algorithm:
# Determine alg with kty (and crv).
crv = self._jwk_data.get("crv", None)
if kty == "EC":
if crv == "P-256" or not crv:
algorithm = "ES256"
elif crv == "P-384":
algorithm = "ES384"
elif crv == "P-521":
algorithm = "ES512"
elif crv == "secp256k1":
algorithm = "ES256K"
else:
raise InvalidKeyError(f"Unsupported crv: {crv}")
elif kty == "RSA":
algorithm = "RS256"
elif kty == "oct":
algorithm = "HS256"
elif kty == "OKP":
if not crv:
raise InvalidKeyError(f"crv is not found: {self._jwk_data}")
if crv == "Ed25519":
algorithm = "EdDSA"
else:
raise InvalidKeyError(f"Unsupported crv: {crv}")
else:
raise InvalidKeyError(f"Unsupported kty: {kty}")
self.Algorithm = self._algorithms.get(algorithm)
if not self.Algorithm:
raise PyJWKError(f"Unable to find a algorithm for key: {self._jwk_data}")
self.key = self.Algorithm.from_jwk(self._jwk_data)
@staticmethod
def from_dict(obj, algorithm=None):
return PyJWK(obj, algorithm)
@staticmethod
def from_json(data, algorithm=None):
obj = json.loads(data)
return PyJWK.from_dict(obj, algorithm)
@property
def key_type(self):
return self._jwk_data.get("kty", None)
@property
def key_id(self):
return self._jwk_data.get("kid", None)
@property
def public_key_use(self):
return self._jwk_data.get("use", None)
class PyJWKSet:
def __init__(self, keys):
self.keys = []
if not keys or not isinstance(keys, list):
raise PyJWKSetError("Invalid JWK Set value")
if len(keys) == 0:
raise PyJWKSetError("The JWK Set did not contain any keys")
for key in keys:
try:
self.keys.append(PyJWK(key))
except PyJWKError:
# skip unusable keys
continue
if len(self.keys) == 0:
raise PyJWKSetError("The JWK Set did not contain any usable keys")
@staticmethod
def from_dict(obj):
keys = obj.get("keys", [])
return PyJWKSet(keys)
@staticmethod
def from_json(data):
obj = json.loads(data)
return PyJWKSet.from_dict(obj)
def __getitem__(self, kid):
for key in self.keys:
if key.key_id == kid:
return key
raise KeyError(f"keyset has no key for kid: {kid}")
class PyJWTSetWithTimestamp:
def __init__(self, jwt_set: PyJWKSet, timestamp: datetime = None):
self.jwt_set = jwt_set
if timestamp is None:
self.timestamp = datetime.now(timezone.utc)
else:
self.timestamp = timestamp
def get_jwk_set(self):
return self.jwt_set
def get_timestamp(self):
return self.timestamp