diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index d71d78e7e..136f86680 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -13,7 +13,7 @@ on: jobs: linters: runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 60 strategy: matrix: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] @@ -35,7 +35,7 @@ jobs: integration: runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 60 strategy: matrix: # These tests are slow, so we only run on the latest Python @@ -82,7 +82,7 @@ jobs: postgres: runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 60 strategy: matrix: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] @@ -138,7 +138,7 @@ jobs: cockroach: runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 60 strategy: matrix: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] @@ -172,7 +172,7 @@ jobs: sqlite: runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 60 strategy: matrix: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] diff --git a/docs/src/piccolo/query_clauses/group_by.rst b/docs/src/piccolo/query_clauses/group_by.rst index d3eb5af87..516a4ac00 100644 --- a/docs/src/piccolo/query_clauses/group_by.rst +++ b/docs/src/piccolo/query_clauses/group_by.rst @@ -19,13 +19,13 @@ In the following query, we get a count of the number of bands per manager: .. code-block:: python - >>> from piccolo.query.methods.select import Count + >>> from piccolo.query.functions.aggregate import Count >>> await Band.select( ... Band.manager.name.as_alias('manager_name'), ... Count(alias='band_count') ... ).group_by( - ... Band.manager + ... Band.manager.name ... ) [ diff --git a/docs/src/piccolo/query_types/count.rst b/docs/src/piccolo/query_types/count.rst index 033667c18..794d125bc 100644 --- a/docs/src/piccolo/query_types/count.rst +++ b/docs/src/piccolo/query_types/count.rst @@ -15,7 +15,7 @@ It's equivalent to this ``select`` query: .. code-block:: python - from piccolo.query.methods.select import Count + from piccolo.query.functions.aggregate import Count >>> response = await Band.select(Count()) >>> response[0]['count'] diff --git a/docs/src/piccolo/query_types/select.rst b/docs/src/piccolo/query_types/select.rst index 1591e3580..092291e4c 100644 --- a/docs/src/piccolo/query_types/select.rst +++ b/docs/src/piccolo/query_types/select.rst @@ -165,6 +165,31 @@ convenient. ------------------------------------------------------------------------------- +String functions +---------------- + +Piccolo has lots of string functions built-in. See +``piccolo/query/functions/string.py``. Here's an example using ``Upper``, to +convert values to uppercase: + +.. code-block:: python + + from piccolo.query.functions.string import Upper + + >> await Band.select(Upper(Band.name, alias='name')) + [{'name': 'PYTHONISTAS'}, ...] + +You can also use these within where clauses: + +.. code-block:: python + + from piccolo.query.functions.string import Upper + + >> await Band.select(Band.name).where(Upper(Band.manager.name) == 'GUIDO') + [{'name': 'Pythonistas'}] + +------------------------------------------------------------------------------- + .. _AggregateFunctions: Aggregate functions @@ -182,7 +207,7 @@ Returns the number of matching rows. .. code-block:: python - from piccolo.query.methods.select import Count + from piccolo.query.functions.aggregate import Count >> await Band.select(Count()).where(Band.popularity > 100) [{'count': 3}] @@ -196,7 +221,7 @@ Returns the average for a given column: .. code-block:: python - >>> from piccolo.query import Avg + >>> from piccolo.query.functions.aggregate import Avg >>> response = await Band.select(Avg(Band.popularity)).first() >>> response["avg"] 750.0 @@ -208,7 +233,7 @@ Returns the sum for a given column: .. code-block:: python - >>> from piccolo.query import Sum + >>> from piccolo.query.functions.aggregate import Sum >>> response = await Band.select(Sum(Band.popularity)).first() >>> response["sum"] 1500 @@ -220,7 +245,7 @@ Returns the maximum for a given column: .. code-block:: python - >>> from piccolo.query import Max + >>> from piccolo.query.functions.aggregate import Max >>> response = await Band.select(Max(Band.popularity)).first() >>> response["max"] 1000 @@ -232,7 +257,7 @@ Returns the minimum for a given column: .. code-block:: python - >>> from piccolo.query import Min + >>> from piccolo.query.functions.aggregate import Min >>> response = await Band.select(Min(Band.popularity)).first() >>> response["min"] 500 @@ -244,7 +269,7 @@ You also can have multiple different aggregate functions in one query: .. code-block:: python - >>> from piccolo.query import Avg, Sum + >>> from piccolo.query.functions.aggregate import Avg, Sum >>> response = await Band.select( ... Avg(Band.popularity), ... Sum(Band.popularity) diff --git a/piccolo/columns/base.py b/piccolo/columns/base.py index 886a0ee48..ce452260c 100644 --- a/piccolo/columns/base.py +++ b/piccolo/columns/base.py @@ -6,7 +6,6 @@ import inspect import typing as t import uuid -from abc import ABCMeta, abstractmethod from dataclasses import dataclass, field, fields from enum import Enum @@ -32,6 +31,7 @@ NotLike, ) from piccolo.columns.reference import LazyTableReference +from piccolo.querystring import QueryString, Selectable from piccolo.utils.warnings import colored_warning if t.TYPE_CHECKING: # pragma: no cover @@ -205,7 +205,6 @@ def table(self) -> t.Type[Table]: # Used by Foreign Keys: call_chain: t.List["ForeignKey"] = field(default_factory=list) - table_alias: t.Optional[str] = None ########################################################################### @@ -260,7 +259,7 @@ def _get_path(self, include_quotes: bool = False): column_name = self.db_column_name if self.call_chain: - table_alias = self.call_chain[-1]._meta.table_alias + table_alias = self.call_chain[-1].table_alias if include_quotes: return f'"{table_alias}"."{column_name}"' else: @@ -272,7 +271,9 @@ def _get_path(self, include_quotes: bool = False): return f"{self.table._meta.tablename}.{column_name}" def get_full_name( - self, with_alias: bool = True, include_quotes: bool = True + self, + with_alias: bool = True, + include_quotes: bool = True, ) -> str: """ Returns the full column name, taking into account joins. @@ -302,11 +303,10 @@ def get_full_name( >>> column._meta.get_full_name(include_quotes=False) 'my_table_name.my_column_name' - """ full_name = self._get_path(include_quotes=include_quotes) - if with_alias and self.call_chain: + if with_alias: alias = self.get_default_alias() if include_quotes: full_name += f' AS "{alias}"' @@ -346,32 +346,6 @@ def __deepcopy__(self, memo) -> ColumnMeta: return self.copy() -class Selectable(metaclass=ABCMeta): - """ - Anything which inherits from this can be used in a select query. - """ - - _alias: t.Optional[str] - - @abstractmethod - def get_select_string( - self, engine_type: str, with_alias: bool = True - ) -> str: - """ - In a query, what to output after the select statement - could be a - column name, a sub query, a function etc. For a column it will be the - column name. - """ - raise NotImplementedError() - - def as_alias(self, alias: str) -> Selectable: - """ - Allows column names to be changed in the result of a select. - """ - self._alias = alias - return self - - class Column(Selectable): """ All other columns inherit from ``Column``. Don't use it directly. @@ -822,25 +796,32 @@ def get_default_value(self) -> t.Any: def get_select_string( self, engine_type: str, with_alias: bool = True - ) -> str: + ) -> QueryString: """ How to refer to this column in a SQL query, taking account of any joins and aliases. """ + if with_alias: if self._alias: original_name = self._meta.get_full_name( with_alias=False, ) - return f'{original_name} AS "{self._alias}"' + return QueryString(f'{original_name} AS "{self._alias}"') else: - return self._meta.get_full_name( - with_alias=True, + return QueryString( + self._meta.get_full_name( + with_alias=True, + ) ) - return self._meta.get_full_name(with_alias=False) + return QueryString( + self._meta.get_full_name( + with_alias=False, + ) + ) - def get_where_string(self, engine_type: str) -> str: + def get_where_string(self, engine_type: str) -> QueryString: return self.get_select_string( engine_type=engine_type, with_alias=False ) @@ -902,6 +883,13 @@ def get_sql_value(self, value: t.Any) -> t.Any: def column_type(self): return self.__class__.__name__.upper() + @property + def table_alias(self) -> str: + return "$".join( + f"{_key._meta.table._meta.tablename}${_key._meta.name}" + for _key in [*self._meta.call_chain, self] + ) + @property def ddl(self) -> str: """ @@ -945,8 +933,8 @@ def ddl(self) -> str: return query - def copy(self) -> Column: - column: Column = copy.copy(self) + def copy(self: Self) -> Self: + column = copy.copy(self) column._meta = self._meta.copy() return column @@ -971,3 +959,6 @@ def __repr__(self): f"{table_class_name}.{self._meta.name} - " f"{self.__class__.__name__}" ) + + +Self = t.TypeVar("Self", bound=Column) diff --git a/piccolo/columns/column_types.py b/piccolo/columns/column_types.py index 2afcfb741..add0c6f5c 100644 --- a/piccolo/columns/column_types.py +++ b/piccolo/columns/column_types.py @@ -60,7 +60,7 @@ class Band(Table): from piccolo.columns.operators.comparison import ArrayAll, ArrayAny from piccolo.columns.operators.string import Concat from piccolo.columns.reference import LazyTableReference -from piccolo.querystring import QueryString, Unquoted +from piccolo.querystring import QueryString from piccolo.utils.encoding import dump_json from piccolo.utils.warnings import colored_warning @@ -752,8 +752,8 @@ def __set__(self, obj, value: t.Union[int, None]): ############################################################################### -DEFAULT = Unquoted("DEFAULT") -NULL = Unquoted("null") +DEFAULT = QueryString("DEFAULT") +NULL = QueryString("null") class Serial(Column): @@ -778,7 +778,7 @@ def default(self): if engine_type == "postgres": return DEFAULT elif engine_type == "cockroach": - return Unquoted("unique_rowid()") + return QueryString("unique_rowid()") elif engine_type == "sqlite": return NULL raise Exception("Unrecognized engine type") @@ -2194,6 +2194,7 @@ def __getattribute__(self, name: str) -> t.Union[Column, t.Any]: column_meta: ColumnMeta = object.__getattribute__(self, "_meta") new_column._meta.call_chain = column_meta.call_chain.copy() + new_column._meta.call_chain.append(self) return new_column else: @@ -2311,7 +2312,7 @@ def arrow(self, key: str) -> JSONB: def get_select_string( self, engine_type: str, with_alias: bool = True - ) -> str: + ) -> QueryString: select_string = self._meta.get_full_name(with_alias=False) if self.json_operator is not None: @@ -2321,7 +2322,7 @@ def get_select_string( alias = self._alias or self._meta.get_default_alias() select_string += f' AS "{alias}"' - return select_string + return QueryString(select_string) def eq(self, value) -> Where: """ @@ -2616,7 +2617,9 @@ def __getitem__(self, value: int) -> Array: else: raise ValueError("Only integers can be used for indexing.") - def get_select_string(self, engine_type: str, with_alias=True) -> str: + def get_select_string( + self, engine_type: str, with_alias=True + ) -> QueryString: select_string = self._meta.get_full_name(with_alias=False) if isinstance(self.index, int): @@ -2626,7 +2629,7 @@ def get_select_string(self, engine_type: str, with_alias=True) -> str: alias = self._alias or self._meta.get_default_alias() select_string += f' AS "{alias}"' - return select_string + return QueryString(select_string) def any(self, value: t.Any) -> Where: """ diff --git a/piccolo/columns/m2m.py b/piccolo/columns/m2m.py index 0eefd22e7..90469fc1f 100644 --- a/piccolo/columns/m2m.py +++ b/piccolo/columns/m2m.py @@ -4,7 +4,6 @@ import typing as t from dataclasses import dataclass -from piccolo.columns.base import Selectable from piccolo.columns.column_types import ( JSON, JSONB, @@ -12,6 +11,7 @@ ForeignKey, LazyTableReference, ) +from piccolo.querystring import QueryString, Selectable from piccolo.utils.list import flatten from piccolo.utils.sync import run_sync @@ -56,7 +56,9 @@ def __init__( for column in columns ) - def get_select_string(self, engine_type: str, with_alias=True) -> str: + def get_select_string( + self, engine_type: str, with_alias=True + ) -> QueryString: m2m_table_name_with_schema = ( self.m2m._meta.resolved_joining_table._meta.get_formatted_tablename() # noqa: E501 ) # noqa: E501 @@ -90,28 +92,33 @@ def get_select_string(self, engine_type: str, with_alias=True) -> str: if engine_type in ("postgres", "cockroach"): if self.as_list: column_name = self.columns[0]._meta.db_column_name - return f""" + return QueryString( + f""" ARRAY( SELECT "inner_{table_2_name}"."{column_name}" FROM {inner_select} ) AS "{m2m_relationship_name}" """ + ) elif not self.serialisation_safe: column_name = table_2_pk_name - return f""" + return QueryString( + f""" ARRAY( SELECT "inner_{table_2_name}"."{column_name}" FROM {inner_select} ) AS "{m2m_relationship_name}" """ + ) else: column_names = ", ".join( f'"inner_{table_2_name}"."{column._meta.db_column_name}"' for column in self.columns ) - return f""" + return QueryString( + f""" ( SELECT JSON_AGG({m2m_relationship_name}_results) FROM ( @@ -119,13 +126,15 @@ def get_select_string(self, engine_type: str, with_alias=True) -> str: ) AS "{m2m_relationship_name}_results" ) AS "{m2m_relationship_name}" """ + ) elif engine_type == "sqlite": if len(self.columns) > 1 or not self.serialisation_safe: column_name = table_2_pk_name else: column_name = self.columns[0]._meta.db_column_name - return f""" + return QueryString( + f""" ( SELECT group_concat( "inner_{table_2_name}"."{column_name}" @@ -134,6 +143,7 @@ def get_select_string(self, engine_type: str, with_alias=True) -> str: ) AS "{m2m_relationship_name} [M2M]" """ + ) else: raise ValueError(f"{engine_type} is an unrecognised engine type") diff --git a/piccolo/columns/readable.py b/piccolo/columns/readable.py index 2748648d8..ebd32bf51 100644 --- a/piccolo/columns/readable.py +++ b/piccolo/columns/readable.py @@ -3,7 +3,7 @@ import typing as t from dataclasses import dataclass -from piccolo.columns.base import Selectable +from piccolo.querystring import QueryString, Selectable if t.TYPE_CHECKING: # pragma: no cover from piccolo.columns.base import Column @@ -27,25 +27,27 @@ def _columns_string(self) -> str: i._meta.get_full_name(with_alias=False) for i in self.columns ) - def _get_string(self, operator: str) -> str: - return ( + def _get_string(self, operator: str) -> QueryString: + return QueryString( f"{operator}('{self.template}', {self._columns_string}) AS " f"{self.output_name}" ) @property - def sqlite_string(self) -> str: + def sqlite_string(self) -> QueryString: return self._get_string(operator="PRINTF") @property - def postgres_string(self) -> str: + def postgres_string(self) -> QueryString: return self._get_string(operator="FORMAT") @property - def cockroach_string(self) -> str: + def cockroach_string(self) -> QueryString: return self._get_string(operator="FORMAT") - def get_select_string(self, engine_type: str, with_alias=True) -> str: + def get_select_string( + self, engine_type: str, with_alias=True + ) -> QueryString: try: return getattr(self, f"{engine_type}_string") except AttributeError as e: diff --git a/piccolo/query/__init__.py b/piccolo/query/__init__.py index 000a47e76..2fcc2df7e 100644 --- a/piccolo/query/__init__.py +++ b/piccolo/query/__init__.py @@ -1,9 +1,9 @@ from piccolo.columns.combination import WhereRaw from .base import Query +from .functions.aggregate import Avg, Max, Min, Sum from .methods import ( Alter, - Avg, Count, Create, CreateIndex, @@ -11,12 +11,9 @@ DropIndex, Exists, Insert, - Max, - Min, Objects, Raw, Select, - Sum, TableExists, Update, ) diff --git a/piccolo/query/functions/__init__.py b/piccolo/query/functions/__init__.py new file mode 100644 index 000000000..d0195cc40 --- /dev/null +++ b/piccolo/query/functions/__init__.py @@ -0,0 +1,16 @@ +from .aggregate import Avg, Count, Max, Min, Sum +from .string import Length, Lower, Ltrim, Reverse, Rtrim, Upper + +__all__ = ( + "Avg", + "Count", + "Length", + "Lower", + "Ltrim", + "Max", + "Min", + "Reverse", + "Rtrim", + "Sum", + "Upper", +) diff --git a/piccolo/query/functions/aggregate.py b/piccolo/query/functions/aggregate.py new file mode 100644 index 000000000..61dd36a46 --- /dev/null +++ b/piccolo/query/functions/aggregate.py @@ -0,0 +1,179 @@ +import typing as t + +from piccolo.columns.base import Column +from piccolo.querystring import QueryString + +from .base import Function + + +class Avg(Function): + """ + ``AVG()`` SQL function. Column type must be numeric to run the query. + + .. code-block:: python + + await Band.select(Avg(Band.popularity)).run() + + # We can use an alias. These two are equivalent: + + await Band.select( + Avg(Band.popularity, alias="popularity_avg") + ).run() + + await Band.select( + Avg(Band.popularity).as_alias("popularity_avg") + ).run() + + """ + + function_name = "AVG" + + +class Count(QueryString): + """ + Used in ``Select`` queries, usually in conjunction with the ``group_by`` + clause:: + + >>> await Band.select( + ... Band.manager.name.as_alias('manager_name'), + ... Count(alias='band_count') + ... ).group_by(Band.manager) + [{'manager_name': 'Guido', 'count': 1}, ...] + + It can also be used without the ``group_by`` clause (though you may prefer + to the :meth:`Table.count ` method instead, as + it's more convenient):: + + >>> await Band.select(Count()) + [{'count': 3}] + + """ + + def __init__( + self, + column: t.Optional[Column] = None, + distinct: t.Optional[t.Sequence[Column]] = None, + alias: str = "count", + ): + """ + :param column: + If specified, the count is for non-null values in that column. + :param distinct: + If specified, the count is for distinct values in those columns. + :param alias: + The name of the value in the response:: + + # These two are equivalent: + + await Band.select( + Band.name, Count(alias="total") + ).group_by(Band.name) + + await Band.select( + Band.name, + Count().as_alias("total") + ).group_by(Band.name) + + """ + if distinct and column: + raise ValueError("Only specify `column` or `distinct`") + + if distinct: + engine_type = distinct[0]._meta.engine_type + if engine_type == "sqlite": + # SQLite doesn't allow us to specify multiple columns, so + # instead we concatenate the values. + column_names = " || ".join("{}" for _ in distinct) + else: + column_names = ", ".join("{}" for _ in distinct) + + return super().__init__( + f"COUNT(DISTINCT({column_names}))", *distinct, alias=alias + ) + else: + if column: + return super().__init__("COUNT({})", column, alias=alias) + else: + return super().__init__("COUNT(*)", alias=alias) + + +class Min(Function): + """ + ``MIN()`` SQL function. + + .. code-block:: python + + await Band.select(Min(Band.popularity)).run() + + # We can use an alias. These two are equivalent: + + await Band.select( + Min(Band.popularity, alias="popularity_min") + ).run() + + await Band.select( + Min(Band.popularity).as_alias("popularity_min") + ).run() + + """ + + function_name = "MIN" + + +class Max(Function): + """ + ``MAX()`` SQL function. + + .. code-block:: python + + await Band.select( + Max(Band.popularity) + ).run() + + # We can use an alias. These two are equivalent: + + await Band.select( + Max(Band.popularity, alias="popularity_max") + ).run() + + await Band.select( + Max(Band.popularity).as_alias("popularity_max") + ).run() + + """ + + function_name = "MAX" + + +class Sum(Function): + """ + ``SUM()`` SQL function. Column type must be numeric to run the query. + + .. code-block:: python + + await Band.select( + Sum(Band.popularity) + ).run() + + # We can use an alias. These two are equivalent: + + await Band.select( + Sum(Band.popularity, alias="popularity_sum") + ).run() + + await Band.select( + Sum(Band.popularity).as_alias("popularity_sum") + ).run() + + """ + + function_name = "SUM" + + +__all__ = ( + "Avg", + "Count", + "Min", + "Max", + "Sum", +) diff --git a/piccolo/query/functions/base.py b/piccolo/query/functions/base.py new file mode 100644 index 000000000..c4181aca6 --- /dev/null +++ b/piccolo/query/functions/base.py @@ -0,0 +1,21 @@ +import typing as t + +from piccolo.columns.base import Column +from piccolo.querystring import QueryString + + +class Function(QueryString): + function_name: str + + def __init__( + self, + identifier: t.Union[Column, QueryString, str], + alias: t.Optional[str] = None, + ): + alias = alias or self.__class__.__name__.lower() + + super().__init__( + f"{self.function_name}({{}})", + identifier, + alias=alias, + ) diff --git a/piccolo/query/functions/string.py b/piccolo/query/functions/string.py new file mode 100644 index 000000000..556817a12 --- /dev/null +++ b/piccolo/query/functions/string.py @@ -0,0 +1,73 @@ +""" +These functions mirror their counterparts in the Postgresql docs: + +https://www.postgresql.org/docs/current/functions-string.html + +""" + +from .base import Function + + +class Length(Function): + """ + Returns the number of characters in the string. + """ + + function_name = "LENGTH" + + +class Lower(Function): + """ + Converts the string to all lower case, according to the rules of the + database's locale. + """ + + function_name = "LOWER" + + +class Ltrim(Function): + """ + Removes the longest string containing only characters in characters (a + space by default) from the start of string. + """ + + function_name = "LTRIM" + + +class Reverse(Function): + """ + Return reversed string. + + Not supported in SQLite. + + """ + + function_name = "REVERSE" + + +class Rtrim(Function): + """ + Removes the longest string containing only characters in characters (a + space by default) from the end of string. + """ + + function_name = "RTRIM" + + +class Upper(Function): + """ + Converts the string to all upper case, according to the rules of the + database's locale. + """ + + function_name = "UPPER" + + +__all__ = ( + "Length", + "Lower", + "Ltrim", + "Reverse", + "Rtrim", + "Upper", +) diff --git a/piccolo/query/methods/__init__.py b/piccolo/query/methods/__init__.py index 6c1854381..f4b9a59f1 100644 --- a/piccolo/query/methods/__init__.py +++ b/piccolo/query/methods/__init__.py @@ -9,6 +9,23 @@ from .objects import Objects from .raw import Raw from .refresh import Refresh -from .select import Avg, Max, Min, Select, Sum +from .select import Select from .table_exists import TableExists from .update import Update + +__all__ = ( + "Alter", + "Count", + "Create", + "CreateIndex", + "Delete", + "DropIndex", + "Exists", + "Insert", + "Objects", + "Raw", + "Refresh", + "Select", + "TableExists", + "Update", +) diff --git a/piccolo/query/methods/count.py b/piccolo/query/methods/count.py index fdd0972cf..99d46c39b 100644 --- a/piccolo/query/methods/count.py +++ b/piccolo/query/methods/count.py @@ -4,7 +4,7 @@ from piccolo.custom_types import Combinable from piccolo.query.base import Query -from piccolo.query.methods.select import Count as SelectCount +from piccolo.query.functions.aggregate import Count as CountFunction from piccolo.query.mixins import WhereDelegate from piccolo.querystring import QueryString @@ -32,7 +32,7 @@ def __init__( ########################################################################### # Clauses - def where(self: Self, *where: Combinable) -> Self: + def where(self: Self, *where: t.Union[Combinable, QueryString]) -> Self: self.where_delegate.where(*where) return self @@ -50,7 +50,7 @@ def default_querystrings(self) -> t.Sequence[QueryString]: table: t.Type[Table] = self.table query = table.select( - SelectCount(column=self.column, distinct=self._distinct) + CountFunction(column=self.column, distinct=self._distinct) ) query.where_delegate._where = self.where_delegate._where diff --git a/piccolo/query/methods/delete.py b/piccolo/query/methods/delete.py index bc0746063..628b89b8e 100644 --- a/piccolo/query/methods/delete.py +++ b/piccolo/query/methods/delete.py @@ -30,7 +30,7 @@ def __init__(self, table: t.Type[Table], force: bool = False, **kwargs): self.returning_delegate = ReturningDelegate() self.where_delegate = WhereDelegate() - def where(self: Self, *where: Combinable) -> Self: + def where(self: Self, *where: t.Union[Combinable, QueryString]) -> Self: self.where_delegate.where(*where) return self diff --git a/piccolo/query/methods/exists.py b/piccolo/query/methods/exists.py index 26d25e03e..7fac83a75 100644 --- a/piccolo/query/methods/exists.py +++ b/piccolo/query/methods/exists.py @@ -16,7 +16,7 @@ def __init__(self, table: t.Type[TableInstance], **kwargs): super().__init__(table, **kwargs) self.where_delegate = WhereDelegate() - def where(self: Self, *where: Combinable) -> Self: + def where(self: Self, *where: t.Union[Combinable, QueryString]) -> Self: self.where_delegate.where(*where) return self diff --git a/piccolo/query/methods/objects.py b/piccolo/query/methods/objects.py index 7b8c3ad43..f11f78e8e 100644 --- a/piccolo/query/methods/objects.py +++ b/piccolo/query/methods/objects.py @@ -262,7 +262,7 @@ def order_by( self.order_by_delegate.order_by(*_columns, ascending=ascending) return self - def where(self: Self, *where: Combinable) -> Self: + def where(self: Self, *where: t.Union[Combinable, QueryString]) -> Self: self.where_delegate.where(*where) return self diff --git a/piccolo/query/methods/select.py b/piccolo/query/methods/select.py index a2a77b155..fdb929f8a 100644 --- a/piccolo/query/methods/select.py +++ b/piccolo/query/methods/select.py @@ -1,6 +1,5 @@ from __future__ import annotations -import decimal import itertools import typing as t from collections import OrderedDict @@ -36,9 +35,8 @@ from piccolo.custom_types import Combinable from piccolo.table import Table # noqa - -def is_numeric_column(column: Column) -> bool: - return column.value_type in (int, decimal.Decimal, float) +# Here to avoid breaking changes - will be removed in the future. +from piccolo.query.functions.aggregate import Count # noqa: F401 class SelectRaw(Selectable): @@ -59,224 +57,8 @@ def __init__(self, sql: str, *args: t.Any) -> None: def get_select_string( self, engine_type: str, with_alias: bool = True - ) -> str: - return self.querystring.__str__() - - -class Avg(Selectable): - """ - ``AVG()`` SQL function. Column type must be numeric to run the query. - - .. code-block:: python - - await Band.select(Avg(Band.popularity)).run() - - # We can use an alias. These two are equivalent: - - await Band.select( - Avg(Band.popularity, alias="popularity_avg") - ).run() - - await Band.select( - Avg(Band.popularity).as_alias("popularity_avg") - ).run() - - """ - - def __init__(self, column: Column, alias: str = "avg"): - if is_numeric_column(column): - self.column = column - else: - raise ValueError("Column type must be numeric to run the query.") - self._alias = alias - - def get_select_string( - self, engine_type: str, with_alias: bool = True - ) -> str: - column_name = self.column._meta.get_full_name(with_alias=False) - return f'AVG({column_name}) AS "{self._alias}"' - - -class Count(Selectable): - """ - Used in ``Select`` queries, usually in conjunction with the ``group_by`` - clause:: - - >>> await Band.select( - ... Band.manager.name.as_alias('manager_name'), - ... Count(alias='band_count') - ... ).group_by(Band.manager) - [{'manager_name': 'Guido', 'count': 1}, ...] - - It can also be used without the ``group_by`` clause (though you may prefer - to the :meth:`Table.count ` method instead, as - it's more convenient):: - - >>> await Band.select(Count()) - [{'count': 3}] - - """ - - def __init__( - self, - column: t.Optional[Column] = None, - distinct: t.Optional[t.Sequence[Column]] = None, - alias: str = "count", - ): - """ - :param column: - If specified, the count is for non-null values in that column. - :param distinct: - If specified, the count is for distinct values in those columns. - :param alias: - The name of the value in the response:: - - # These two are equivalent: - - await Band.select( - Band.name, Count(alias="total") - ).group_by(Band.name) - - await Band.select( - Band.name, - Count().as_alias("total") - ).group_by(Band.name) - - """ - if distinct and column: - raise ValueError("Only specify `column` or `distinct`") - - self.column = column - self.distinct = distinct - self._alias = alias - - def get_select_string( - self, engine_type: str, with_alias: bool = True - ) -> str: - expression: str - - if self.distinct: - if engine_type == "sqlite": - # SQLite doesn't allow us to specify multiple columns, so - # instead we concatenate the values. - column_names = " || ".join( - i._meta.get_full_name(with_alias=False) - for i in self.distinct - ) - else: - column_names = ", ".join( - i._meta.get_full_name(with_alias=False) - for i in self.distinct - ) - - expression = f"DISTINCT ({column_names})" - else: - if self.column: - expression = self.column._meta.get_full_name(with_alias=False) - else: - expression = "*" - - return f'COUNT({expression}) AS "{self._alias}"' - - -class Max(Selectable): - """ - ``MAX()`` SQL function. - - .. code-block:: python - - await Band.select( - Max(Band.popularity) - ).run() - - # We can use an alias. These two are equivalent: - - await Band.select( - Max(Band.popularity, alias="popularity_max") - ).run() - - await Band.select( - Max(Band.popularity).as_alias("popularity_max") - ).run() - - """ - - def __init__(self, column: Column, alias: str = "max"): - self.column = column - self._alias = alias - - def get_select_string( - self, engine_type: str, with_alias: bool = True - ) -> str: - column_name = self.column._meta.get_full_name(with_alias=False) - return f'MAX({column_name}) AS "{self._alias}"' - - -class Min(Selectable): - """ - ``MIN()`` SQL function. - - .. code-block:: python - - await Band.select(Min(Band.popularity)).run() - - # We can use an alias. These two are equivalent: - - await Band.select( - Min(Band.popularity, alias="popularity_min") - ).run() - - await Band.select( - Min(Band.popularity).as_alias("popularity_min") - ).run() - - """ - - def __init__(self, column: Column, alias: str = "min"): - self.column = column - self._alias = alias - - def get_select_string( - self, engine_type: str, with_alias: bool = True - ) -> str: - column_name = self.column._meta.get_full_name(with_alias=False) - return f'MIN({column_name}) AS "{self._alias}"' - - -class Sum(Selectable): - """ - ``SUM()`` SQL function. Column type must be numeric to run the query. - - .. code-block:: python - - await Band.select( - Sum(Band.popularity) - ).run() - - # We can use an alias. These two are equivalent: - - await Band.select( - Sum(Band.popularity, alias="popularity_sum") - ).run() - - await Band.select( - Sum(Band.popularity).as_alias("popularity_sum") - ).run() - - """ - - def __init__(self, column: Column, alias: str = "sum"): - if is_numeric_column(column): - self.column = column - else: - raise ValueError("Column type must be numeric to run the query.") - self._alias = alias - - def get_select_string( - self, engine_type: str, with_alias: bool = True - ) -> str: - column_name = self.column._meta.get_full_name(with_alias=False) - return f'SUM({column_name}) AS "{self._alias}"' + ) -> QueryString: + return self.querystring OptionalDict = t.Optional[t.Dict[str, t.Any]] @@ -645,7 +427,7 @@ def callback( self.callback_delegate.callback(callbacks, on=on) return self - def where(self: Self, *where: Combinable) -> Self: + def where(self: Self, *where: t.Union[Combinable, QueryString]) -> Self: self.where_delegate.where(*where) return self @@ -678,23 +460,25 @@ def _get_joins(self, columns: t.Sequence[Selectable]) -> t.List[str]: for readable in readables: columns += readable.columns + querystrings: t.List[QueryString] = [ + i for i in columns if isinstance(i, QueryString) + ] + for querystring in querystrings: + if querystring_columns := getattr(querystring, "columns", []): + columns += querystring_columns + for column in columns: if not isinstance(column, Column): continue _joins: t.List[str] = [] for index, key in enumerate(column._meta.call_chain, 0): - table_alias = "$".join( - f"{_key._meta.table._meta.tablename}${_key._meta.name}" - for _key in column._meta.call_chain[: index + 1] - ) - - key._meta.table_alias = table_alias + table_alias = key.table_alias if index > 0: left_tablename = column._meta.call_chain[ index - 1 - ]._meta.table_alias + ].table_alias else: left_tablename = ( key._meta.table._meta.get_formatted_tablename() @@ -761,11 +545,10 @@ def default_querystrings(self) -> t.Sequence[QueryString]: engine_type = self.table._meta.db.engine_type - select_strings: t.List[str] = [ + select_strings: t.List[QueryString] = [ c.get_select_string(engine_type=engine_type) for c in self.columns_delegate.selected_columns ] - columns_str = ", ".join(select_strings) ####################################################################### @@ -779,7 +562,9 @@ def default_querystrings(self) -> t.Sequence[QueryString]: query += "{}" args.append(distinct.querystring) + columns_str = ", ".join("{}" for i in select_strings) query += f" {columns_str} FROM {self.table._meta.get_formatted_tablename()}" # noqa: E501 + args.extend(select_strings) for join in joins: query += f" {join}" diff --git a/piccolo/query/methods/update.py b/piccolo/query/methods/update.py index ff6a10589..f75854c43 100644 --- a/piccolo/query/methods/update.py +++ b/piccolo/query/methods/update.py @@ -50,7 +50,7 @@ def values( self.values_delegate.values(values) return self - def where(self, *where: Combinable) -> Update: + def where(self, *where: t.Union[Combinable, QueryString]) -> Update: self.where_delegate.where(*where) return self diff --git a/piccolo/query/mixins.py b/piccolo/query/mixins.py index 8d7c6a4a9..214d1b8d7 100644 --- a/piccolo/query/mixins.py +++ b/piccolo/query/mixins.py @@ -9,13 +9,14 @@ from piccolo.columns import And, Column, Or, Where from piccolo.columns.column_types import ForeignKey +from piccolo.columns.combination import WhereRaw from piccolo.custom_types import Combinable from piccolo.querystring import QueryString from piccolo.utils.list import flatten from piccolo.utils.sql_values import convert_to_sql_value if t.TYPE_CHECKING: # pragma: no cover - from piccolo.columns.base import Selectable + from piccolo.querystring import Selectable from piccolo.table import Table # noqa @@ -254,8 +255,10 @@ def _extract_columns(self, combinable: Combinable): elif isinstance(combinable, (And, Or)): self._extract_columns(combinable.first) self._extract_columns(combinable.second) + elif isinstance(combinable, WhereRaw): + self._where_columns.extend(combinable.querystring.columns) - def where(self, *where: Combinable): + def where(self, *where: t.Union[Combinable, QueryString]): for arg in where: if isinstance(arg, bool): raise ValueError( @@ -265,6 +268,10 @@ def where(self, *where: Combinable): "`.where(MyTable.some_column.is_null())`." ) + if isinstance(arg, QueryString): + # If a raw QueryString is passed in. + arg = WhereRaw(arg.template, *arg.args) + self._where = And(self._where, arg) if self._where else arg diff --git a/piccolo/querystring.py b/piccolo/querystring.py index 3c23d86dc..7f3f3e42a 100644 --- a/piccolo/querystring.py +++ b/piccolo/querystring.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as t +from abc import ABCMeta, abstractmethod from dataclasses import dataclass from datetime import datetime from importlib.util import find_spec @@ -8,6 +9,7 @@ if t.TYPE_CHECKING: # pragma: no cover from piccolo.table import Table + from piccolo.columns import Column from uuid import UUID @@ -17,22 +19,32 @@ apgUUID = UUID -@dataclass -class Unquoted: +class Selectable(metaclass=ABCMeta): """ - Used when we want the value to be unquoted because it's a Postgres - keyword - for example DEFAULT. + Anything which inherits from this can be used in a select query. """ - __slots__ = ("value",) + __slots__ = ("_alias",) - value: str + _alias: t.Optional[str] - def __repr__(self): - return f"{self.value}" + @abstractmethod + def get_select_string( + self, engine_type: str, with_alias: bool = True + ) -> QueryString: + """ + In a query, what to output after the select statement - could be a + column name, a sub query, a function etc. For a column it will be the + column name. + """ + raise NotImplementedError() - def __str__(self): - return f"{self.value}" + def as_alias(self, alias: str) -> Selectable: + """ + Allows column names to be changed in the result of a select. + """ + self._alias = alias + return self @dataclass @@ -42,7 +54,7 @@ class Fragment: no_arg: bool = False -class QueryString: +class QueryString(Selectable): """ When we're composing complex queries, we're combining QueryStrings, rather than concatenating strings directly. The reason for this is QueryStrings @@ -56,6 +68,7 @@ class QueryString: "query_type", "table", "_frozen_compiled_strings", + "columns", ) def __init__( @@ -64,6 +77,7 @@ def __init__( *args: t.Any, query_type: str = "generic", table: t.Optional[t.Type[Table]] = None, + alias: t.Optional[str] = None, ) -> None: """ :param template: @@ -83,12 +97,42 @@ def __init__( """ self.template = template - self.args = args self.query_type = query_type self.table = table self._frozen_compiled_strings: t.Optional[ t.Tuple[str, t.List[t.Any]] ] = None + self._alias = alias + self.args, self.columns = self.process_args(args) + + def process_args( + self, args: t.Sequence[t.Any] + ) -> t.Tuple[t.Sequence[t.Any], t.Sequence[Column]]: + """ + If a Column is passed in, we convert it to the name of the column + (including joins). + """ + from piccolo.columns import Column + + processed_args = [] + columns = [] + + for arg in args: + if isinstance(arg, Column): + columns.append(arg) + arg = QueryString( + f"{arg._meta.get_full_name(with_alias=False)}" + ) + elif isinstance(arg, QueryString): + columns.extend(arg.columns) + + processed_args.append(arg) + + return (processed_args, columns) + + def as_alias(self, alias: str) -> QueryString: + self._alias = alias + return self def __str__(self): """ @@ -143,7 +187,7 @@ def bundle( fragment.no_arg = True bundled.append(fragment) else: - if isinstance(value, self.__class__): + if isinstance(value, QueryString): fragment.no_arg = True bundled.append(fragment) @@ -195,3 +239,47 @@ def freeze(self, engine_type: str = "postgres"): self._frozen_compiled_strings = self.compile_string( engine_type=engine_type ) + + ########################################################################### + + def get_select_string( + self, engine_type: str, with_alias: bool = True + ) -> QueryString: + if with_alias and self._alias: + return QueryString("{} AS " + self._alias, self) + else: + return self + + def get_where_string(self, engine_type: str) -> QueryString: + return self.get_select_string( + engine_type=engine_type, with_alias=False + ) + + ########################################################################### + # Basic logic + + def __eq__(self, value) -> QueryString: # type: ignore[override] + return QueryString("{} = {}", self, value) + + def __ne__(self, value) -> QueryString: # type: ignore[override] + return QueryString("{} != {}", self, value) + + def __add__(self, value) -> QueryString: + return QueryString("{} + {}", self, value) + + def __sub__(self, value) -> QueryString: + return QueryString("{} - {}", self, value) + + def is_in(self, value) -> QueryString: + return QueryString("{} IN {}", self, value) + + def not_in(self, value) -> QueryString: + return QueryString("{} NOT IN {}", self, value) + + +class Unquoted(QueryString): + """ + This is deprecated - just use QueryString directly. + """ + + pass diff --git a/piccolo/table.py b/piccolo/table.py index b4fcbf942..7882db95e 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -48,7 +48,7 @@ from piccolo.query.methods.indexes import Indexes from piccolo.query.methods.objects import First from piccolo.query.methods.refresh import Refresh -from piccolo.querystring import QueryString, Unquoted +from piccolo.querystring import QueryString from piccolo.utils import _camel_to_snake from piccolo.utils.graphlib import TopologicalSorter from piccolo.utils.sql_values import convert_to_sql_value @@ -56,7 +56,7 @@ from piccolo.utils.warnings import colored_warning if t.TYPE_CHECKING: # pragma: no cover - from piccolo.columns import Selectable + from piccolo.querystring import Selectable PROTECTED_TABLENAMES = ("user",) TABLENAME_WARNING = ( @@ -796,30 +796,14 @@ def querystring(self) -> QueryString: """ Used when inserting rows. """ - args_dict = {} - for col in self._meta.columns: - column_name = col._meta.name - value = convert_to_sql_value(value=self[column_name], column=col) - args_dict[column_name] = value - - def is_unquoted(arg): - return isinstance(arg, Unquoted) - - # Strip out any args which are unquoted. - filtered_args = [i for i in args_dict.values() if not is_unquoted(i)] + args = [ + convert_to_sql_value(value=self[column._meta.name], column=column) + for column in self._meta.columns + ] # If unquoted, dump it straight into the query. - query = ",".join( - [ - ( - args_dict[column._meta.name].value - if is_unquoted(args_dict[column._meta.name]) - else "{}" - ) - for column in self._meta.columns - ] - ) - return QueryString(f"({query})", *filtered_args) + query = ",".join(["{}" for _ in args]) + return QueryString(f"({query})", *args) def __str__(self) -> str: return self.querystring.__str__() diff --git a/tests/query/test_functions.py b/tests/query/test_functions.py new file mode 100644 index 000000000..abe9a5f01 --- /dev/null +++ b/tests/query/test_functions.py @@ -0,0 +1,102 @@ +from unittest import TestCase + +from piccolo.query.functions.string import Reverse, Upper +from piccolo.querystring import QueryString +from piccolo.table import create_db_tables_sync, drop_db_tables_sync +from tests.base import engines_skip +from tests.example_apps.music.tables import Band, Manager + + +class FunctionTest(TestCase): + tables = (Band, Manager) + + def setUp(self) -> None: + create_db_tables_sync(*self.tables) + + manager = Manager({Manager.name: "Guido"}) + manager.save().run_sync() + + band = Band({Band.name: "Pythonistas", Band.manager: manager}) + band.save().run_sync() + + def tearDown(self) -> None: + drop_db_tables_sync(*self.tables) + + +class TestUpperFunction(FunctionTest): + + def test_column(self): + """ + Make sure we can uppercase a column's value. + """ + response = Band.select(Upper(Band.name)).run_sync() + self.assertListEqual(response, [{"upper": "PYTHONISTAS"}]) + + def test_alias(self): + response = Band.select(Upper(Band.name, alias="name")).run_sync() + self.assertListEqual(response, [{"name": "PYTHONISTAS"}]) + + def test_joined_column(self): + """ + Make sure we can uppercase a column's value from a joined table. + """ + response = Band.select(Upper(Band.manager._.name)).run_sync() + self.assertListEqual(response, [{"upper": "GUIDO"}]) + + +@engines_skip("sqlite") +class TestNested(FunctionTest): + """ + Skip the the test for SQLite, as it doesn't support ``Reverse``. + """ + + def test_nested(self): + """ + Make sure we can nest functions. + """ + response = Band.select(Upper(Reverse(Band.name))).run_sync() + self.assertListEqual(response, [{"upper": "SATSINOHTYP"}]) + + def test_nested_with_joined_column(self): + """ + Make sure nested functions can be used on a column from a joined table. + """ + response = Band.select(Upper(Reverse(Band.manager._.name))).run_sync() + self.assertListEqual(response, [{"upper": "ODIUG"}]) + + def test_nested_within_querystring(self): + """ + If we wrap a function in a custom QueryString - make sure the columns + are still accessible, so joins are successful. + """ + response = Band.select( + QueryString("CONCAT({}, '!')", Upper(Band.manager._.name)), + ).run_sync() + + self.assertListEqual(response, [{"concat": "GUIDO!"}]) + + +class TestWhereClause(FunctionTest): + + def test_where(self): + """ + Make sure where clauses work with functions. + """ + response = ( + Band.select(Band.name) + .where(Upper(Band.name) == "PYTHONISTAS") + .run_sync() + ) + self.assertListEqual(response, [{"name": "Pythonistas"}]) + + def test_where_with_joined_column(self): + """ + Make sure where clauses work with functions, when a joined column is + used. + """ + response = ( + Band.select(Band.name) + .where(Upper(Band.manager._.name) == "GUIDO") + .run_sync() + ) + self.assertListEqual(response, [{"name": "Pythonistas"}]) diff --git a/tests/table/test_select.py b/tests/table/test_select.py index a2bb86981..ebf2c3ff8 100644 --- a/tests/table/test_select.py +++ b/tests/table/test_select.py @@ -7,7 +7,8 @@ from piccolo.columns import Date, Varchar from piccolo.columns.combination import WhereRaw from piccolo.query import OrderByRaw -from piccolo.query.methods.select import Avg, Count, Max, Min, SelectRaw, Sum +from piccolo.query.functions.aggregate import Avg, Count, Max, Min, Sum +from piccolo.query.methods.select import SelectRaw from piccolo.query.mixins import DistinctOnError from piccolo.table import Table, create_db_tables_sync, drop_db_tables_sync from tests.base import ( @@ -927,14 +928,6 @@ def test_chain_different_functions_alias(self): self.assertEqual(float(response["popularity_avg"]), 1003.3333333333334) self.assertEqual(response["popularity_sum"], 3010) - def test_avg_validation(self): - with self.assertRaises(ValueError): - Band.select(Avg(Band.name)).run_sync() - - def test_sum_validation(self): - with self.assertRaises(ValueError): - Band.select(Sum(Band.name)).run_sync() - def test_columns(self): """ Make sure the colums method can be used to specify which columns to