Skip to content

Commit

Permalink
fleshed out session auth
Browse files Browse the repository at this point in the history
  • Loading branch information
dantownsend committed Oct 30, 2019
1 parent ae2103a commit 888e94a
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 9 deletions.
2 changes: 2 additions & 0 deletions piccolo_api/csrf/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class CSRFMiddleware():
pass
38 changes: 35 additions & 3 deletions piccolo_api/session_auth/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from piccolo.extensions.user import BaseUser
from starlette.exceptions import HTTPException
from starlette.endpoints import HTTPEndpoint, Request
from starlette.responses import HTMLResponse, RedirectResponse
from starlette.responses import (
HTMLResponse,
RedirectResponse,
PlainTextResponse,
)
from starlette.authentication import requires
from starlette.status import HTTP_303_SEE_OTHER
from starlette.templating import Jinja2Templates

Expand All @@ -20,6 +25,23 @@
)


class SessionLogoutEndpoint(HTTPEndpoint):
@abstractproperty
def _session_table(self) -> t.Type[SessionsBase]:
raise NotImplementedError

@requires(scopes=["authenticated"], redirect="login")
async def post(self, request: Request) -> PlainTextResponse:
cookie = request.cookies.get("id", None)
breakpoint()
if not cookie:
raise HTTPException(
status_code=401, detail="The session cookie wasn't found."
)
await self._session_table.remove_session(token=cookie).run()
return PlainTextResponse("Successfully logged out")


class SessionLoginEndpoint(HTTPEndpoint):
@abstractproperty
def _auth_table(self) -> t.Type[BaseUser]:
Expand All @@ -36,7 +58,8 @@ def _expiry(self) -> timedelta:
@abstractproperty
def _redirect_to(self) -> str:
"""
Where to redirect to after login is successful.
Where to redirect to after login is successful. It's the name of a
Starlette route.
"""
raise NotImplementedError

Expand Down Expand Up @@ -85,7 +108,7 @@ async def post(self, request: Request) -> RedirectResponse:
if not self._production:
message = (
"If running sessions in production, make sure 'production' "
"is set to True"
"is set to True, and serve under HTTPS."
)
warnings.warn(message)

Expand Down Expand Up @@ -114,3 +137,12 @@ class _SessionLoginEndpoint(SessionLoginEndpoint):
_production = production

return _SessionLoginEndpoint


def session_logout(
session_table: SessionsBase,
) -> t.Type[SessionLogoutEndpoint]:
class _SessionLogoutEndpoint(SessionLogoutEndpoint):
_session_table = session_table

return _SessionLogoutEndpoint
2 changes: 1 addition & 1 deletion piccolo_api/session_auth/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def display_name(self) -> str:

@property
def identity(self) -> str:
return ""
return str(self.user_id)


class SessionsAuthBackend(AuthenticationBackend):
Expand Down
8 changes: 8 additions & 0 deletions piccolo_api/session_auth/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,11 @@ async def get_user_id(cls, token: str) -> t.Optional[int]:
@classmethod
def get_user_id_sync(cls, token: str) -> t.Optional[int]:
return async_to_sync(cls.get_user_id)(token)

@classmethod
async def remove_session(cls, token: str):
await cls.delete().where(cls.token == token).run()

@classmethod
def remove_session_sync(cls, token: str):
return async_to_sync(cls.remove_session)(token)
15 changes: 10 additions & 5 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from starlette.testclient import TestClient

from piccolo_api.session_auth.tables import SessionsBase
from piccolo_api.session_auth.endpoints import session_login
from piccolo_api.session_auth.endpoints import session_login, session_logout
from piccolo_api.session_auth.middleware import SessionsAuthBackend


Expand Down Expand Up @@ -42,21 +42,26 @@ def get(self, request):


class ProtectedEndpoint(HTTPEndpoint):
@requires("authenticated")
@requires("authenticated", redirect="login")
def get(self, request):
return PlainTextResponse("top secret")


ROUTER = Router(
routes=[
Route("/", HomeEndpoint),
Route("/", HomeEndpoint, name="home"),
Route(
"/login/", session_login(auth_table=User, session_table=Sessions)
"/login/",
session_login(auth_table=User, session_table=Sessions),
name="login",
),
Route(
"/logout/", session_logout(session_table=Sessions), name="login"
),
Mount(
"/secret",
AuthenticationMiddleware(
Route("/", ProtectedEndpoint),
ProtectedEndpoint,
SessionsAuthBackend(auth_table=User, session_table=Sessions),
),
),
Expand Down

0 comments on commit 888e94a

Please sign in to comment.