diff --git a/piccolo_api/crud/endpoints.py b/piccolo_api/crud/endpoints.py index 5a763b6..13d32f5 100644 --- a/piccolo_api/crud/endpoints.py +++ b/piccolo_api/crud/endpoints.py @@ -25,6 +25,7 @@ from piccolo.query.methods.select import Select from piccolo.table import Table from piccolo.utils.encoding import dump_json +from piccolo.utils.pydantic import create_pydantic_model from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import Route, Router @@ -38,7 +39,6 @@ ) from .exceptions import MalformedQuery, db_exception_handler -from .serializers import create_pydantic_model from .validators import Validators, apply_validators if t.TYPE_CHECKING: # pragma: no cover @@ -257,9 +257,9 @@ def __init__( table=table, exclude_secrets=exclude_secrets, max_joins=max_joins ) schema_extra["visible_fields_options"] = self.visible_fields_options - schema_extra[ - "primary_key_name" - ] = self.table._meta.primary_key._meta.name + schema_extra["primary_key_name"] = ( + self.table._meta.primary_key._meta.name + ) self.schema_extra = schema_extra root_methods = ["GET"] @@ -282,9 +282,9 @@ def __init__( Route( path="/{row_id:str}/", endpoint=self.detail, - methods=["GET"] - if read_only - else ["GET", "PUT", "DELETE", "PATCH"], + methods=( + ["GET"] if read_only else ["GET", "PUT", "DELETE", "PATCH"] + ), ), ] @@ -330,8 +330,8 @@ def pydantic_model_output(self) -> t.Type[pydantic.BaseModel]: @property def pydantic_model_optional(self) -> t.Type[pydantic.BaseModel]: """ - All fields are optional, which is useful for serialising filters, - where a user can filter on any number of fields. + All fields are optional, which is useful for PATCH requests, which + may only update some fields. """ return create_pydantic_model( self.table, @@ -340,6 +340,63 @@ def pydantic_model_optional(self) -> t.Type[pydantic.BaseModel]: model_name=f"{self.table.__name__}Optional", ) + @property + def pydantic_model_filters(self) -> t.Type[pydantic.BaseModel]: + """ + Used for serialising query params, which are used for filtering. + + A special case is multidimensional arrays - if we have this:: + + my_column = Array(Array(Varchar())) + + Even though the type is ``list[list[str]]``, this isn't allowed as a + query parameter. Instead, we use ``list[str]``. + + Also, for ``Email`` columns, we don't want to validate that it's a + correct email address when filtering, as someone may want to filter + by 'gmail', for example. + + """ + model_name = f"{self.table.__name__}Filters" + + multidimensional_array_columns = [ + i + for i in self.table._meta.array_columns + if i._get_dimensions() > 1 + ] + + email_columns = self.table._meta.email_columns + + base_model = create_pydantic_model( + self.table, + include_default_columns=True, + exclude_columns=(*multidimensional_array_columns, *email_columns), + all_optional=True, + model_name=model_name, + ) + + if multidimensional_array_columns or email_columns: + return pydantic.create_model( + __model_name=model_name, + __base__=base_model, + **{ + i._meta.name: ( + t.Optional[t.List[i._get_inner_value_type()]], # type: ignore # noqa: E501 + pydantic.Field(default=None), + ) + for i in multidimensional_array_columns + }, + **{ + i._meta.name: ( + t.Optional[str], + pydantic.Field(default=None), + ) + for i in email_columns + }, + ) + else: + return base_model + def pydantic_model_plural( self, include_readable=False, @@ -716,7 +773,7 @@ def _apply_filters( """ fields = params.fields if fields: - model_dict = self.pydantic_model_optional(**fields).model_dump() + model_dict = self.pydantic_model_filters(**fields).model_dump() for field_name in fields.keys(): value = model_dict.get(field_name, ...) if value is ...: @@ -860,9 +917,9 @@ async def get_all( curr_page_len = curr_page_len + offset count = await self.table.count().run() curr_page_string = f"{offset}-{curr_page_len}" - headers[ - "Content-Range" - ] = f"{plural_name} {curr_page_string}/{count}" + headers["Content-Range"] = ( + f"{plural_name} {curr_page_string}/{count}" + ) # We need to serialise it ourselves, in case there are datetime # fields. @@ -1155,9 +1212,13 @@ async def patch_single( cls = self.table try: - values = {getattr(cls, key): getattr(model, key) for key in data.keys()} + values = { + getattr(cls, key): getattr(model, key) for key in data.keys() + } except AttributeError: - unrecognised_keys = set(data.keys()) - set(model.model_dump().keys()) + unrecognised_keys = set(data.keys()) - set( + model.model_dump().keys() + ) return Response( f"Unrecognised keys - {unrecognised_keys}.", status_code=400, @@ -1180,7 +1241,9 @@ async def patch_single( return Response(f"{e}", status_code=400) values["password"] = cls.hash_password(password) try: - await cls.update(values).where(cls._meta.primary_key == row_id).run() + await cls.update(values).where( + cls._meta.primary_key == row_id + ).run() new_row = ( await cls.select(exclude_secrets=self.exclude_secrets) .where(cls._meta.primary_key == row_id) @@ -1188,7 +1251,9 @@ async def patch_single( .run() ) assert new_row - return CustomJSONResponse(self.pydantic_model(**new_row).model_dump_json()) + return CustomJSONResponse( + self.pydantic_model(**new_row).model_dump_json() + ) except ValueError: return Response("Unable to save the resource.", status_code=500) diff --git a/piccolo_api/fastapi/endpoints.py b/piccolo_api/fastapi/endpoints.py index 4ab4b9b..b1cf311 100644 --- a/piccolo_api/fastapi/endpoints.py +++ b/piccolo_api/fastapi/endpoints.py @@ -17,19 +17,11 @@ from pydantic.main import BaseModel from piccolo_api.crud.endpoints import PiccoloCRUD +from piccolo_api.utils.types import get_type ANNOTATIONS: t.DefaultDict = defaultdict(dict) -try: - # Python 3.10 and above - from types import UnionType # type: ignore -except ImportError: - - class UnionType: # type: ignore - ... - - class HTTPMethod(str, Enum): get = "GET" delete = "DELETE" @@ -85,41 +77,6 @@ class ReferencesModel(BaseModel): references: t.List[ReferenceModel] -def _get_type(type_: t.Type) -> t.Type: - """ - Extract the inner type from an optional if necessary, otherwise return - the type as is. - - For example:: - - >>> _get_type(Optional[int]) - int - - >>> _get_type(int | None) - int - - >>> _get_type(int) - int - - >>> _get_type(list[str]) - list[str] - - """ - origin = t.get_origin(type_) - - # Note: even if `t.Optional` is passed in, the origin is still a - # `t.Union` or `UnionType` depending on the Python version. - if any(origin is i for i in (t.Union, UnionType)): - union_args = t.get_args(type_) - - NoneType = type(None) - - if len(union_args) == 2 and NoneType in union_args: - return [i for i in union_args if i is not NoneType][0] - - return type_ - - class FastAPIWrapper: """ Wraps ``PiccoloCRUD`` so it can easily be integrated into FastAPI. @@ -160,6 +117,7 @@ def __init__( self.ModelIn = piccolo_crud.pydantic_model self.ModelOptional = piccolo_crud.pydantic_model_optional self.ModelPlural = piccolo_crud.pydantic_model_plural() + self.ModelFilters = piccolo_crud.pydantic_model_filters self.alias = f"{piccolo_crud.table._meta.tablename}__{id(self)}" @@ -180,7 +138,7 @@ async def get(request: Request, **kwargs): self.modify_signature( endpoint=get, - model=self.ModelOut, + model=self.ModelFilters, http_method=HTTPMethod.get, allow_ordering=True, allow_pagination=True, @@ -243,7 +201,7 @@ async def count(request: Request, **kwargs): return await piccolo_crud.get_count(request=request) self.modify_signature( - endpoint=count, model=self.ModelOut, http_method=HTTPMethod.get + endpoint=count, model=self.ModelFilters, http_method=HTTPMethod.get ) fastapi_app.add_api_route( @@ -301,7 +259,7 @@ async def delete(request: Request, **kwargs): self.modify_signature( endpoint=delete, - model=self.ModelOut, + model=self.ModelFilters, http_method=HTTPMethod.delete, ) @@ -325,9 +283,9 @@ async def post(request: Request, model): """ return await piccolo_crud.root(request=request) - post.__annotations__[ - "model" - ] = f"ANNOTATIONS['{self.alias}']['ModelIn']" + post.__annotations__["model"] = ( + f"ANNOTATIONS['{self.alias}']['ModelIn']" + ) fastapi_app.add_api_route( path=root_url, @@ -386,9 +344,9 @@ async def put(row_id: str, request: Request, model): """ return await piccolo_crud.detail(request=request) - put.__annotations__[ - "model" - ] = f"ANNOTATIONS['{self.alias}']['ModelIn']" + put.__annotations__["model"] = ( + f"ANNOTATIONS['{self.alias}']['ModelIn']" + ) fastapi_app.add_api_route( path=self.join_urls(root_url, "/{row_id:str}/"), @@ -410,9 +368,9 @@ async def patch(row_id: str, request: Request, model): """ return await piccolo_crud.detail(request=request) - patch.__annotations__[ - "model" - ] = f"ANNOTATIONS['{self.alias}']['ModelOptional']" + patch.__annotations__["model"] = ( + f"ANNOTATIONS['{self.alias}']['ModelOptional']" + ) fastapi_app.add_api_route( path=self.join_urls(root_url, "/{row_id:str}/"), @@ -460,7 +418,7 @@ def modify_signature( for field_name, _field in model.model_fields.items(): annotation = _field.annotation assert annotation is not None - type_ = _get_type(annotation) + type_ = get_type(annotation) parameters.append( Parameter( diff --git a/piccolo_api/utils/__init__.py b/piccolo_api/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/piccolo_api/utils/types.py b/piccolo_api/utils/types.py new file mode 100644 index 0000000..63d96f1 --- /dev/null +++ b/piccolo_api/utils/types.py @@ -0,0 +1,53 @@ +""" +Utils for extracting information from complex, nested types. +""" + +from __future__ import annotations + +import typing as t + +try: + # Python 3.10 and above + from types import UnionType # type: ignore +except ImportError: + + class UnionType: # type: ignore + ... + + +def get_type(type_: t.Type) -> t.Type: + """ + Extract the inner type from an optional if necessary, otherwise return + the type as is. + + For example:: + + >>> get_type(Optional[int]) + int + + >>> get_type(int | None) + int + + >>> get_type(int) + int + + >>> _get_type(list[str]) + list[str] + + """ + origin = t.get_origin(type_) + + # Note: even if `t.Optional` is passed in, the origin is still a + # `t.Union` or `UnionType` depending on the Python version. + if any(origin is i for i in (t.Union, UnionType)): + union_args = t.get_args(type_) + + NoneType = type(None) + + if len(union_args) == 2 and NoneType in union_args: + return [i for i in union_args if i is not NoneType][0] + + return type_ + + +__all__ = ("get_type",) diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index 908c2aa..fa7c03a 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -1,7 +1,7 @@ black==24.3.0 -isort==5.12.0 -twine==4.0.2 -mypy==1.5.1 +isort==5.13.2 +twine==5.0.0 +mypy==1.9.0 pip-upgrader==1.4.15 -wheel==0.41.2 -setuptools==68.2.2 +wheel==0.43.0 +setuptools==69.2.0 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 35a7f56..c18d9c9 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,5 +1,5 @@ Jinja2>=2.11.0 -piccolo[postgres]>=1.0a3 +piccolo[postgres]>=1.5 pydantic[email]>=2.0 python-multipart>=0.0.5 fastapi>=0.100.0 diff --git a/tests/crud/test_crud_endpoints.py b/tests/crud/test_crud_endpoints.py index 98a6622..6daa8d0 100644 --- a/tests/crud/test_crud_endpoints.py +++ b/tests/crud/test_crud_endpoints.py @@ -2,7 +2,15 @@ from unittest import TestCase from piccolo.apps.user.tables import BaseUser -from piccolo.columns import Email, ForeignKey, Integer, Secret, Text, Varchar +from piccolo.columns import ( + Array, + Email, + ForeignKey, + Integer, + Secret, + Text, + Varchar, +) from piccolo.columns.column_types import OnDelete from piccolo.columns.readable import Readable from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync @@ -43,6 +51,10 @@ class Studio(Table): booking_email = Email(default="booking@studio.com") +class Seats(Table): + arrangement = Array(Array(Varchar())) + + class Cinema(Table): name = Varchar() address = Text(unique=True) @@ -473,6 +485,7 @@ def test_get_schema(self): "help_text": None, "nullable": False, "secret": False, + "unique": False, }, "maxLength": 100, "title": "Name", @@ -486,6 +499,7 @@ def test_get_schema(self): "help_text": None, "nullable": False, "secret": False, + "unique": False, }, "title": "Rating", }, @@ -541,6 +555,7 @@ class Rating(Enum): "help_text": None, "nullable": False, "secret": False, + "unique": False, }, "title": "Score", } @@ -590,6 +605,7 @@ def test_get_schema_with_joins(self): "help_text": None, "nullable": True, "secret": False, + "unique": False, }, "title": "Movie", }, @@ -604,6 +620,7 @@ def test_get_schema_with_joins(self): "help_text": None, "nullable": False, "secret": False, + "unique": False, }, "title": "Name", }, @@ -1124,6 +1141,111 @@ def test_match(self): ) +class TestFilterEmail(TestCase): + """ + Make sure that ``Email`` columns can be filtered - i.e. we can pass in + partial emails like ``google.com``. + """ + + def setUp(self): + Studio.create_table(if_not_exists=True).run_sync() + + def tearDown(self): + Studio.alter().drop_table().run_sync() + + def test_filter_email(self): + client = TestClient(PiccoloCRUD(table=Studio)) + + Studio.insert( + Studio( + { + Studio.name: "Studio 1", + Studio.booking_email: "booking_1@gmail.com", + Studio.contact_email: "contact_1@gmail.com", + } + ), + Studio( + { + Studio.name: "Studio 2", + Studio.booking_email: "booking_2@gmail.com", + Studio.contact_email: "contact_2@gmail.com", + } + ), + ).run_sync() + + response = client.get("/?booking_email=booking_1") + self.assertEqual(response.status_code, 200) + + self.assertEqual( + response.json(), + { + "rows": [ + { + "booking_email": "booking_1@gmail.com", + "contact_email": "contact_1@gmail.com", + "id": 1, + "name": "Studio 1", + } + ] + }, + ) + + +class TestFilterMultidimensionalArray(TestCase): + """ + Make sure that multidimensional ``Array`` columns can be filtered. + """ + + def setUp(self): + Seats.create_table(if_not_exists=True).run_sync() + + def tearDown(self): + Seats.alter().drop_table().run_sync() + + def test_filter_multidimensional_array(self): + client = TestClient(PiccoloCRUD(table=Seats)) + + Seats.insert( + Seats( + { + Seats.arrangement: [ + ["A1", "A2", "A3"], + ["B1", "B2", "B3"], + ["C1", "C2", "C3"], + ], + } + ), + Seats( + { + Seats.arrangement: [ + ["D1", "D2", "D3"], + ["E1", "E2", "E3"], + ["F1", "F2", "F3"], + ], + } + ), + ).run_sync() + + response = client.get("/?arrangement=A1") + self.assertEqual(response.status_code, 200) + + self.assertEqual( + response.json(), + { + "rows": [ + { + "id": 1, + "arrangement": [ + ["A1", "A2", "A3"], + ["B1", "B2", "B3"], + ["C1", "C2", "C3"], + ], + } + ] + }, + ) + + class TestExcludeSecrets(TestCase): """ Make sure that if ``exclude_secrets`` is ``True``, then values for diff --git a/tests/fastapi/test_fastapi_endpoints.py b/tests/fastapi/test_fastapi_endpoints.py index 8ca1842..b5cd0d5 100644 --- a/tests/fastapi/test_fastapi_endpoints.py +++ b/tests/fastapi/test_fastapi_endpoints.py @@ -1,8 +1,5 @@ -import sys -import typing as t from unittest import TestCase -import pytest from fastapi import FastAPI from piccolo.columns import ForeignKey, Integer, Varchar from piccolo.columns.readable import Readable @@ -10,7 +7,7 @@ from starlette.testclient import TestClient from piccolo_api.crud.endpoints import PiccoloCRUD -from piccolo_api.fastapi.endpoints import FastAPIWrapper, _get_type +from piccolo_api.fastapi.endpoints import FastAPIWrapper class Movie(Table): @@ -120,6 +117,7 @@ def test_schema(self): "help_text": None, "nullable": False, "secret": False, + "unique": False, }, "title": "Name", }, @@ -131,6 +129,7 @@ def test_schema(self): "help_text": None, "nullable": False, "secret": False, + "unique": False, }, "title": "Rating", }, @@ -172,6 +171,7 @@ def test_schema_joins(self): "help_text": None, "nullable": True, "secret": False, + "unique": False, }, "title": "Movie", }, @@ -186,6 +186,7 @@ def test_schema_joins(self): "help_text": None, "nullable": False, "secret": False, + "unique": False, }, "title": "Name", }, @@ -258,27 +259,3 @@ def test_patch(self): self.assertEqual( response.json(), {"id": 1, "name": "Star Wars", "rating": 90} ) - - -class TestGetType(TestCase): - def test_get_type(self): - """ - If we pass in an optional type, it should return the non-optional type. - """ - # Should return the underlying type, as they're all optional: - self.assertIs(_get_type(t.Optional[str]), str) - self.assertIs(_get_type(t.Optional[t.List[str]]), t.List[str]) - self.assertIs(_get_type(t.Union[str, None]), str) - - # Should be returned as is, because it's not optional: - self.assertIs(_get_type(t.List[str]), t.List[str]) - - @pytest.mark.skipif( - sys.version_info < (3, 10), reason="Union syntax not available" - ) - def test_new_union_syntax(self): - """ - Make sure it works with the new syntax added in Python 3.10. - """ - self.assertIs(_get_type(str | None), str) # type: ignore - self.assertIs(_get_type(None | str), str) # type: ignore diff --git a/tests/serve.py b/tests/serve.py index 55bf82a..5b9cf60 100644 --- a/tests/serve.py +++ b/tests/serve.py @@ -3,6 +3,7 @@ Run it from the root of the project using `python -m tests.serve`. """ + import os import uvicorn diff --git a/tests/utils/test_types.py b/tests/utils/test_types.py new file mode 100644 index 0000000..b5a705d --- /dev/null +++ b/tests/utils/test_types.py @@ -0,0 +1,32 @@ +import sys +import typing as t +from unittest import TestCase + +import pytest + +from piccolo_api.utils.types import get_type + + +class TestGetType(TestCase): + + def test_get_type(self): + """ + If we pass in an optional type, it should return the non-optional type. + """ + # Should return the underlying type, as they're all optional: + self.assertIs(get_type(t.Optional[str]), str) + self.assertIs(get_type(t.Optional[t.List[str]]), t.List[str]) + self.assertIs(get_type(t.Union[str, None]), str) + + # Should be returned as is, because it's not optional: + self.assertIs(get_type(t.List[str]), t.List[str]) + + @pytest.mark.skipif( + sys.version_info < (3, 10), reason="Union syntax not available" + ) + def test_new_union_syntax(self): + """ + Make sure it works with the new syntax added in Python 3.10. + """ + self.assertIs(get_type(str | None), str) # type: ignore + self.assertIs(get_type(None | str), str) # type: ignore