/
base.py
218 lines (184 loc) · 7.01 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
from typing import List, Dict, Optional, Any, Type
import requests
from copy import deepcopy
from jose import jwk, jwt
from jose.utils import base64url_decode
from jose.backends.base import Key
from fastapi import Depends, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from pydantic.error_wrappers import ValidationError
from starlette import status
NOT_AUTHENTICATED = "Not authenticated"
NO_PUBLICKEY = "JWK public Attribute for authorization token not found"
NOT_VERIFIED = "Not verified"
SCOPE_NOT_MATCHED = "Scope not matched"
NOT_VALIDATED_CLAIMS = "Validation Error for Claims"
class JWKS:
# keys: List[Dict[str, Any]]
keys: Dict[str, Key]
def __init__(self, keys: Dict[str, Key]):
self.keys = keys
@classmethod
def fromurl(cls, url: str):
"""
get and parse json into jwks from endpoint as follows,
https://xxx/.well-known/jwks.json
"""
# return cls.parse_obj(requests.get(url).json())
jwks = requests.get(url).json()
jwks = {_jwk["kid"]: jwk.construct(_jwk) for _jwk in jwks.get("keys", [])}
return cls(keys=jwks)
@classmethod
def firebase(cls, url: str):
"""
get and parse json into jwks from endpoint for Firebase,
"""
certs = requests.get(url).json()
keys = {
kid: jwk.construct(publickey, algorithm="RS256")
for kid, publickey in certs.items()
}
return cls(keys=keys)
class BaseTokenVerifier:
def __init__(self, jwks: JWKS, auto_error: bool = True, *args, **kwargs):
"""
auto-error: if False, return payload as b'null' for invalid token.
"""
self.jwks_to_key = jwks.keys
self.scope_name: Optional[str] = None
self.auto_error = auto_error
def clone(self):
"""create clone instanse"""
# In some case, self.jwks_to_key can't pickle (deepcopy).
# Tempolary put it aside to deepcopy. Then, undo it at the last line.
jwks_to_key = self.jwks_to_key
self.jwks_to_key = {}
clone = deepcopy(self)
clone.jwks_to_key = jwks_to_key
# undo original instanse
self.jwks_to_key = jwks_to_key
return clone
def get_publickey(self, http_auth: HTTPAuthorizationCredentials):
token = http_auth.credentials
header = jwt.get_unverified_header(token)
kid = header.get("kid")
if not kid:
if self.auto_error:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=NOT_AUTHENTICATED
)
else:
return None
publickey = self.jwks_to_key.get(kid)
if not publickey:
if self.auto_error:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=NO_PUBLICKEY,
)
else:
return None
return publickey
def verify_token(self, http_auth: HTTPAuthorizationCredentials) -> bool:
public_key = self.get_publickey(http_auth)
if not public_key:
# error handling is included in self.get_publickey
return False
message, encoded_sig = http_auth.credentials.rsplit(".", 1)
decoded_sig = base64url_decode(encoded_sig.encode())
is_verified = public_key.verify(message.encode(), decoded_sig)
if not is_verified:
if self.auto_error:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=NOT_VERIFIED
)
return is_verified
class TokenVerifier(BaseTokenVerifier):
"""
Verify `Access token` and authorize it based on scope (or groups)
"""
scope_key: Optional[str] = None
def scope(self, scope_name: str):
"""User-SCOPE verification Shortcut to pass it into dependencies.
Use as (`auth` is this instanse and `app` is fastapi.FastAPI instanse):
```
from fastapi import Depends
@app.get("/", dependencies=[Depends(auth.scope("allowed scope"))])
def api():
return "hello"
```
"""
clone = self.clone()
clone.scope_name = scope_name
if not clone.scope_key:
raise AttributeError("declaire scope_key to set scope")
return clone
def verify_scope(self, http_auth: HTTPAuthorizationCredentials) -> bool:
claims = jwt.get_unverified_claims(http_auth.credentials)
scopes = claims.get(self.scope_key)
if isinstance(scopes, str):
scopes = {scope.strip() for scope in scopes.split()}
if scopes is None or self.scope_name not in scopes:
if self.auto_error:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=SCOPE_NOT_MATCHED,
)
return False
return True
async def __call__(
self, http_auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())
) -> Optional[bool]:
"""User access-token verification Shortcut to pass it into dependencies.
Use as (`auth` is this instanse and `app` is fastapi.FastAPI instanse):
```
from fastapi import Depends
@app.get("/", dependencies=[Depends(auth)])
def api():
return "hello"
```
"""
is_verified = self.verify_token(http_auth)
if not is_verified:
return None
if self.scope_name:
is_verified_scope = self.verify_scope(http_auth)
if not is_verified_scope:
return None
return True
class TokenUserInfoGetter(BaseTokenVerifier):
"""
Verify `ID token` and extract user information
"""
user_info: Type[BaseModel] = None
def __init__(self, *args, **kwargs):
if not self.user_info:
raise AttributeError(
"must assign custom pydantic.BaseModel into class attributes `user_info`"
)
super().__init__(*args, **kwargs)
async def __call__(
self, http_auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())
) -> Optional[Type[BaseModel]]:
"""Get current user and verification with ID-token Shortcut.
Use as (`Auth` is this subclass, `auth` is `Auth` instanse and `app` is fastapi.FastAPI instanse):
```
from fastapi import Depends
@app.get("/")
def api(current_user: Auth = Depends(auth)):
return current_user
```
"""
is_verified = self.verify_token(http_auth)
if not is_verified:
return None
claims = jwt.get_unverified_claims(http_auth.credentials)
try:
current_user = self.user_info.parse_obj(claims)
return current_user
except ValidationError:
if self.auto_error:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail=NOT_VALIDATED_CLAIMS,
)
else:
return None