Skip to content

Commit

Permalink
add db_exception_handler (#187)
Browse files Browse the repository at this point in the history
* add `db_exception_handler`

* add link to original GitHub issue

* add functools.wraps

* add logging

* added tests
  • Loading branch information
dantownsend committed Sep 2, 2022
1 parent 10d9834 commit 4b14eba
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 4 deletions.
5 changes: 4 additions & 1 deletion piccolo_api/crud/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
execute_post_hooks,
)

from .exceptions import MalformedQuery
from .exceptions import MalformedQuery, db_exception_handler
from .serializers import Config, create_pydantic_model
from .validators import Validators, apply_validators

Expand Down Expand Up @@ -794,6 +794,7 @@ def _clean_data(self, data: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
return cleaned_data

@apply_validators
@db_exception_handler
async def post_single(
self, request: Request, data: t.Dict[str, t.Any]
) -> Response:
Expand Down Expand Up @@ -1005,6 +1006,7 @@ async def get_single(self, request: Request, row_id: PK_TYPES) -> Response:
)

@apply_validators
@db_exception_handler
async def put_single(
self, request: Request, row_id: PK_TYPES, data: t.Dict[str, t.Any]
) -> Response:
Expand Down Expand Up @@ -1034,6 +1036,7 @@ async def put_single(
return Response("Unable to save the resource.", status_code=500)

@apply_validators
@db_exception_handler
async def patch_single(
self, request: Request, row_id: PK_TYPES, data: t.Dict[str, t.Any]
) -> Response:
Expand Down
60 changes: 60 additions & 0 deletions piccolo_api/crud/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
import functools
import logging
import typing as t
from sqlite3 import IntegrityError

from starlette.responses import JSONResponse

try:
# We can't be sure that asyncpg is installed, hence why it's in a
# try / except.
from asyncpg.exceptions import UniqueViolationError
except ImportError:

class UniqueViolationError(Exception): # type: ignore
pass


logger = logging.getLogger(__file__)


class MalformedQuery(Exception):
"""
Raised when the query is malformed - for example, the column names are
Expand All @@ -6,3 +26,43 @@ class MalformedQuery(Exception):
"""

pass


def db_exception_handler(func: t.Callable[..., t.Coroutine]):
"""
A decorator which wraps an endpoint, and converts database exceptions
into HTTP responses.
Eventually we will add generic database exceptions to Piccolo, so each
database adapter raises the same exceptions.
For now though, we handle the exceptions from each database adapter.
It's very important that we catch unique constraint errors, as these are
very commmon, and make a poor user experience if the user just sees a
generic 500 error instead of a useful message like 'Field X is not unique'.
https://github.com/piccolo-orm/piccolo_admin/issues/167
"""

@functools.wraps(func)
async def inner(*args, **kwargs):
try:
return await func(*args, **kwargs)
except IntegrityError as exception:
logger.exception("SQLite integrity error")
return JSONResponse(
{"db_error": exception.__str__()},
status_code=422,
)
except UniqueViolationError as exception:
logger.exception("Asyncpg unique violation")
return JSONResponse(
{
"db_error": exception.message,
},
status_code=422,
)

return inner
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ line_length = 79
[[tool.mypy.overrides]]
module = [
"asyncpg.pgproto.pgproto",
"asyncpg.exceptions",
"jinja2",
"uvicorn",
"jwt",
Expand Down
102 changes: 99 additions & 3 deletions tests/crud/test_crud_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from unittest import TestCase

from piccolo.columns import Email, ForeignKey, Integer, Secret, Varchar
from piccolo.columns import Email, ForeignKey, Integer, Secret, Text, Varchar
from piccolo.columns.readable import Readable
from piccolo.table import Table
from starlette.datastructures import QueryParams
Expand Down Expand Up @@ -39,6 +39,11 @@ class Studio(Table):
booking_email = Email(default="booking@studio.com")


class Cinema(Table):
name = Varchar()
address = Text(unique=True)


class TestGetVisibleFieldsOptions(TestCase):
def test_without_joins(self):
response = get_visible_fields_options(table=Role, max_joins=0)
Expand Down Expand Up @@ -880,7 +885,7 @@ def setUp(self):
def tearDown(self):
Movie.alter().drop_table().run_sync()

def test_post(self):
def test_success(self):
"""
Make sure a post can create rows successfully.
"""
Expand All @@ -897,7 +902,7 @@ def test_post(self):
self.assertTrue(movie.name == json["name"])
self.assertTrue(movie.rating == json["rating"])

def test_post_error(self):
def test_validation_error(self):
"""
Make sure a post returns a validation error with incorrect or missing
data.
Expand All @@ -911,6 +916,97 @@ def test_post_error(self):
self.assertTrue(Movie.count().run_sync() == 0)


class TestDBExceptionHandler(TestCase):
"""
Make sure that if a unique constraint fails, we get a useful message
back, and not a 500 error.
"""

def setUp(self):
Cinema.create_table(if_not_exists=True).run_sync()

self.cinema_1 = (
Cinema.objects()
.create(
name="Odeon",
address="Leicester Square, London",
)
.run_sync()
)

self.cinema_2 = (
Cinema.objects()
.create(
name="Grauman's Chinese Theatre",
address="6925 Hollywood Boulevard, Hollywood",
)
.run_sync()
)

def tearDown(self):
Cinema.alter().drop_table().run_sync()

def test_post(self):
client = TestClient(PiccoloCRUD(table=Cinema, read_only=False))

# Test error
response = client.post(
"/",
json={"name": "Odeon 2", "address": self.cinema_1.address},
)
self.assertEqual(response.status_code, 422)
self.assertTrue("db_error" in response.json())

# Test success
response = client.post(
"/",
json={"name": "Odeon 2", "address": "A new address"},
)
self.assertEqual(response.status_code, 201)

def test_patch(self):
client = TestClient(PiccoloCRUD(table=Cinema, read_only=False))

# Test error
response = client.patch(
f"/{self.cinema_1.id}/",
json={"address": self.cinema_2.address},
)
self.assertEqual(response.status_code, 422)
self.assertTrue("db_error" in response.json())

# Test success
response = client.patch(
f"/{self.cinema_1.id}/",
json={"address": "A new address"},
)
self.assertEqual(response.status_code, 200)

def test_put(self):
client = TestClient(PiccoloCRUD(table=Cinema, read_only=False))

# Test error
response = client.put(
f"/{self.cinema_1.id}/",
json={
"name": self.cinema_1.name,
"address": self.cinema_2.address,
},
)
self.assertEqual(response.status_code, 422)
self.assertTrue("db_error" in response.json())

# Test success
response = client.put(
f"/{self.cinema_1.id}/",
json={
"name": "New cinema",
"address": "A new address",
},
)
self.assertEqual(response.status_code, 204)


class TestGet(TestCase):
def setUp(self):
for table in (Movie, Role):
Expand Down

0 comments on commit 4b14eba

Please sign in to comment.