Skip to content

Commit

Permalink
support async validators in PiccoloCRUD (#203)
Browse files Browse the repository at this point in the history
* support async validators

* fix tests, and increase minimum fastapi / starlette version

* ignore httpx missing imports mypy error

* add explicit httpx dependency back

Starlette has this is an optional dependency, so still need to explicitly specify it.
  • Loading branch information
dantownsend committed Nov 15, 2022
1 parent 5c3569e commit deaf311
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 76 deletions.
63 changes: 47 additions & 16 deletions piccolo_api/crud/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
import inspect
import typing as t

from piccolo.utils.sync import run_sync
from starlette.exceptions import HTTPException
from starlette.requests import Request

if t.TYPE_CHECKING: # pragma: no cover
from .endpoints import PiccoloCRUD


ValidatorFunction = t.Callable[["PiccoloCRUD", Request], None]
ValidatorFunction = t.Callable[
["PiccoloCRUD", Request], t.Union[t.Coroutine, None]
]


class Validators:
Expand All @@ -23,6 +26,24 @@ class Validators:
Starlette ``Request`` instance, and should raise a Starlette
``HTTPException`` if there is a problem.
Async functions are also supported. Here are some examples:
.. code-block:: python
def validator_1(piccolo_crud: PiccoloCRUD, request: Request):
if not request.user.user.superuser:
raise HTTPException(
status_code=403,
"Only a superuser can do this"
)
async def validator_2(piccolo_crud: PiccoloCRUD, request: Request):
if not await my_check_user_function(request.user.user):
raise HTTPException(
status_code=403,
"The user can't do this."
)
"""

def __init__(
Expand Down Expand Up @@ -64,7 +85,7 @@ def apply_validators(function):
:class:`PiccoloCRUD`.
"""

def run_validators(*args, **kwargs):
async def run_validators(*args, **kwargs):
piccolo_crud: PiccoloCRUD = args[0]
validators = piccolo_crud.validators

Expand All @@ -81,29 +102,39 @@ def run_validators(*args, **kwargs):
if validator_functions and request:
for validator_function in validator_functions:
try:
validator_function(
request=request,
piccolo_crud=piccolo_crud,
**validators.extra_context,
)
if inspect.iscoroutinefunction(validator_function):
await validator_function(
request=request,
piccolo_crud=piccolo_crud,
**validators.extra_context,
)
else:
validator_function(
request=request,
piccolo_crud=piccolo_crud,
**validators.extra_context,
)
except HTTPException as exception:
raise exception
except Exception:
raise HTTPException(
status_code=400, detail="Validation error"
)

@functools.wraps(function)
async def inner_coroutine_function(*args, **kwargs):
run_validators(*args, **kwargs)
return await function(*args, **kwargs)
if inspect.iscoroutinefunction(function):

@functools.wraps(function)
def inner_function(*args, **kwargs):
run_validators(*args, **kwargs)
return function(*args, **kwargs)
@functools.wraps(function)
async def inner_coroutine_function(*args, **kwargs):
await run_validators(*args, **kwargs)
return await function(*args, **kwargs)

if inspect.iscoroutinefunction(function):
return inner_coroutine_function

else:

@functools.wraps(function)
def inner_function(*args, **kwargs):
run_sync(run_validators(*args, **kwargs))
return function(*args, **kwargs)

return inner_function
2 changes: 1 addition & 1 deletion piccolo_api/csp/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def wrapped_send(message: Message):
+ b"; report-uri "
+ self.config.report_uri
)
headers.append([b"Content-Security-Policy", header_value])
headers.append([b"content-security-policy", header_value])
message["headers"] = headers

await send(message)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ module = [
"moto",
"botocore",
"botocore.config",
"httpx"
]
ignore_missing_imports = true

Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ Jinja2>=2.11.0
piccolo[postgres]>=0.89.0
pydantic[email]>=1.6
python-multipart>=0.0.5
fastapi>=0.65.2
fastapi>=0.87.0
PyJWT>=2.0.0
httpx>=0.20.0
55 changes: 41 additions & 14 deletions tests/crud/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,37 @@ def validator_1(*args, **kwargs):
def validator_2(*args, **kwargs):
raise HTTPException(status_code=401, detail="Denied!")

for scenario in [
Scenario(
validators=[validator_1],
status_code=400,
content=b"Validation error",
),
Scenario(
validators=[validator_2],
status_code=401,
content=b"Denied!",
),
]:
async def validator_3(*args, **kwargs):
raise ValueError("Error!")

async def validator_4(*args, **kwargs):
raise HTTPException(status_code=401, detail="Async denied!")

for index, scenario in enumerate(
[
Scenario(
validators=[validator_1],
status_code=400,
content=b"Validation error",
),
Scenario(
validators=[validator_2],
status_code=401,
content=b"Denied!",
),
Scenario(
validators=[validator_3],
status_code=400,
content=b"Validation error",
),
Scenario(
validators=[validator_4],
status_code=401,
content=b"Async denied!",
),
],
start=1,
):
client = TestClient(
ExceptionMiddleware(
PiccoloCRUD(
Expand All @@ -71,5 +90,13 @@ def validator_2(*args, **kwargs):
)

response = client.get("/")
self.assertEqual(response.status_code, scenario.status_code)
self.assertEqual(response.content, scenario.content)
self.assertEqual(
response.status_code,
scenario.status_code,
msg=f"Scenario {index} failed!",
)
self.assertEqual(
response.content,
scenario.content,
msg=f"Scenario {index} failed!",
)
8 changes: 5 additions & 3 deletions tests/csp/test_csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ def test_headers(self):
client = TestClient(wrapped_app)
response = client.request("GET", "/")

header_names = response.headers.keys()

# Make sure the headers got added:
self.assertTrue("Content-Security-Policy" in response.headers.keys())
self.assertTrue("content-security-policy" in header_names)

# Make sure the original headers are still intact:
self.assertTrue("content-type" in response.headers.keys())
self.assertTrue("content-type" in header_names)

def test_report_uri(self):
wrapped_app = CSPMiddleware(
Expand All @@ -44,5 +46,5 @@ def test_report_uri(self):
client = TestClient(wrapped_app)
response = client.request("GET", "/")

header = response.headers["Content-Security-Policy"]
header = response.headers["content-security-policy"]
self.assertTrue("report-uri" in header)

0 comments on commit deaf311

Please sign in to comment.