Skip to content

Commit

Permalink
269 Stop multi-dimensional array and email filtering from breaking (#270
Browse files Browse the repository at this point in the history
)

* stop multidimensional arrays from breaking `FastAPIWrapper`

* replace `list[list]` with `list`

* fix linter errors

* wip

* fix `ModelFilters`

* update `pydantic_model_filters`

* use `Array. _get_dimensions` instead of `is_multidimensional_array`

* use latest piccolo

* ignore mypy warning

* add test for filtering email

* add a test for multidimensional arrays
  • Loading branch information
dantownsend committed Mar 26, 2024
1 parent 2df940d commit 9dc8f41
Show file tree
Hide file tree
Showing 10 changed files with 317 additions and 109 deletions.
99 changes: 82 additions & 17 deletions piccolo_api/crud/endpoints.py
Expand Up @@ -25,6 +25,7 @@
from piccolo.query.methods.select import Select
from piccolo.table import Table
from piccolo.utils.encoding import dump_json
from piccolo.utils.pydantic import create_pydantic_model
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route, Router
Expand All @@ -38,7 +39,6 @@
)

from .exceptions import MalformedQuery, db_exception_handler
from .serializers import create_pydantic_model
from .validators import Validators, apply_validators

if t.TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -257,9 +257,9 @@ def __init__(
table=table, exclude_secrets=exclude_secrets, max_joins=max_joins
)
schema_extra["visible_fields_options"] = self.visible_fields_options
schema_extra[
"primary_key_name"
] = self.table._meta.primary_key._meta.name
schema_extra["primary_key_name"] = (
self.table._meta.primary_key._meta.name
)
self.schema_extra = schema_extra

root_methods = ["GET"]
Expand All @@ -282,9 +282,9 @@ def __init__(
Route(
path="/{row_id:str}/",
endpoint=self.detail,
methods=["GET"]
if read_only
else ["GET", "PUT", "DELETE", "PATCH"],
methods=(
["GET"] if read_only else ["GET", "PUT", "DELETE", "PATCH"]
),
),
]

Expand Down Expand Up @@ -330,8 +330,8 @@ def pydantic_model_output(self) -> t.Type[pydantic.BaseModel]:
@property
def pydantic_model_optional(self) -> t.Type[pydantic.BaseModel]:
"""
All fields are optional, which is useful for serialising filters,
where a user can filter on any number of fields.
All fields are optional, which is useful for PATCH requests, which
may only update some fields.
"""
return create_pydantic_model(
self.table,
Expand All @@ -340,6 +340,63 @@ def pydantic_model_optional(self) -> t.Type[pydantic.BaseModel]:
model_name=f"{self.table.__name__}Optional",
)

@property
def pydantic_model_filters(self) -> t.Type[pydantic.BaseModel]:
"""
Used for serialising query params, which are used for filtering.
A special case is multidimensional arrays - if we have this::
my_column = Array(Array(Varchar()))
Even though the type is ``list[list[str]]``, this isn't allowed as a
query parameter. Instead, we use ``list[str]``.
Also, for ``Email`` columns, we don't want to validate that it's a
correct email address when filtering, as someone may want to filter
by 'gmail', for example.
"""
model_name = f"{self.table.__name__}Filters"

multidimensional_array_columns = [
i
for i in self.table._meta.array_columns
if i._get_dimensions() > 1
]

email_columns = self.table._meta.email_columns

base_model = create_pydantic_model(
self.table,
include_default_columns=True,
exclude_columns=(*multidimensional_array_columns, *email_columns),
all_optional=True,
model_name=model_name,
)

if multidimensional_array_columns or email_columns:
return pydantic.create_model(
__model_name=model_name,
__base__=base_model,
**{
i._meta.name: (
t.Optional[t.List[i._get_inner_value_type()]], # type: ignore # noqa: E501
pydantic.Field(default=None),
)
for i in multidimensional_array_columns
},
**{
i._meta.name: (
t.Optional[str],
pydantic.Field(default=None),
)
for i in email_columns
},
)
else:
return base_model

def pydantic_model_plural(
self,
include_readable=False,
Expand Down Expand Up @@ -716,7 +773,7 @@ def _apply_filters(
"""
fields = params.fields
if fields:
model_dict = self.pydantic_model_optional(**fields).model_dump()
model_dict = self.pydantic_model_filters(**fields).model_dump()
for field_name in fields.keys():
value = model_dict.get(field_name, ...)
if value is ...:
Expand Down Expand Up @@ -860,9 +917,9 @@ async def get_all(
curr_page_len = curr_page_len + offset
count = await self.table.count().run()
curr_page_string = f"{offset}-{curr_page_len}"
headers[
"Content-Range"
] = f"{plural_name} {curr_page_string}/{count}"
headers["Content-Range"] = (
f"{plural_name} {curr_page_string}/{count}"
)

# We need to serialise it ourselves, in case there are datetime
# fields.
Expand Down Expand Up @@ -1155,9 +1212,13 @@ async def patch_single(
cls = self.table

try:
values = {getattr(cls, key): getattr(model, key) for key in data.keys()}
values = {
getattr(cls, key): getattr(model, key) for key in data.keys()
}
except AttributeError:
unrecognised_keys = set(data.keys()) - set(model.model_dump().keys())
unrecognised_keys = set(data.keys()) - set(
model.model_dump().keys()
)
return Response(
f"Unrecognised keys - {unrecognised_keys}.",
status_code=400,
Expand All @@ -1180,15 +1241,19 @@ async def patch_single(
return Response(f"{e}", status_code=400)
values["password"] = cls.hash_password(password)
try:
await cls.update(values).where(cls._meta.primary_key == row_id).run()
await cls.update(values).where(
cls._meta.primary_key == row_id
).run()
new_row = (
await cls.select(exclude_secrets=self.exclude_secrets)
.where(cls._meta.primary_key == row_id)
.first()
.run()
)
assert new_row
return CustomJSONResponse(self.pydantic_model(**new_row).model_dump_json())
return CustomJSONResponse(
self.pydantic_model(**new_row).model_dump_json()
)
except ValueError:
return Response("Unable to save the resource.", status_code=500)

Expand Down
72 changes: 15 additions & 57 deletions piccolo_api/fastapi/endpoints.py
Expand Up @@ -17,19 +17,11 @@
from pydantic.main import BaseModel

from piccolo_api.crud.endpoints import PiccoloCRUD
from piccolo_api.utils.types import get_type

ANNOTATIONS: t.DefaultDict = defaultdict(dict)


try:
# Python 3.10 and above
from types import UnionType # type: ignore
except ImportError:

class UnionType: # type: ignore
...


class HTTPMethod(str, Enum):
get = "GET"
delete = "DELETE"
Expand Down Expand Up @@ -85,41 +77,6 @@ class ReferencesModel(BaseModel):
references: t.List[ReferenceModel]


def _get_type(type_: t.Type) -> t.Type:
"""
Extract the inner type from an optional if necessary, otherwise return
the type as is.
For example::
>>> _get_type(Optional[int])
int
>>> _get_type(int | None)
int
>>> _get_type(int)
int
>>> _get_type(list[str])
list[str]
"""
origin = t.get_origin(type_)

# Note: even if `t.Optional` is passed in, the origin is still a
# `t.Union` or `UnionType` depending on the Python version.
if any(origin is i for i in (t.Union, UnionType)):
union_args = t.get_args(type_)

NoneType = type(None)

if len(union_args) == 2 and NoneType in union_args:
return [i for i in union_args if i is not NoneType][0]

return type_


class FastAPIWrapper:
"""
Wraps ``PiccoloCRUD`` so it can easily be integrated into FastAPI.
Expand Down Expand Up @@ -160,6 +117,7 @@ def __init__(
self.ModelIn = piccolo_crud.pydantic_model
self.ModelOptional = piccolo_crud.pydantic_model_optional
self.ModelPlural = piccolo_crud.pydantic_model_plural()
self.ModelFilters = piccolo_crud.pydantic_model_filters

self.alias = f"{piccolo_crud.table._meta.tablename}__{id(self)}"

Expand All @@ -180,7 +138,7 @@ async def get(request: Request, **kwargs):

self.modify_signature(
endpoint=get,
model=self.ModelOut,
model=self.ModelFilters,
http_method=HTTPMethod.get,
allow_ordering=True,
allow_pagination=True,
Expand Down Expand Up @@ -243,7 +201,7 @@ async def count(request: Request, **kwargs):
return await piccolo_crud.get_count(request=request)

self.modify_signature(
endpoint=count, model=self.ModelOut, http_method=HTTPMethod.get
endpoint=count, model=self.ModelFilters, http_method=HTTPMethod.get
)

fastapi_app.add_api_route(
Expand Down Expand Up @@ -301,7 +259,7 @@ async def delete(request: Request, **kwargs):

self.modify_signature(
endpoint=delete,
model=self.ModelOut,
model=self.ModelFilters,
http_method=HTTPMethod.delete,
)

Expand All @@ -325,9 +283,9 @@ async def post(request: Request, model):
"""
return await piccolo_crud.root(request=request)

post.__annotations__[
"model"
] = f"ANNOTATIONS['{self.alias}']['ModelIn']"
post.__annotations__["model"] = (
f"ANNOTATIONS['{self.alias}']['ModelIn']"
)

fastapi_app.add_api_route(
path=root_url,
Expand Down Expand Up @@ -386,9 +344,9 @@ async def put(row_id: str, request: Request, model):
"""
return await piccolo_crud.detail(request=request)

put.__annotations__[
"model"
] = f"ANNOTATIONS['{self.alias}']['ModelIn']"
put.__annotations__["model"] = (
f"ANNOTATIONS['{self.alias}']['ModelIn']"
)

fastapi_app.add_api_route(
path=self.join_urls(root_url, "/{row_id:str}/"),
Expand All @@ -410,9 +368,9 @@ async def patch(row_id: str, request: Request, model):
"""
return await piccolo_crud.detail(request=request)

patch.__annotations__[
"model"
] = f"ANNOTATIONS['{self.alias}']['ModelOptional']"
patch.__annotations__["model"] = (
f"ANNOTATIONS['{self.alias}']['ModelOptional']"
)

fastapi_app.add_api_route(
path=self.join_urls(root_url, "/{row_id:str}/"),
Expand Down Expand Up @@ -460,7 +418,7 @@ def modify_signature(
for field_name, _field in model.model_fields.items():
annotation = _field.annotation
assert annotation is not None
type_ = _get_type(annotation)
type_ = get_type(annotation)

parameters.append(
Parameter(
Expand Down
Empty file added piccolo_api/utils/__init__.py
Empty file.
53 changes: 53 additions & 0 deletions piccolo_api/utils/types.py
@@ -0,0 +1,53 @@
"""
Utils for extracting information from complex, nested types.
"""

from __future__ import annotations

import typing as t

try:
# Python 3.10 and above
from types import UnionType # type: ignore
except ImportError:

class UnionType: # type: ignore
...


def get_type(type_: t.Type) -> t.Type:
"""
Extract the inner type from an optional if necessary, otherwise return
the type as is.
For example::
>>> get_type(Optional[int])
int
>>> get_type(int | None)
int
>>> get_type(int)
int
>>> _get_type(list[str])
list[str]
"""
origin = t.get_origin(type_)

# Note: even if `t.Optional` is passed in, the origin is still a
# `t.Union` or `UnionType` depending on the Python version.
if any(origin is i for i in (t.Union, UnionType)):
union_args = t.get_args(type_)

NoneType = type(None)

if len(union_args) == 2 and NoneType in union_args:
return [i for i in union_args if i is not NoneType][0]

return type_


__all__ = ("get_type",)
10 changes: 5 additions & 5 deletions requirements/dev-requirements.txt
@@ -1,7 +1,7 @@
black==24.3.0
isort==5.12.0
twine==4.0.2
mypy==1.5.1
isort==5.13.2
twine==5.0.0
mypy==1.9.0
pip-upgrader==1.4.15
wheel==0.41.2
setuptools==68.2.2
wheel==0.43.0
setuptools==69.2.0
2 changes: 1 addition & 1 deletion requirements/requirements.txt
@@ -1,5 +1,5 @@
Jinja2>=2.11.0
piccolo[postgres]>=1.0a3
piccolo[postgres]>=1.5
pydantic[email]>=2.0
python-multipart>=0.0.5
fastapi>=0.100.0
Expand Down

0 comments on commit 9dc8f41

Please sign in to comment.