Skip to content

Commit

Permalink
Fixes for custom pk (#130)
Browse files Browse the repository at this point in the history
* make ids endpoint work with UUID primary key columns

* more custom pk fixes

The detail endpoints now work if non-integers are passed in

* more tests

* more tests

* make sure custom primary keys are excluded from pydantic model

* add a test for invalid ID
  • Loading branch information
dantownsend committed Jan 11, 2022
1 parent 26d0adc commit 83fe185
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 20 deletions.
39 changes: 27 additions & 12 deletions piccolo_api/crud/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import logging
import typing as t
import uuid
from collections import defaultdict
from dataclasses import dataclass, field

Expand Down Expand Up @@ -59,6 +60,8 @@

MATCH_TYPES = ("contains", "exact", "starts", "ends")

PK_TYPES = t.Union[str, uuid.UUID, int]


class CustomJSONResponse(Response):
media_type = "application/json"
Expand Down Expand Up @@ -229,13 +232,6 @@ def __init__(

routes: t.List[BaseRoute] = [
Route(path="/", endpoint=self.root, methods=root_methods),
Route(
path="/{row_id:int}/",
endpoint=self.detail,
methods=["GET"]
if read_only
else ["GET", "PUT", "DELETE", "PATCH"],
),
Route(path="/schema/", endpoint=self.get_schema, methods=["GET"]),
Route(path="/ids/", endpoint=self.get_ids, methods=["GET"]),
Route(path="/count/", endpoint=self.get_count, methods=["GET"]),
Expand All @@ -250,6 +246,13 @@ def __init__(
endpoint=self.update_password,
methods=["PUT"],
),
Route(
path="/{row_id:str}/",
endpoint=self.detail,
methods=["GET"]
if read_only
else ["GET", "PUT", "DELETE", "PATCH"],
),
]

super().__init__(routes=routes)
Expand All @@ -264,6 +267,7 @@ def pydantic_model(self) -> t.Type[pydantic.BaseModel]:
return create_pydantic_model(
self.table,
model_name=f"{self.table.__name__}In",
exclude_columns=(self.table._meta.primary_key,),
**self.schema_extra,
)

Expand Down Expand Up @@ -415,7 +419,11 @@ async def get_ids(self, request: Request) -> Response:
query = query.limit(limit).offset(offset)

values = await query.run()
return JSONResponse({i["id"]: i["readable"] for i in values})

if self.table._meta.primary_key.value_type not in (int, str):
return JSONResponse({str(i["id"]): i["readable"] for i in values})
else:
return JSONResponse({i["id"]: i["readable"] for i in values})

###########################################################################

Expand Down Expand Up @@ -852,6 +860,11 @@ async def detail(self, request: Request) -> Response:
if row_id is None:
return Response("Missing ID parameter.", status_code=404)

try:
row_id = self.table._meta.primary_key.value_type(row_id)
except ValueError:
return Response("The ID is invalid", status_code=400)

if (
not await self.table.exists()
.where(self.table._meta.primary_key == row_id)
Expand Down Expand Up @@ -910,7 +923,7 @@ def _parse_visible_fields(self, visible_fields: str) -> t.List[Column]:
return visible_columns

@apply_validators
async def get_single(self, request: Request, row_id: int) -> Response:
async def get_single(self, request: Request, row_id: PK_TYPES) -> Response:
"""
Returns a single row.
"""
Expand Down Expand Up @@ -973,7 +986,7 @@ async def get_single(self, request: Request, row_id: int) -> Response:

@apply_validators
async def put_single(
self, request: Request, row_id: int, data: t.Dict[str, t.Any]
self, request: Request, row_id: PK_TYPES, data: t.Dict[str, t.Any]
) -> Response:
"""
Replaces an existing row. We don't allow new resources to be created.
Expand Down Expand Up @@ -1002,7 +1015,7 @@ async def put_single(

@apply_validators
async def patch_single(
self, request: Request, row_id: int, data: t.Dict[str, t.Any]
self, request: Request, row_id: PK_TYPES, data: t.Dict[str, t.Any]
) -> Response:
"""
Patch a single row.
Expand Down Expand Up @@ -1049,7 +1062,9 @@ async def patch_single(
return Response("Unable to save the resource.", status_code=500)

@apply_validators
async def delete_single(self, request: Request, row_id: int) -> Response:
async def delete_single(
self, request: Request, row_id: PK_TYPES
) -> Response:
"""
Deletes a single row.
"""
Expand Down
16 changes: 8 additions & 8 deletions piccolo_api/fastapi/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,14 +283,14 @@ async def post(request: Request, model):
#######################################################################
# Detail - GET

async def get_single(row_id: int, request: Request):
async def get_single(row_id: str, request: Request):
"""
Retrieve a single row from the table.
"""
return await piccolo_crud.detail(request=request)

fastapi_app.add_api_route(
path=self.join_urls(root_url, "/{row_id:int}/"),
path=self.join_urls(root_url, "/{row_id:str}/"),
endpoint=get_single,
response_model=self.ModelOut,
methods=["GET"],
Expand All @@ -302,14 +302,14 @@ async def get_single(row_id: int, request: Request):

if not piccolo_crud.read_only:

async def delete_single(row_id: int, request: Request):
async def delete_single(row_id: str, request: Request):
"""
Delete a single row from the table.
"""
return await piccolo_crud.detail(request=request)

fastapi_app.add_api_route(
path=self.join_urls(root_url, "/{row_id:int}/"),
path=self.join_urls(root_url, "/{row_id:str}/"),
endpoint=delete_single,
response_model=None,
methods=["DELETE"],
Expand All @@ -321,7 +321,7 @@ async def delete_single(row_id: int, request: Request):

if not piccolo_crud.read_only:

async def put(row_id: int, request: Request, model):
async def put(row_id: str, request: Request, model):
"""
Insert or update a single row.
"""
Expand All @@ -332,7 +332,7 @@ async def put(row_id: int, request: Request, model):
] = f"ANNOTATIONS['{self.alias}']['ModelIn']"

fastapi_app.add_api_route(
path=self.join_urls(root_url, "/{row_id:int}/"),
path=self.join_urls(root_url, "/{row_id:str}/"),
endpoint=put,
response_model=self.ModelOut,
methods=["PUT"],
Expand All @@ -344,7 +344,7 @@ async def put(row_id: int, request: Request, model):

if not piccolo_crud.read_only:

async def patch(row_id: int, request: Request, model):
async def patch(row_id: str, request: Request, model):
"""
Update a single row.
"""
Expand All @@ -355,7 +355,7 @@ async def patch(row_id: int, request: Request, model):
] = f"ANNOTATIONS['{self.alias}']['ModelOptional']"

fastapi_app.add_api_route(
path=self.join_urls(root_url, "/{row_id:int}/"),
path=self.join_urls(root_url, "/{row_id:str}/"),
endpoint=patch,
response_model=self.ModelOut,
methods=["PATCH"],
Expand Down
92 changes: 92 additions & 0 deletions tests/crud/test_custom_pk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from unittest import TestCase

from piccolo.columns.column_types import UUID, Integer, Varchar
from piccolo.table import Table
from starlette.testclient import TestClient

from piccolo_api.crud.endpoints import PiccoloCRUD


class Movie(Table):
id = UUID(primary_key=True)
name = Varchar(length=100, required=True)
rating = Integer()


class TestCustomPK(TestCase):
"""
Make sure PiccoloCRUD works with Tables with a custom primary key column.
"""

def setUp(self):
Movie.create_table(if_not_exists=True).run_sync()
self.movie = Movie.objects().create(name="Star Wars").run_sync()
self.client = TestClient(PiccoloCRUD(table=Movie, read_only=False))

def tearDown(self):
Movie.alter().drop_table().run_sync()

def test_get_ids(self):
response = self.client.get("/ids/")
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.json(), {str(self.movie.id): str(self.movie.id)}
)

def test_get_list(self):
response = self.client.get("/")
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.json(),
{
"rows": [
{
"id": str(self.movie.id),
"name": self.movie.name,
"rating": self.movie.rating,
}
]
},
)

def test_get_single(self):
response = self.client.get(f"/{str(self.movie.id)}")
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.json(),
{
"id": str(self.movie.id),
"name": self.movie.name,
"rating": self.movie.rating,
},
)

def test_post(self):
Movie.delete(force=True).run_sync()
response = self.client.post(
"/", json={"name": "Lord of the Rings", "rating": 1000}
)
self.assertEqual(response.status_code, 201)

movie = Movie.select(Movie.name, Movie.rating).first().run_sync()
self.assertEqual(movie, {"name": "Lord of the Rings", "rating": 1000})

def test_delete(self):
response = self.client.delete(f"/{self.movie.id}/")
self.assertEqual(response.status_code, 204)
self.assertEqual(Movie.count().run_sync(), 0)

def test_patch(self):
response = self.client.patch(
f"/{self.movie.id}/", json={"rating": 2000}
)
self.assertEqual(response.status_code, 200)
movie = Movie.select().first().run_sync()
self.assertEqual(
movie, {"id": self.movie.id, "name": "Star Wars", "rating": 2000}
)

def test_invalid_id(self):
response = self.client.get("/abc123/")
self.assertEqual(response.status_code, 400)
self.assertEqual(response.content, b"The ID is invalid")

0 comments on commit 83fe185

Please sign in to comment.