Skip to content

Commit

Permalink
Request Context Inside Hook Callbacks (#167)
Browse files Browse the repository at this point in the history
* Found it could be useful to have access to the request context inside of hooks.
Here is an implementation that gives that option. Wrote tests for it. Hope this is suitable.

* Passing request context to hook callbacks. Here is an implementation that gives that option. Wrote tests for it. Hope this is suitable.
Ran lint script

* Fixed the helper function name in test_hooks

* fix mypy error

* add docs for dependency injection

Co-authored-by: Daniel Townsend <dan@dantownsend.co.uk>
  • Loading branch information
AnthonyArmour and dantownsend committed Aug 8, 2022
1 parent c79baf6 commit b22a618
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 10 deletions.
13 changes: 12 additions & 1 deletion docs/source/crud/hooks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pre_save
~~~~~~~~

This hook runs during POST requests, prior to inserting data into the database.
It takes a single parameter, ``row``, and should return the same:
It takes a single parameter, ``row``, and should return the row:

.. code-block:: python
Expand Down Expand Up @@ -131,6 +131,17 @@ It takes one parameter, ``row_id`` which is the id of the row to be deleted.
]
)
Dependency injection
~~~~~~~~~~~~~~~~~~~~

Each hook can optionally receive the ``Starlette`` request object. Just
add ``request`` as an argument in your hook, and it'll be injected automatically.

.. code-block:: python
async def set_movie_rating_10(row: Movie, request: Request):
...
-------------------------------------------------------------------------------

Source
Expand Down
7 changes: 6 additions & 1 deletion piccolo_api/crud/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,10 @@ async def post_single(
row = self.table(**model.dict())
if self._hook_map:
row = await execute_post_hooks(
hooks=self._hook_map, hook_type=HookType.pre_save, row=row
hooks=self._hook_map,
hook_type=HookType.pre_save,
row=row,
request=request,
)
response = await row.save().run()
json = dump_json(response)
Expand Down Expand Up @@ -1054,6 +1057,7 @@ async def patch_single(
hook_type=HookType.pre_patch,
row_id=row_id,
values=values,
request=request,
)

try:
Expand Down Expand Up @@ -1083,6 +1087,7 @@ async def delete_single(
hooks=self._hook_map,
hook_type=HookType.pre_delete,
row_id=row_id,
request=request,
)

try:
Expand Down
45 changes: 37 additions & 8 deletions piccolo_api/crud/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from enum import Enum

from piccolo.table import Table
from starlette.requests import Request


class HookType(Enum):
Expand All @@ -22,13 +23,23 @@ def __init__(self, hook_type: HookType, callable: t.Callable) -> None:


async def execute_post_hooks(
hooks: t.Dict[HookType, t.List[Hook]], hook_type: HookType, row: Table
hooks: t.Dict[HookType, t.List[Hook]],
hook_type: HookType,
row: Table,
request: Request,
):
for hook in hooks.get(hook_type, []):
signature = inspect.signature(hook.callable)
kwargs: t.Dict[str, t.Any] = dict(row=row)
# Include request in hook call arguments if possible
if {i for i in signature.parameters.keys()}.intersection(
{"kwargs", "request"}
):
kwargs.update(request=request)
if inspect.iscoroutinefunction(hook.callable):
row = await hook.callable(row)
row = await hook.callable(**kwargs)
else:
row = hook.callable(row)
row = hook.callable(**kwargs)
return row


Expand All @@ -37,20 +48,38 @@ async def execute_patch_hooks(
hook_type: HookType,
row_id: t.Any,
values: t.Dict[t.Any, t.Any],
request: Request,
) -> t.Dict[t.Any, t.Any]:
for hook in hooks.get(hook_type, []):
signature = inspect.signature(hook.callable)
kwargs = dict(row_id=row_id, values=values)
# Include request in hook call arguments if possible
if {i for i in signature.parameters.keys()}.intersection(
{"kwargs", "request"}
):
kwargs.update(request=request)
if inspect.iscoroutinefunction(hook.callable):
values = await hook.callable(row_id=row_id, values=values)
values = await hook.callable(**kwargs)
else:
values = hook.callable(row_id=row_id, values=values)
values = hook.callable(**kwargs)
return values


async def execute_delete_hooks(
hooks: t.Dict[HookType, t.List[Hook]], hook_type: HookType, row_id: t.Any
hooks: t.Dict[HookType, t.List[Hook]],
hook_type: HookType,
row_id: t.Any,
request: Request,
):
for hook in hooks.get(hook_type, []):
signature = inspect.signature(hook.callable)
kwargs = dict(row_id=row_id)
# Include request in hook call arguments if possible
if {i for i in signature.parameters.keys()}.intersection(
{"kwargs", "request"}
):
kwargs.update(request=request)
if inspect.iscoroutinefunction(hook.callable):
await hook.callable(row_id=row_id)
await hook.callable(**kwargs)
else:
hook.callable(row_id=row_id)
hook.callable(**kwargs)
113 changes: 113 additions & 0 deletions tests/crud/test_hooks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest import TestCase

from fastapi import Request
from piccolo.columns import Integer, Varchar
from piccolo.columns.readable import Readable
from piccolo.table import Table
Expand Down Expand Up @@ -39,17 +40,62 @@ async def look_up_existing(row_id: int, values: dict):
return values


async def add_additional_name_details(
row_id: int, values: dict, request: Request
):
director = request.query_params.get("director_name", "")
values["name"] = values["name"] + f" ({director})"
return values


async def additional_name_details(row: Movie, request: Request):
director = request.query_params.get("director_name", "")
row["name"] = f"{row.name} ({director})"
return row


async def raises_exception(row_id: int, request: Request):
if request.query_params.get("director_name", False):
raise Exception("Test Passed")


async def failing_hook(row_id: int):
raise Exception("hook failed")


# TODO - add test for a non-async hook.
class TestPostHooks(TestCase):
def setUp(self):
Movie.create_table(if_not_exists=True).run_sync()

def tearDown(self):
Movie.alter().drop_table().run_sync()

def test_request_context_passed_to_post_hook(self):
"""
Make sure request context can be passed to post hook
callable
"""
client = TestClient(
PiccoloCRUD(
table=Movie,
read_only=False,
hooks=[
Hook(
hook_type=HookType.pre_save,
callable=additional_name_details,
)
],
)
)
json_req = {
"name": "Star Wars",
"rating": 93,
}
_ = client.post("/", json=json_req, params={"director_name": "George"})
movie = Movie.objects().first().run_sync()
self.assertEqual(movie.name, "Star Wars (George)")

def test_single_pre_post_hook(self):
"""
Make sure single hook executes
Expand Down Expand Up @@ -96,6 +142,47 @@ def test_multi_pre_post_hooks(self):
movie = Movie.objects().first().run_sync()
self.assertEqual(movie.rating, 20)

def test_request_context_passed_to_patch_hook(self):
"""
Make sure request context can be passed to patch hook
callable
"""
client = TestClient(
PiccoloCRUD(
table=Movie,
read_only=False,
hooks=[
Hook(
hook_type=HookType.pre_patch,
callable=add_additional_name_details,
)
],
)
)

movie = Movie(name="Star Wars", rating=93)
movie.save().run_sync()

new_name = "Star Wars: A New Hope"
new_name_modified = new_name + " (George)"

json_req = {
"name": new_name,
}

response = client.patch(
f"/{movie.id}/", json=json_req, params={"director_name": "George"}
)
self.assertTrue(response.status_code == 200)

# Make sure the row is returned:
response_json = response.json()
self.assertTrue(response_json["name"] == new_name_modified)

# Make sure the underlying database row was changed:
movies = Movie.select().run_sync()
self.assertTrue(movies[0]["name"] == new_name_modified)

def test_pre_patch_hook(self):
"""
Make sure pre_patch hook executes successfully
Expand Down Expand Up @@ -159,6 +246,32 @@ def test_pre_patch_hook_db_lookup(self):
movies = Movie.select().run_sync()
self.assertTrue(movies[0]["name"] == original_name)

def test_request_context_passed_to_delete_hook(self):
"""
Make sure request context can be passed to patch hook
callable
"""
client = TestClient(
PiccoloCRUD(
table=Movie,
read_only=False,
hooks=[
Hook(
hook_type=HookType.pre_delete,
callable=raises_exception,
)
],
)
)

movie = Movie(name="Star Wars", rating=10)
movie.save().run_sync()

with self.assertRaises(Exception, msg="Test Passed"):
_ = client.delete(
f"/{movie.id}/", params={"director_name": "George"}
)

def test_delete_hook_fails(self):
"""
Make sure failing pre_delete hook bubbles up
Expand Down

0 comments on commit b22a618

Please sign in to comment.