Skip to content

Commit

Permalink
Added is_null and not_null operators (#212)
Browse files Browse the repository at this point in the history
* added is_null  / not_null operators

* update operator docs

* make null / not null operators take precedence over `match`

* add new operators to FastAPI docs

* make sure operators are shown for date / time fields

* fix bug with `match` not being shown for `Varchar` columns

* more tests

* fix linter errors

* bumping minimum piccolo version

We added much better type annotations in that version of Piccolo, which effects the type annotations in Piccolo API.
  • Loading branch information
dantownsend committed Feb 9, 2023
1 parent 0bd38e4 commit d443647
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 43 deletions.
30 changes: 23 additions & 7 deletions docs/source/crud/piccolo_crud.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,23 +127,39 @@ Get all movies with 'star wars' in the name:
Operators
~~~~~~~~~

As shown above you can specify which operator to use. The allowed operators are:
As shown above you can specify which operator to use. For numeric, and date /
time fields the following operators are allowed:

* lt: Less Than
* lte: Less Equal Than
* gt: Greater Than
* gte: Greater Equal Than
* e: Equal (default)
* ``lt``: Less Than
* ``lte``: Less Than or Equal
* ``gt``: Greater Than
* ``gte``: Greater Than or Equal
* ``e``: Equal (default)

To specify which operator to use, pass a query parameter like ``field__operator=operator_name``.
For example ``duration__operator=gte``.

A query which fetches all movies lasting more than 200 minutes:
Here's a query which fetches all movies lasting more than 200 minutes:

.. code-block::
GET /movie/?duration=200&duration__operator=gte
``is_null`` / ``not_null``
^^^^^^^^^^^^^^^^^^^^^^^^^^

All field types also support the ``is_null`` and ``not_null`` operators.

For example:

.. code-block::
# Get all rows with a null duration
GET /movie/duration__operator=is_null
# Get all rows without a null duration
GET /movie/duration__operator=not_null
Match type
~~~~~~~~~~

Expand Down
62 changes: 43 additions & 19 deletions piccolo_api/crud/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
Equal,
GreaterEqualThan,
GreaterThan,
IsNotNull,
IsNull,
LessEqualThan,
LessThan,
)
Expand Down Expand Up @@ -56,6 +58,8 @@
"gt": GreaterThan,
"gte": GreaterEqualThan,
"e": Equal,
"is_null": IsNull,
"not_null": IsNotNull,
}


Expand Down Expand Up @@ -555,9 +559,17 @@ def _split_params(params: t.Dict[str, t.Any]) -> Params:
response = Params()

for key, value in params.items():
if key.endswith("__operator") and value in OPERATOR_MAP.keys():
field_name = key.split("__operator")[0]
response.operators[field_name] = OPERATOR_MAP[value]
if key.endswith("__operator"):
if value in OPERATOR_MAP.keys():
field_name = key.split("__operator")[0]
operator = OPERATOR_MAP[value]
response.operators[field_name] = operator
if operator in (IsNull, IsNotNull):
# We don't require the user to pass in a value if
# they specify these operators, so set one for them.
response.fields[field_name] = None
else:
logger.info(f"Unrecognised __operator argument - {value}")
continue

if key.endswith("__match") and value in MATCH_TYPES:
Expand Down Expand Up @@ -639,26 +651,37 @@ def _apply_filters(
values = value if isinstance(value, list) else [value]

for value in values:
if isinstance(column, (Varchar, Text)):
match_type = params.match_types[field_name]
if match_type == "exact":
clause = column.__eq__(value)
elif match_type == "starts":
clause = column.ilike(f"{value}%")
elif match_type == "ends":
clause = column.ilike(f"%{value}")
else:
clause = column.ilike(f"%{value}%")
query = query.where(clause)
elif isinstance(column, Array):
query = query.where(column.any(value))
else:
operator = params.operators[field_name]
operator = params.operators[field_name]
if operator in (IsNull, IsNotNull):
query = query.where(
Where(
column=column, value=value, operator=operator
column=column,
operator=operator,
)
)
else:
if isinstance(column, (Varchar, Text)):
match_type = params.match_types[field_name]
if match_type == "exact":
clause = column.__eq__(value)
elif match_type == "starts":
clause = column.ilike(f"{value}%")
elif match_type == "ends":
clause = column.ilike(f"%{value}")
else:
clause = column.ilike(f"%{value}%")
query = query.where(clause)
elif isinstance(column, Array):
query = query.where(column.any(value))
else:
query = query.where(
Where(
column=column,
value=value,
operator=operator,
)
)

return query

@apply_validators
Expand Down Expand Up @@ -1100,6 +1123,7 @@ async def patch_single(
.first()
.run()
)
assert new_row
return CustomJSONResponse(
self.pydantic_model(**new_row).json()
)
Expand Down
34 changes: 30 additions & 4 deletions piccolo_api/fastapi/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

from __future__ import annotations

import datetime
import typing as t
from collections import defaultdict
from decimal import Decimal
from enum import Enum
from inspect import Parameter, Signature
from inspect import Parameter, Signature, isclass

from fastapi import APIRouter, FastAPI, Request
from fastapi.params import Query
Expand Down Expand Up @@ -422,7 +423,15 @@ def modify_signature(
),
)

if type_ in (int, float, Decimal):
if type_ in (
int,
float,
Decimal,
datetime.date,
datetime.datetime,
datetime.time,
datetime.timedelta,
):
parameters.append(
Parameter(
name=f"{field_name}__operator",
Expand All @@ -432,13 +441,30 @@ def modify_signature(
description=(
f"Which operator to use for `{field_name}`. "
"The options are `e` (equals - default) `lt`, "
"`lte`, `gt`, and `gte`."
"`lte`, `gt`, `gte`, `is_null`, and "
"`not_null`."
),
),
)
)
else:
parameters.append(
Parameter(
name=f"{field_name}__operator",
kind=Parameter.POSITIONAL_OR_KEYWORD,
default=Query(
default=None,
description=(
f"Which operator to use for `{field_name}`. "
"The options are `is_null`, and `not_null`."
),
),
)
)

if type_ is str:
# We have to check if it's a subclass of `str` for Varchar, which
# uses Pydantics `constr` (constrained string).
if type_ is str or (isclass(type_) and issubclass(type_, str)):
parameters.append(
Parameter(
name=f"{field_name}__match",
Expand Down
4 changes: 1 addition & 3 deletions piccolo_api/session_auth/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ async def get_user_id(
happens. The ``max_expiry_date`` remains the same, so there's a
hard limit on how long a session can be used for.
"""
session: SessionsBase = (
await cls.objects().where(cls.token == token).first().run()
)
session = await cls.objects().where(cls.token == token).first().run()

if not session:
return None
Expand Down
7 changes: 3 additions & 4 deletions piccolo_api/token_auth/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,16 @@ async def get_user(self, token: str) -> User:
raise AuthenticationError()

user = (
await self.auth_table.select(self.auth_table.username)
await self.auth_table.objects()
.where(self.auth_table._meta.primary_key == user_id)
.first()
.run()
)

if not user:
if user is None:
raise AuthenticationError()

user = User(user=user)
return user
return User(user=user)


DEFAULT_PROVIDER = PiccoloTokenAuthProvider()
Expand Down
4 changes: 2 additions & 2 deletions piccolo_api/token_auth/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from piccolo.utils.sync import run_sync

if t.TYPE_CHECKING: # pragma: no cover
from piccolo.query import Select
from piccolo.query.methods.select import First


def generate_token() -> str:
Expand Down Expand Up @@ -56,7 +56,7 @@ def create_token_sync(cls, user_id: int) -> str:
return run_sync(cls.create_token(user_id))

@classmethod
async def authenticate(cls, token: str) -> Select:
async def authenticate(cls, token: str) -> First:
return cls.select(cls.user.id).where(cls.token == token).first()

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Jinja2>=2.11.0
piccolo[postgres]>=0.89.0
piccolo[postgres]>=0.104.0
pydantic[email]>=1.6
python-multipart>=0.0.5
fastapi>=0.87.0
Expand Down
46 changes: 44 additions & 2 deletions tests/crud/test_crud_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,9 +803,9 @@ def test_reverse_order(self):
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"rows": rows})

def test_operator(self):
def test_operator_gt(self):
"""
Test filters - greater than.
Test operator - greater than.
"""
client = TestClient(PiccoloCRUD(table=Movie, read_only=False))
response = client.get(
Expand All @@ -818,6 +818,48 @@ def test_operator(self):
{"rows": [{"id": 1, "name": "Star Wars", "rating": 93}]},
)

def test_operator_null(self):
"""
Test operators - `is_null` / `not_null`.
"""
# Create a role with a null foreign key value.
Role(name="Joe Bloggs").save().run_sync()

client = TestClient(PiccoloCRUD(table=Role, read_only=False))

# Null
response = client.get(
"/",
params={"movie__operator": "is_null"},
)
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.json(),
{"rows": [{"id": 2, "movie": None, "name": "Joe Bloggs"}]},
)

# Not Null
response = client.get(
"/",
params={"movie__operator": "not_null"},
)
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.json(),
{"rows": [{"id": 1, "movie": 1, "name": "Luke Skywalker"}]},
)

# Make sure the null operator takes precedence
response = client.get(
"/",
params={"movie": 2, "movie__operator": "not_null"},
)
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.json(),
{"rows": [{"id": 1, "movie": 1, "name": "Luke Skywalker"}]},
)

def test_match(self):
client = TestClient(PiccoloCRUD(table=Movie, read_only=False))

Expand Down
3 changes: 2 additions & 1 deletion tests/crud/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ async def remove_spaces(row_id: int, values: dict):

async def look_up_existing(row_id: int, values: dict):
row = await Movie.objects().get(Movie._meta.primary_key == row_id).run()
values["name"] = row.name
if row is not None:
values["name"] = row.name
return values


Expand Down

0 comments on commit d443647

Please sign in to comment.