Skip to content

Commit

Permalink
added expiry to JWT
Browse files Browse the repository at this point in the history
  • Loading branch information
dantownsend committed Aug 9, 2019
1 parent b108439 commit 374b296
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 10 deletions.
23 changes: 22 additions & 1 deletion docs/source/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,29 @@ You have to pass in two arguments:
authenticate the user.
* secret - this is used for signing the JWT.

expiry
~~~~~~

An optional argument, which allows you to control when a token expires. By
default it's set to 1 day.

.. code-block:: python
from datetime import timedelta
jwt_login(
auth_table=User,
secret=SECRET,
expiry=timedelta(minutes=10)
)
JWTMiddleware
-------------

This wraps an ASGI app, and ensures a valid token is passed in the header.
Otherwise a 403 error is returned. If the token is valid, the corresponding
``user_id`` is added to the ``scope``.

blacklist
~~~~~~~~~

Expand All @@ -129,7 +149,8 @@ anywhere else.
return token in BLACKLISTED_TOKENS
jwt_login(
asgi_app = JWTMiddleware(
my_endpoint,
auth_table=User,
secret=SECRET,
blacklist=MyBlacklist()
Expand Down
19 changes: 17 additions & 2 deletions piccolo_api/endpoints/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import abstractproperty
from datetime import datetime, timedelta
import typing as t

import jwt
Expand All @@ -20,6 +21,10 @@ def _auth_table(self) -> t.Type[BaseUser]:
def _secret(self) -> str:
raise NotImplementedError

@abstractproperty
def _expiry(self) -> timedelta:
raise NotImplementedError

async def post(self, request: Request) -> JSONResponse:
body = await request.json()
username = body.get('username', None)
Expand All @@ -36,7 +41,15 @@ async def post(self, request: Request) -> JSONResponse:
detail="Login failed"
)

payload = jwt.encode({'user_id': user_id}, self._secret).decode()
expiry = datetime.now() + self._expiry

payload = jwt.encode(
{
'user_id': user_id,
'exp': expiry
},
self._secret
).decode()

return JSONResponse({
'token': payload
Expand All @@ -45,11 +58,13 @@ async def post(self, request: Request) -> JSONResponse:

def jwt_login(
auth_table: BaseUser,
secret: str
secret: str,
expiry: timedelta = timedelta(days=1)
) -> t.Type[JWTLoginBase]:

class JWTLogin(JWTLoginBase):
_auth_table = auth_table
_secret = secret
_expiry = expiry

return JWTLogin
35 changes: 28 additions & 7 deletions piccolo_api/middleware/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing as t
import time

from starlette.exceptions import HTTPException
import jwt
Expand Down Expand Up @@ -45,12 +46,14 @@ def get_token(self, headers: dict) -> t.Optional[str]:
return None
return auth_str.split(' ')[1]

async def get_user_id(self, token: str) -> t.Optional[int]:
async def get_user_id(
self,
token_dict: t.Dict[str, t.Any]
) -> t.Optional[int]:
"""
Extract the user_id from the token, and check it's valid.
"""
message = jwt.decode(token, self.secret)
user_id = message.get('user_id', None)
user_id = token_dict.get('user_id', None)

if not user_id:
return None
Expand All @@ -59,11 +62,24 @@ async def get_user_id(self, token: str) -> t.Optional[int]:
self.auth_table.id == user_id
).run()

if exists == True:
if exists is True:
return user_id
else:
return None

def has_expired(self, token_dict: t.Dict[str, t.Any]) -> bool:
"""
Work out if the token has expired.
"""
expiry = token_dict.get('exp', None)

if not expiry:
# A token doesn't need to have an expiry.
return True
else:
# The value is a timestamp, based on Unix time.
return expiry < time.time()

async def __call__(self, scope, receive, send):
"""
Add the user_id to the scope if a JWT token is available, and the user
Expand All @@ -72,12 +88,17 @@ async def __call__(self, scope, receive, send):
headers = dict(scope['headers'])
token = self.get_token(headers)
if not token:
raise HTTPException(status_code=403)
raise HTTPException(status_code=403, detail="Token not found")

if await self.blacklist.in_blacklist(token):
raise HTTPException(status_code=403)
raise HTTPException(status_code=403, detail="Token revoked")

token_dict = jwt.decode(token, self.secret)

if self.has_expired(token_dict):
raise HTTPException(status_code=403, detail="Token has expired")

user_id = await self.get_user_id(token)
user_id = await self.get_user_id(token_dict)
if not user_id:
raise HTTPException(status_code=403)

Expand Down

0 comments on commit 374b296

Please sign in to comment.