Skip to content

Commit

Permalink
changed BaseUser import path
Browse files Browse the repository at this point in the history
  • Loading branch information
dantownsend committed Nov 15, 2019
1 parent 07a0668 commit 43b3ea4
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 28 deletions.
3 changes: 0 additions & 3 deletions docs/source/jwt/endpoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@ This creates an endpoint for logging in, and getting a JSON Web Token (JWT).
from starlette.routing import Route, Router
from piccolo_api.jwt_auth.endpoints import jwt_login
from .tables import User
from settings import SECRET
asgi_app = Router([
Route(
path="/login/",
endpoint=jwt_login(
auth_table=User,
secret=SECRET
)
),
Expand Down Expand Up @@ -45,7 +43,6 @@ default it's set to 1 day.
from datetime import timedelta
jwt_login(
auth_table=User,
secret=SECRET,
expiry=timedelta(minutes=10)
)
Expand Down
6 changes: 4 additions & 2 deletions piccolo_api/jwt_auth/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from starlette.responses import JSONResponse
from starlette.requests import Request

from piccolo.extensions.user import BaseUser
from piccolo.extensions.user.tables import BaseUser


class JWTLoginBase(HTTPEndpoint):
Expand Down Expand Up @@ -44,7 +44,9 @@ async def post(self, request: Request) -> JSONResponse:


def jwt_login(
auth_table: BaseUser, secret: str, expiry: timedelta = timedelta(days=1)
secret: str,
auth_table: BaseUser = BaseUser,
expiry: timedelta = timedelta(days=1),
) -> t.Type[JWTLoginBase]:
class JWTLogin(JWTLoginBase):
_auth_table = auth_table
Expand Down
37 changes: 19 additions & 18 deletions piccolo_api/jwt_auth/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,30 @@
from starlette.exceptions import HTTPException
import jwt

from piccolo.extensions.user import BaseUser
from piccolo.extensions.user.tables import BaseUser


class JWTBlacklist():

class JWTBlacklist:
async def in_blacklist(self, token: str) -> bool:
"""
Checks whether the token is in the blacklist.
"""
return False


class JWTMiddleware():
class JWTMiddleware:
"""
Protects an endpoint - only allows access if a JWT token is presented.
"""

auth_table: BaseUser = None

def __init__(
self,
asgi,
auth_table: BaseUser,
secret: str,
blacklist: JWTBlacklist = JWTBlacklist()
auth_table: t.Type[BaseUser] = BaseUser,
blacklist: JWTBlacklist = JWTBlacklist(),
) -> None:
self.asgi = asgi
self.secret = secret
Expand All @@ -38,29 +38,30 @@ def get_token(self, headers: dict) -> t.Optional[str]:
"""
Try and extract the JWT token from the request headers.
"""
auth_token = headers.get(b'authorization', None)
auth_token = headers.get(b"authorization", None)
if not auth_token:
return None
auth_str = auth_token.decode()
if not auth_str.startswith('Bearer '):
if not auth_str.startswith("Bearer "):
return None
return auth_str.split(' ')[1]
return auth_str.split(" ")[1]

async def get_user_id(
self,
token_dict: t.Dict[str, t.Any]
self, token_dict: t.Dict[str, t.Any]
) -> t.Optional[int]:
"""
Extract the user_id from the token, and check it's valid.
"""
user_id = token_dict.get('user_id', None)
user_id = token_dict.get("user_id", None)

if not user_id:
return None

exists = await self.auth_table.exists().where(
self.auth_table.id == user_id
).run()
exists = (
await self.auth_table.exists()
.where(self.auth_table.id == user_id)
.run()
)

if exists is True:
return user_id
Expand All @@ -71,7 +72,7 @@ def has_expired(self, token_dict: t.Dict[str, t.Any]) -> bool:
"""
Work out if the token has expired.
"""
expiry = token_dict.get('exp', None)
expiry = token_dict.get("exp", None)

if not expiry:
# A token doesn't need to have an expiry.
Expand All @@ -85,7 +86,7 @@ async def __call__(self, scope, receive, send):
Add the user_id to the scope if a JWT token is available, and the user
is recognised, otherwise raise a 403 HTTP error.
"""
headers = dict(scope['headers'])
headers = dict(scope["headers"])
token = self.get_token(headers)
if not token:
raise HTTPException(status_code=403, detail="Token not found")
Expand All @@ -103,6 +104,6 @@ async def __call__(self, scope, receive, send):
raise HTTPException(status_code=403)

new_scope = dict(scope)
new_scope['user_id'] = user_id
new_scope["user_id"] = user_id

await self.asgi(new_scope, receive, send)
2 changes: 1 addition & 1 deletion piccolo_api/session_auth/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typing as t
import warnings

from piccolo.extensions.user import BaseUser
from piccolo.extensions.user.tables import BaseUser
from starlette.exceptions import HTTPException
from starlette.endpoints import HTTPEndpoint, Request
from starlette.responses import (
Expand Down
2 changes: 1 addition & 1 deletion piccolo_api/session_auth/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import typing as t

from piccolo.extensions.user import BaseUser as PiccoloBaseUser
from piccolo.extensions.user.tables import BaseUser as PiccoloBaseUser
from piccolo_api.session_auth.tables import SessionsBase
from starlette.authentication import (
AuthenticationBackend,
Expand Down
3 changes: 3 additions & 0 deletions piccolo_api/session_auth/migrations/config.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
NAME = "session_auth"


DEPENDENCIES = ["piccolo.extensions.user.migrations.config"]
2 changes: 1 addition & 1 deletion piccolo_api/token_auth/tables/token_auth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import secrets

from piccolo.columns.column_types import Varchar, ForeignKey
from piccolo.extensions.user import BaseUser
from piccolo.extensions.user.tables import BaseUser
from piccolo.table import Table


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
asgiref==3.2.1
Jinja2==2.10.1
piccolo>=0.5.2
piccolo>=0.6.0
pydantic==1.0
python-multipart==0.0.5
starlette>=0.12.13
Expand Down
2 changes: 1 addition & 1 deletion tests/test_session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from unittest import TestCase

from piccolo.extensions.user import BaseUser
from piccolo.extensions.user.tables import BaseUser
from piccolo.engine.sqlite import SQLiteEngine
from starlette.authentication import requires
from starlette.endpoints import HTTPEndpoint
Expand Down

0 comments on commit 43b3ea4

Please sign in to comment.