Skip to content

Commit

Permalink
added LoginHooks (#138)
Browse files Browse the repository at this point in the history
* added `LoginHooks`

* Update doc-requirements.txt

* add `html_short_title` to conf.py

* add docs for `LoginHooks`

* add methods for running hooks

* call hooks

* add suggested changes, and change login_success signature

Was passing `BaseUser` to `login_success` hook, but could be wasteful because it requires an extra SQL query, and the end user might not even need it (username and user_id might be sufficient).

* fix example
  • Loading branch information
dantownsend committed Jun 3, 2022
1 parent 7e24d16 commit 0329afd
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 14 deletions.
9 changes: 9 additions & 0 deletions docs/source/api_reference/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
API Reference
=============

Auth
----

.. currentmodule:: piccolo_api.shared.auth.hooks

.. autoclass:: LoginHooks
9 changes: 7 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@
autodoc_typehints = "signature"
autodoc_typehints_format = "short"
autoclass_content = "both"
autodoc_type_aliases = {"ASGIApp": "ASGIApp"}
html_short_title = "Piccolo API"
autodoc_type_aliases = {
"ASGIApp": "ASGIApp",
"PreLoginHook": "PreLoginHook",
"LoginSuccessHook": "LoginSuccessHook",
"LoginFailureHook": "LoginFailureHook",
}

# -- Intersphinx -------------------------------------------------------------

Expand All @@ -60,3 +64,4 @@
# a list of builtin themes.
#
html_theme = "piccolo_theme"
html_short_title = "Piccolo API"
6 changes: 6 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ ASGI app, covering authentication, security, and more.

./contributing/index

.. toctree::
:caption: API Reference
:maxdepth: 1

./api_reference/index

.. toctree::
:caption: Changes
:maxdepth: 1
Expand Down
79 changes: 70 additions & 9 deletions piccolo_api/session_auth/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from starlette.status import HTTP_303_SEE_OTHER

from piccolo_api.session_auth.tables import SessionsBase
from piccolo_api.shared.auth.hooks import LoginHooks

if t.TYPE_CHECKING: # pragma: no cover
from jinja2 import Template
Expand Down Expand Up @@ -51,7 +52,7 @@ def _redirect_to(self) -> t.Optional[str]:
def _logout_template(self) -> Template:
raise NotImplementedError

def render_template(
def _render_template(
self, request: Request, template_context: t.Dict[str, t.Any] = {}
) -> HTMLResponse:
# If CSRF middleware is present, we have to include a form field with
Expand All @@ -71,7 +72,7 @@ def render_template(
)

async def get(self, request: Request) -> HTMLResponse:
return self.render_template(request)
return self._render_template(request)

async def post(self, request: Request) -> Response:
cookie = request.cookies.get(self._cookie_name, None)
Expand Down Expand Up @@ -131,8 +132,15 @@ def _production(self) -> bool:
def _login_template(self) -> Template:
raise NotImplementedError

def render_template(
self, request: Request, template_context: t.Dict[str, t.Any] = {}
@abstractproperty
def _hooks(self) -> t.Optional[LoginHooks]:
raise NotImplementedError

def _render_template(
self,
request: Request,
template_context: t.Dict[str, t.Any] = {},
status_code=200,
) -> HTMLResponse:
# If CSRF middleware is present, we have to include a form field with
# the CSRF token. It only works if CSRFMiddleware has
Expand All @@ -147,11 +155,24 @@ def render_template(
csrf_cookie_name=csrf_cookie_name,
request=request,
**template_context,
)
),
status_code=status_code,
)

def _get_error_response(
self, request, error: str, response_format: t.Literal["html", "plain"]
) -> Response:
if response_format == "html":
return self._render_template(
request, template_context={"error": error}, status_code=401
)
else:
return PlainTextResponse(
status_code=401, content=f"Login failed: {error}"
)

async def get(self, request: Request) -> HTMLResponse:
return self.render_template(request)
return self._render_template(request)

async def post(self, request: Request) -> Response:
# Some middleware (for example CSRF) has already awaited the request
Expand All @@ -166,19 +187,54 @@ async def post(self, request: Request) -> Response:

username = body.get("username", None)
password = body.get("password", None)
return_html = body.get("format") == "html"

if (not username) or (not password):
raise HTTPException(
status_code=401, detail="Missing username or password"
)

# Run pre_login hooks
if self._hooks and self._hooks.pre_login:
hooks_response = await self._hooks.run_pre_login(username=username)
if isinstance(hooks_response, str):
return self._get_error_response(
request=request,
error=hooks_response,
response_format="html" if return_html else "plain",
)

user_id = await self._auth_table.login(
username=username, password=password
)

if not user_id:
if body.get("format") == "html":
return self.render_template(
if user_id:
# Run login_success hooks
if self._hooks and self._hooks.login_success:
hooks_response = await self._hooks.run_login_success(
username=username, user_id=user_id
)
if isinstance(hooks_response, str):
return self._get_error_response(
request=request,
error=hooks_response,
response_format="html" if return_html else "plain",
)
else:
# Run login_failure hooks
if self._hooks and self._hooks.login_failure:
hooks_response = await self._hooks.run_login_failure(
username=username
)
if isinstance(hooks_response, str):
return self._get_error_response(
request=request,
error=hooks_response,
response_format="html" if return_html else "plain",
)

if return_html:
return self._render_template(
request,
template_context={
"error": "The username or password is incorrect."
Expand Down Expand Up @@ -235,6 +291,7 @@ def session_login(
production: bool = False,
cookie_name: str = "id",
template_path: t.Optional[str] = None,
hooks: t.Optional[LoginHooks] = None,
) -> t.Type[SessionLoginEndpoint]:
"""
An endpoint for creating a user session.
Expand Down Expand Up @@ -266,6 +323,9 @@ def session_login(
``'/some_directory/login.html'``. Refer to the default template at
``piccolo_api/templates/session_login.html`` as a basis for your
custom template.
:param hooks:
Allows you to run custom logic at various points in the login process.
See :class:`LoginHooks <piccolo_api.shared.auth.hooks.LoginHooks>`.
""" # noqa: E501
template_path = (
Expand All @@ -285,6 +345,7 @@ class _SessionLoginEndpoint(SessionLoginEndpoint):
_production = production
_cookie_name = cookie_name
_login_template = login_template
_hooks = hooks

return _SessionLoginEndpoint

Expand Down
135 changes: 135 additions & 0 deletions piccolo_api/shared/auth/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from __future__ import annotations

import dataclasses
import inspect
import logging
import typing as t

PreLoginHook = t.Union[
t.Callable[[str], t.Optional[str]],
t.Callable[[str], t.Awaitable[t.Optional[str]]],
]
LoginSuccessHook = t.Union[
t.Callable[[str, int], t.Optional[str]],
t.Callable[[str, int], t.Awaitable[t.Optional[str]]],
]
LoginFailureHook = t.Union[
t.Callable[[str], t.Optional[str]],
t.Callable[[str], t.Awaitable[t.Optional[str]]],
]


logger = logging.getLogger(__file__)


@dataclasses.dataclass
class LoginHooks:
"""
Allows you to run custom logic during login. A hook can be a function or
coroutine.
Here's an example using :class:`session_login <piccolo_api.session_auth.endpoints.session_login>`:
.. code-block:: python
def check_ban_list(username: str, **kwargs):
'''
An example `pre_login` hook.
'''
if username in ('nuisance', 'pest'):
return 'This account has been temporarily suspended'.
async def log_success(username: str, user_id: int, **kwargs):
'''
An example `login_success` hook.
'''
await my_logging_service.record(
f'{username} just logged in'
)
async def log_failure(username: str, **kwargs):
'''
An example `login_failure` hook.
'''
await my_logging_service.record(f'{username} could not login')
return (
'To reset your password go <a href="/password-reset/">here</a>.'
)
login_endpoint = session_login(
hooks=LoginHooks(
pre_login=[check_ban_list],
login_success=[log_success],
login_failure=[log_failure],
)
)
If any of the hooks return a string, the login process is aborted, and the
login template is shown again, containing the string as a warning message.
The string can contain HTML such as links, and it will be rendered
correctly.
All of the example hooks above accept ``**kwargs`` - this is recommended
just in case more data is passed to the hooks in future Piccolo API
versions.
:param pre_login:
A list of function and / or coroutines, which accept the username as a
string.
:param login_success:
A list of function and / or coroutines, which accept the username as a
string, and the user ID as an integer. If a string is returned, the
login process stops before a session is created.
:param login_failure:
A list of function and / or coroutines, which accept the username as a
string.
""" # noqa: E501

pre_login: t.Optional[t.List[PreLoginHook]] = None
login_success: t.Optional[t.List[LoginSuccessHook]] = None
login_failure: t.Optional[t.List[LoginFailureHook]] = None

async def run_pre_login(self, username: str) -> t.Optional[str]:
if self.pre_login:
for hook in self.pre_login:
response = hook(username)
if inspect.isawaitable(response):
response = t.cast(t.Awaitable, response)
response = await response

if isinstance(response, str):
return response

return None

async def run_login_success(
self, username: str, user_id: int
) -> t.Optional[str]:
if self.login_success:
for hook in self.login_success:
response = hook(username, user_id)
if inspect.isawaitable(response):
response = t.cast(t.Awaitable, response)
response = await response

if isinstance(response, str):
return response

return None

async def run_login_failure(self, username: str) -> t.Optional[str]:
if self.login_failure:
for hook in self.login_failure:
response = hook(username)
if inspect.isawaitable(response):
response = t.cast(t.Awaitable, response)
response = await response

if isinstance(response, str):
return response

return None
5 changes: 2 additions & 3 deletions requirements/doc-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
Sphinx==4.4.0
piccolo-theme>=0.5.0
sphinx-rtd-theme==1.0.0
Sphinx==4.5.0
piccolo-theme>=0.9.0
livereload==2.6.3

0 comments on commit 0329afd

Please sign in to comment.