Skip to content

Commit

Permalink
allow multiple __order in PiccoloCRUD (#218)
Browse files Browse the repository at this point in the history
* allow multiple order_by in `PiccoloCRUD`

* for tests for `_split_params`

* improve coverage

* update docs
  • Loading branch information
dantownsend committed Mar 17, 2023
1 parent 2c4e2a3 commit c011fd4
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 104 deletions.
6 changes: 6 additions & 0 deletions docs/source/crud/piccolo_crud.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ You can reverse the sort by prepending '-' to the field. For example:
GET /movie/?__order=-duration
Multiple columns can be used for the sort:

.. code-block::
GET /movie/?__order=-duration,name
Visible fields
~~~~~~~~~~~~~~

Expand Down
178 changes: 112 additions & 66 deletions piccolo_api/crud/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ class CustomJSONResponse(Response):

@dataclass
class OrderBy:
ascending: bool = False
property_name: str = "id"
column: Column
ascending: bool


@dataclass
Expand All @@ -87,11 +87,11 @@ class Params:
default_factory=lambda: defaultdict(lambda: MATCH_TYPES[0])
)
fields: t.Dict[str, t.Any] = field(default_factory=dict)
order_by: t.Optional[OrderBy] = None
order_by: t.Optional[t.List[OrderBy]] = None
include_readable: bool = False
page: int = 1
page_size: t.Optional[int] = None
visible_fields: str = field(default="")
visible_fields: t.Optional[t.List[Column]] = None
range_header: bool = False
range_header_name: str = field(default="")

Expand Down Expand Up @@ -135,6 +135,10 @@ def get_visible_fields_options(
return tuple(fields)


class ParamException(Exception):
pass


class PiccoloCRUD(Router):
"""
Wraps a Piccolo table with CRUD methods for use in a REST API.
Expand Down Expand Up @@ -455,7 +459,11 @@ async def get_count(self, request: Request) -> Response:
Returns the total number of rows in the table.
"""
params = self._parse_params(request.query_params)
split_params = self._split_params(params)

try:
split_params = self._split_params(params)
except ParamException as exception:
return Response(str(exception), status_code=400)

try:
query = self._apply_filters(self.table.count(), split_params)
Expand Down Expand Up @@ -523,8 +531,7 @@ async def root(self, request: Request) -> Response:

###########################################################################

@staticmethod
def _split_params(params: t.Dict[str, t.Any]) -> Params:
def _split_params(self, params: t.Dict[str, t.Any]) -> Params:
"""
Some parameters reference fields, and others provide instructions
on how to perform the query (e.g. which operator to use).
Expand Down Expand Up @@ -569,7 +576,9 @@ def _split_params(params: t.Dict[str, t.Any]) -> Params:
# they specify these operators, so set one for them.
response.fields[field_name] = None
else:
logger.info(f"Unrecognised __operator argument - {value}")
raise ParamException(
f"Unrecognised __operator argument - {value}"
)
continue

if key.endswith("__match") and value in MATCH_TYPES:
Expand All @@ -578,20 +587,40 @@ def _split_params(params: t.Dict[str, t.Any]) -> Params:
continue

if key == "__order":
ascending = True
if value.startswith("-"):
ascending = False
value = value[1:]
response.order_by = OrderBy(
ascending=ascending, property_name=value
)
# We allow multiple columns to be specified using a comma
# separated string e.g. 'name,created_on'. The value may
# already be a list if the parameter is passed in multiple
# times for example `?__order=name?__order=created_on`.
order_by: t.List[OrderBy] = []
sub_values: t.List[str]

if isinstance(value, str):
sub_values = value.split(",")
elif isinstance(value, list):
sub_values = value
else:
raise ParamException("Unrecognised __order_by type.")

for sub_value in sub_values:
ascending = True
if sub_value.startswith("-"):
ascending = False
sub_value = sub_value[1:]

column = self._get_column(column_name=sub_value)
order_by.append(
OrderBy(column=column, ascending=ascending)
)
response.order_by = order_by
continue

if key == "__page":
try:
page = int(value)
except ValueError:
logger.info(f"Unrecognised __page argument - {value}")
raise ParamException(
f"Unrecognised __page argument - {value}"
)
else:
response.page = page
continue
Expand All @@ -600,17 +629,39 @@ def _split_params(params: t.Dict[str, t.Any]) -> Params:
try:
page_size = int(value)
except ValueError:
logger.info(f"Unrecognised __page_size argument - {value}")
raise ParamException(
f"Unrecognised __page_size argument - {value}"
)
else:
response.page_size = page_size
continue

if key == "__visible_fields":
response.visible_fields = value
column_names: t.List[str]

if isinstance(value, str):
column_names = value.split(",")
elif isinstance(value, list):
column_names = value
else:
raise ParamException("Unrecognised __visible_fields type")

try:
response.visible_fields = [
self._get_column(column_name=column_name)
for column_name in column_names
]
except ValueError as e:
raise ParamException(str(e))
continue

if key == "__readable" and value in ("true", "True", "1"):
response.include_readable = True
if key == "__readable":
if value in ("t", "true", "True", "1"):
response.include_readable = True
else:
raise ParamException(
f"Unrecognised __readable argument - {value}"
)
continue

if key == "__range_header":
Expand Down Expand Up @@ -693,30 +744,28 @@ async def get_all(
"""
params = self._clean_data(params) if params else {}

split_params = self._split_params(params)
try:
split_params = self._split_params(params)
except ParamException as exception:
return Response(str(exception), status_code=400)

# Visible fields
visible_fields = split_params.visible_fields
nested: t.Union[bool, t.Tuple[Column, ...]]
if visible_fields:
try:
visible_columns = self._parse_visible_fields(visible_fields)
except ValueError as exception:
return Response(str(exception), status_code=400)

nested = tuple(
i for i in visible_columns if len(i._meta.call_chain) > 0
i for i in visible_fields if len(i._meta.call_chain) > 0
)
else:
visible_columns = self.table._meta.columns
visible_fields = self.table._meta.columns
nested = False

# Readable
include_readable = split_params.include_readable
readable_columns = (
[
self.table._get_related_readable(i)
for i in visible_columns
for i in visible_fields
if isinstance(i, ForeignKey)
]
if include_readable
Expand All @@ -725,7 +774,7 @@ async def get_all(

# Build select query, and exclude secrets
query = self.table.select(
*visible_columns,
*visible_fields,
*readable_columns,
exclude_secrets=self.exclude_secrets,
)
Expand All @@ -743,8 +792,10 @@ async def get_all(
# Ordering
order_by = split_params.order_by
if order_by:
column = getattr(self.table, order_by.property_name)
query = query.order_by(column, ascending=order_by.ascending)
for _order_by in order_by:
query = query.order_by(
_order_by.column, ascending=_order_by.ascending
)
else:
query = query.order_by(
self.table._meta.primary_key, ascending=False
Expand Down Expand Up @@ -788,7 +839,7 @@ async def get_all(
# fields.
json = self.pydantic_model_plural(
include_readable=include_readable,
include_columns=tuple(visible_columns),
include_columns=tuple(visible_fields),
nested=nested,
)(rows=rows).json()
return CustomJSONResponse(json, headers=headers)
Expand Down Expand Up @@ -856,7 +907,11 @@ async def delete_all(
Deletes all rows - query parameters are used for filtering.
"""
params = self._clean_data(params) if params else {}
split_params = self._split_params(params)

try:
split_params = self._split_params(params)
except ParamException as exception:
return Response(str(exception), status_code=400)

try:
query = self._apply_filters(
Expand Down Expand Up @@ -937,60 +992,51 @@ async def detail(self, request: Request) -> Response:
else:
return Response(status_code=405)

def _parse_visible_fields(self, visible_fields: str) -> t.List[Column]:
def _get_column(self, column_name: str) -> Column:
"""
Parse the ``visible_fields`` string, and return a list of columns.
Retrieves the Piccolo column based off the colum name, including joins.
:param visible_fields:
A comma separated list of column names, for example ``'id,name'``.
:param column_name:
The presence of a full stop in the name indicates a join, for
example ``'director.name'``.
:raises ValueError:
If the max join depth is exceeded, or the column name isn't
recognised.
"""
column_names: t.List[str] = visible_fields.split(",")
visible_columns: t.List[Column] = []

for column_name in column_names:
try:
column = self.table._meta.get_column_by_name(column_name)
except ValueError as exception:
raise ValueError(
f"{exception} - the column options are "
f"{self.visible_fields_options}."
)

if len(column._meta.call_chain) > self.max_joins:
raise ValueError("Max join depth exceeded")
else:
visible_columns.append(column)
try:
column = self.table._meta.get_column_by_name(column_name)
except ValueError as exception:
raise ValueError(
f"{exception} - the column options are "
f"{self.visible_fields_options}."
)

return visible_columns
if len(column._meta.call_chain) > self.max_joins:
raise ValueError("Max join depth exceeded")
else:
return column

@apply_validators
async def get_single(self, request: Request, row_id: PK_TYPES) -> Response:
"""
Returns a single row.
"""
params = dict(request.query_params)
split_params: Params = self._split_params(params)

try:
split_params = self._split_params(params)
except ParamException as exception:
return Response(str(exception), status_code=400)

# Visible fields
nested: t.Union[bool, t.Tuple[Column, ...]]
visible_fields = split_params.visible_fields
if visible_fields:
try:
visible_columns = self._parse_visible_fields(visible_fields)
except ValueError as exception:
return Response(str(exception), status_code=400)

nested = tuple(
i for i in visible_columns if len(i._meta.call_chain) > 0
i for i in visible_fields if len(i._meta.call_chain) > 0
)
else:
visible_columns = self.table._meta.columns
visible_fields = self.table._meta.columns
nested = False

# Readable
Expand All @@ -1005,7 +1051,7 @@ async def get_single(self, request: Request, row_id: PK_TYPES) -> Response:

query = (
self.table.select(
*visible_columns,
*visible_fields,
*readable_columns,
exclude_secrets=self.exclude_secrets,
)
Expand All @@ -1026,7 +1072,7 @@ async def get_single(self, request: Request, row_id: PK_TYPES) -> Response:
return CustomJSONResponse(
self._pydantic_model_output(
include_readable=split_params.include_readable,
include_columns=tuple(visible_columns),
include_columns=tuple(visible_fields),
nested=nested,
)(**row).json()
)
Expand Down

0 comments on commit c011fd4

Please sign in to comment.