diff --git a/pypika/clickhouse/array.py b/pypika/clickhouse/array.py index 40de67c5..20c422f6 100644 --- a/pypika/clickhouse/array.py +++ b/pypika/clickhouse/array.py @@ -1,5 +1,5 @@ import abc -from typing import Union +from typing import Union, Optional, TYPE_CHECKING from pypika.terms import ( Field, @@ -8,9 +8,14 @@ ) from pypika.utils import format_alias_sql +if TYPE_CHECKING: + from pypika.queries import Schema + class Array(Term): - def __init__(self, values: list, converter_cls=None, converter_options: dict = None, alias: str = None): + def __init__( + self, values: list, converter_cls=None, converter_options: Optional[dict] = None, alias: Optional[str] = None + ): super().__init__(alias) self._values = values self._converter_cls = converter_cls @@ -35,14 +40,14 @@ def __init__( self, left_array: Union[Array, Field], right_array: Union[Array, Field], - alias: str = None, - schema: str = None, + alias: Optional[str] = None, + schema: Optional["Schema"] = None, ): self._left_array = left_array self._right_array = right_array self.alias = alias self.schema = schema - self.args = tuple() + self.args = [] self.name = "hasAny" def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, dialect=None, **kwargs): @@ -57,7 +62,7 @@ def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, diale class _AbstractArrayFunction(Function, metaclass=abc.ABCMeta): - def __init__(self, array: Union[Array, Field], alias: str = None, schema: str = None): + def __init__(self, array: Union[Array, Field], alias: Optional[str] = None, schema: Optional["Schema"] = None): self.schema = schema self.alias = alias self.name = self.clickhouse_function() diff --git a/pypika/clickhouse/search_string.py b/pypika/clickhouse/search_string.py index 22a03027..5445f1b7 100644 --- a/pypika/clickhouse/search_string.py +++ b/pypika/clickhouse/search_string.py @@ -1,11 +1,13 @@ import abc +from typing import Optional + from pypika.terms import Function from pypika.utils import format_alias_sql class _AbstractSearchString(Function, metaclass=abc.ABCMeta): - def __init__(self, name, pattern: str, alias: str = None): + def __init__(self, name, pattern: str, alias: Optional[str] = None): super(_AbstractSearchString, self).__init__(self.clickhouse_function(), name, alias=alias) self._pattern = pattern @@ -50,7 +52,7 @@ def clickhouse_function(cls) -> str: class _AbstractMultiSearchString(Function, metaclass=abc.ABCMeta): - def __init__(self, name, patterns: list, alias: str = None): + def __init__(self, name, patterns: list, alias: Optional[str] = None): super(_AbstractMultiSearchString, self).__init__(self.clickhouse_function(), name, alias=alias) self._patterns = patterns diff --git a/pypika/clickhouse/type_conversion.py b/pypika/clickhouse/type_conversion.py index 80229b7e..a42ea6c6 100644 --- a/pypika/clickhouse/type_conversion.py +++ b/pypika/clickhouse/type_conversion.py @@ -2,22 +2,25 @@ Field, Function, ) +from pypika.queries import Schema from pypika.utils import format_alias_sql +from typing import Optional + class ToString(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToString, self).__init__("toString", name, alias=alias) class ToFixedString(Function): - def __init__(self, field, length: int, alias: str = None, schema: str = None): + def __init__(self, field, length: int, alias: Optional[str] = None, schema: Optional[Schema] = None): self._length = length self._field = field self.alias = alias self.name = "toFixedString" self.schema = schema - self.args = () + self.args = [] def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, dialect=None, **kwargs): sql = "{name}({field},{length})".format( @@ -29,60 +32,60 @@ def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, diale class ToInt8(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToInt8, self).__init__("toInt8", name, alias=alias) class ToInt16(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToInt16, self).__init__("toInt16", name, alias=alias) class ToInt32(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToInt32, self).__init__("toInt32", name, alias=alias) class ToInt64(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToInt64, self).__init__("toInt64", name, alias=alias) class ToUInt8(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToUInt8, self).__init__("toUInt8", name, alias=alias) class ToUInt16(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToUInt16, self).__init__("toUInt16", name, alias=alias) class ToUInt32(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToUInt32, self).__init__("toUInt32", name, alias=alias) class ToUInt64(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToUInt64, self).__init__("toUInt64", name, alias=alias) class ToFloat32(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToFloat32, self).__init__("toFloat32", name, alias=alias) class ToFloat64(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToFloat64, self).__init__("toFloat64", name, alias=alias) class ToDate(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToDate, self).__init__("toDate", name, alias=alias) class ToDateTime(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToDateTime, self).__init__("toDateTime", name, alias=alias) diff --git a/pypika/dialects.py b/pypika/dialects.py index ab3f1f20..5a2fb4e7 100644 --- a/pypika/dialects.py +++ b/pypika/dialects.py @@ -1,6 +1,9 @@ import itertools from copy import copy -from typing import Any, Iterable, List, NoReturn, Optional, Set, Union, Tuple as TypedTuple, cast +from typing import Any, Iterable, List, Optional, Set, Union, Tuple as TypedTuple, cast, TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import Self, NoReturn from pypika.enums import Dialects from pypika.queries import ( @@ -9,15 +12,24 @@ DropQueryBuilder, Selectable, Table, - Query, + BaseQuery, QueryBuilder, JoinOn, ) -from pypika.terms import ArithmeticExpression, Criterion, EmptyCriterion, Field, Function, Star, Term, ValueWrapper +from pypika.terms import ( + ArithmeticExpression, + Criterion, + EmptyCriterion, + Field, + Function, + Star, + Term, + ValueWrapper, +) from pypika.utils import QueryException, builder, format_quotes -class SnowflakeQuery(Query): +class SnowflakeQuery(BaseQuery["SnowflakeQueryBuilder"]): """ Defines a query class for use with Snowflake. """ @@ -61,7 +73,7 @@ def __init__(self) -> None: super().__init__(dialect=Dialects.SNOWFLAKE) -class MySQLQuery(Query): +class MySQLQuery(BaseQuery["MySQLQueryBuilder"]): """ Defines a query class for use with MySQL. """ @@ -97,8 +109,8 @@ def __init__(self, **kwargs: Any) -> None: self._for_update_skip_locked = False self._for_update_of: Set[str] = set() - def __copy__(self) -> "MySQLQueryBuilder": - newone = cast(MySQLQueryBuilder, super().__copy__()) + def __copy__(self) -> "Self": + newone = super().__copy__() newone._duplicate_updates = copy(self._duplicate_updates) newone._ignore_duplicates = copy(self._ignore_duplicates) return newone @@ -228,7 +240,7 @@ class MySQLDropQueryBuilder(DropQueryBuilder): QUOTE_CHAR = "`" -class VerticaQuery(Query): +class VerticaQuery(BaseQuery["VerticaQueryBuilder"]): """ Defines a query class for use with Vertica. """ @@ -350,7 +362,7 @@ def __str__(self) -> str: return self.get_sql() -class OracleQuery(Query): +class OracleQuery(BaseQuery["OracleQueryBuilder"]): """ Defines a query class for use with Oracle. """ @@ -374,7 +386,7 @@ def get_sql(self, *args: Any, **kwargs: Any) -> str: return super().get_sql(*args, **kwargs) -class PostgreSQLQuery(Query): +class PostgreSQLQuery(BaseQuery["PostgreSQLQueryBuilder"]): """ Defines a query class for use with PostgreSQL. """ @@ -406,8 +418,8 @@ def __init__(self, **kwargs: Any) -> None: self._for_update_skip_locked = False self._for_update_of: Set[str] = set() - def __copy__(self) -> "PostgreSQLQueryBuilder": - newone = cast(PostgreSQLQueryBuilder, super().__copy__()) + def __copy__(self) -> "Self": + newone = super().__copy__() newone._returns = copy(self._returns) newone._on_conflict_do_updates = copy(self._on_conflict_do_updates) return newone @@ -428,7 +440,7 @@ def for_update(self, nowait: bool = False, skip_locked: bool = False, of: TypedT self._for_update_of = set(of) @builder - def on_conflict(self, *target_fields: Union[str, Term]) -> None: + def on_conflict(self, *target_fields: Union[str, Term, None]) -> None: if not self._insert_table: raise QueryException("On conflict only applies to insert query") @@ -655,7 +667,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An return querystring -class RedshiftQuery(Query): +class RedshiftQuery(BaseQuery["RedShiftQueryBuilder"]): """ Defines a query class for use with Amazon Redshift. """ @@ -669,7 +681,7 @@ class RedShiftQueryBuilder(QueryBuilder): QUERY_CLS = RedshiftQuery -class MSSQLQuery(Query): +class MSSQLQuery(BaseQuery["MSSQLQueryBuilder"]): """ Defines a query class for use with Microsoft SQL Server. """ @@ -751,7 +763,7 @@ def _select_sql(self, **kwargs: Any) -> str: ) -class ClickHouseQuery(Query): +class ClickHouseQuery(BaseQuery["ClickHouseQueryBuilder"]): """ Defines a query class for use with Yandex ClickHouse. """ @@ -767,23 +779,23 @@ def drop_database(cls, database: Union[Database, str]) -> "ClickHouseDropQueryBu return ClickHouseDropQueryBuilder().drop_database(database) @classmethod - def drop_table(self, table: Union[Table, str]) -> "ClickHouseDropQueryBuilder": + def drop_table(cls, table: Union[Table, str]) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_table(table) @classmethod - def drop_dictionary(self, dictionary: str) -> "ClickHouseDropQueryBuilder": + def drop_dictionary(cls, dictionary: str) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_dictionary(dictionary) @classmethod - def drop_quota(self, quota: str) -> "ClickHouseDropQueryBuilder": + def drop_quota(cls, quota: str) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_quota(quota) @classmethod - def drop_user(self, user: str) -> "ClickHouseDropQueryBuilder": + def drop_user(cls, user: str) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_user(user) @classmethod - def drop_view(self, view: str) -> "ClickHouseDropQueryBuilder": + def drop_view(cls, view: str) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_view(view) @@ -799,7 +811,7 @@ def _update_sql(self, **kwargs: Any) -> str: return "ALTER TABLE {table}".format(table=self._update_table.get_sql(**kwargs)) def _from_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: - def _error_none(v) -> NoReturn: + def _error_none(v) -> "NoReturn": raise TypeError("expect Selectable or QueryBuilder, got {}".format(type(v).__name__)) selectable = ",".join( @@ -858,7 +870,7 @@ def get_value_sql(self, **kwargs: Any) -> str: return super().get_value_sql(**kwargs) -class SQLLiteQuery(Query): +class SQLLiteQuery(BaseQuery["SQLLiteQueryBuilder"]): """ Defines a query class for use with Microsoft SQL Server. """ diff --git a/pypika/queries.py b/pypika/queries.py index 81d235c2..71673117 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -2,6 +2,7 @@ from functools import reduce from itertools import chain import operator +import builtins from typing import ( Any, Callable, @@ -16,8 +17,13 @@ Set, cast, TypeVar, + overload, + TYPE_CHECKING, ) +if TYPE_CHECKING: + from typing_extensions import Self + from pypika.enums import Dialects, JoinType, ReferenceOption, SetOperation, Order from pypika.terms import ( ArithmeticExpression, @@ -26,14 +32,15 @@ Field, Function, Index, - Node, Rollup, Star, + Node, Term, Tuple, ValueWrapper, Criterion, PeriodCriterion, + WrappedConstantValue, WrappedConstant, ) from pypika.utils import ( @@ -45,7 +52,6 @@ format_alias_sql, format_quotes, ignore_copy, - SQLPart, ) __author__ = "Timothy Heys" @@ -53,6 +59,8 @@ _T = TypeVar("_T") +SchemaT = TypeVar("SchemaT", bound="Schema") +QueryBuilderType = TypeVar("QueryBuilderType", bound="QueryBuilder", covariant=True) class Selectable(Node): @@ -79,12 +87,15 @@ def __getitem__(self, name: str) -> Field: return self.field(name) def get_table_name(self) -> str: - if not self.alias: + if self.alias is None: raise TypeError("expect str, got None") return self.alias + def get_sql(self, **kwargs) -> str: + raise NotImplementedError -class AliasedQuery(Selectable, SQLPart): + +class AliasedQuery(Selectable): def __init__(self, name: str, query: Optional[Selectable] = None) -> None: super().__init__(alias=name) self.name = name @@ -102,7 +113,7 @@ def __hash__(self) -> int: return hash(str(self.name)) -class Schema(SQLPart): +class Schema: def __init__(self, name: str, parent: Optional["Schema"] = None) -> None: self._name = name self._parent = parent @@ -132,11 +143,193 @@ def get_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: class Database(Schema): @ignore_copy - def __getattr__(self, item: str) -> Schema: + def __getattr__(self, item: str) -> Schema: # type: ignore return Schema(item, parent=self) -class Table(Selectable): +class BaseQuery(Generic[QueryBuilderType]): + """ + Query is the primary class and entry point in pypika. It is used to build queries iteratively using the builder + design + pattern. + + This class is the generic base class for Query. + """ + + @classmethod + def _builder(cls, **kwargs: Any) -> "QueryBuilderType": + raise NotImplementedError + + @classmethod + def from_(cls, table: Union[Selectable, str], **kwargs: Any) -> "QueryBuilderType": + """ + Query builder entry point. Initializes query building and sets the table to select from. When using this + function, the query becomes a SELECT query. + + :param table: + Type: Table or str + + An instance of a Table object or a string table name. + + :returns QueryBuilder + """ + return cls._builder(**kwargs).from_(table) + + @classmethod + def create_table(cls, table: Union[str, "Table"]) -> "CreateQueryBuilder": + """ + Query builder entry point. Initializes query building and sets the table name to be created. When using this + function, the query becomes a CREATE statement. + + :param table: An instance of a Table object or a string table name. + + :return: CreateQueryBuilder + """ + return CreateQueryBuilder().create_table(table) + + @classmethod + def drop_database(cls, database: Union[Database, str]) -> "DropQueryBuilder": + """ + Query builder entry point. Initializes query building and sets the table name to be dropped. When using this + function, the query becomes a DROP statement. + + :param database: An instance of a Database object or a string database name. + + :return: DropQueryBuilder + """ + return DropQueryBuilder().drop_database(database) + + @classmethod + def drop_table(cls, table: Union[str, "Table"]) -> "DropQueryBuilder": + """ + Query builder entry point. Initializes query building and sets the table name to be dropped. When using this + function, the query becomes a DROP statement. + + :param table: An instance of a Table object or a string table name. + + :return: DropQueryBuilder + """ + return DropQueryBuilder().drop_table(table) + + @classmethod + def drop_user(cls, user: str) -> "DropQueryBuilder": + """ + Query builder entry point. Initializes query building and sets the table name to be dropped. When using this + function, the query becomes a DROP statement. + + :param user: String user name. + + :return: DropQueryBuilder + """ + return DropQueryBuilder().drop_user(user) + + @classmethod + def drop_view(cls, view: str) -> "DropQueryBuilder": + """ + Query builder entry point. Initializes query building and sets the table name to be dropped. When using this + function, the query becomes a DROP statement. + + :param view: String view name. + + :return: DropQueryBuilder + """ + return DropQueryBuilder().drop_view(view) + + @classmethod + def into(cls, table: Union["Table", str], **kwargs: Any) -> "QueryBuilderType": + """ + Query builder entry point. Initializes query building and sets the table to insert into. When using this + function, the query becomes an INSERT query. + + :param table: + Type: Table or str + + An instance of a Table object or a string table name. + + :returns QueryBuilder + """ + return cls._builder(**kwargs).into(table) + + @classmethod + def with_(cls, table: Selectable, name: str, **kwargs: Any) -> "QueryBuilderType": + return cls._builder(**kwargs).with_(table, name) + + @classmethod + def select(cls, *terms: Union[int, float, str, bool, Term], **kwargs: Any) -> "QueryBuilderType": + """ + Query builder entry point. Initializes query building without a table and selects fields. Useful when testing + SQL functions. + + :param terms: + Type: list[expression] + + A list of terms to select. These can be any type of int, float, str, bool, or Term. They cannot be a Field + unless the function ``Query.from_`` is called first. + + :returns QueryBuilder + """ + return cls._builder(**kwargs).select(*terms) + + @classmethod + def update(cls, table: Union[str, "Table"], **kwargs) -> "QueryBuilderType": + """ + Query builder entry point. Initializes query building and sets the table to update. When using this + function, the query becomes an UPDATE query. + + :param table: + Type: Table or str + + An instance of a Table object or a string table name. + + :returns QueryBuilder + """ + return cls._builder(**kwargs).update(table) + + @classmethod + def Table(cls, table_name: str, **kwargs) -> "Table[QueryBuilderType]": + """ + Convenience method for creating a Table that uses this Query class. + + :param table_name: + Type: str + + A string table name. + + :returns Table + """ + return Table(table_name, query_cls=cls, **kwargs) + + @classmethod + def Tables(cls, *names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List["Table[QueryBuilderType]"]: # type: ignore + """ + Convenience method for creating many tables that uses this Query class. + See ``Query.make_tables`` for details. + + :param names: + Type: list[str or tuple] + + A list of string table names, or name and alias tuples. + + :returns Table + """ + return make_tables(*names, query_cls=cls, **kwargs) + + +class Query(BaseQuery["QueryBuilder"]): + """ + Query is the primary class and entry point in pypika. It is used to build queries iteratively using the builder + design + pattern. + + This class is immutable. + """ + + @classmethod + def _builder(cls, **kwargs: Any) -> "QueryBuilder": + return QueryBuilder(**kwargs) + + +class Table(Selectable, Generic[QueryBuilderType]): @staticmethod def _init_schema(schema: Union[str, list, tuple, Schema, None]) -> Optional[Schema]: # This is a bit complicated in order to support backwards compatibility. It should probably be cleaned up for @@ -152,17 +345,17 @@ def _init_schema(schema: Union[str, list, tuple, Schema, None]) -> Optional[Sche def __init__( self, name: str, - schema: Optional[Union[Schema, str]] = None, + schema: Union[str, list, tuple, Schema, None] = None, alias: Optional[str] = None, - query_cls: Optional[Type["Query"]] = None, + query_cls: Type["BaseQuery[QueryBuilderType]"] = Query, # type: ignore ) -> None: super().__init__(alias) self._table_name = name self._schema = self._init_schema(schema) - self._query_cls = query_cls or Query + self._query_cls: Type["BaseQuery[QueryBuilderType]"] = query_cls self._for: Optional[Criterion] = None self._for_portion: Optional[PeriodCriterion] = None - if not issubclass(self._query_cls, Query): + if not issubclass(self._query_cls, BaseQuery): raise TypeError("Expected 'query_cls' to be subclass of Query") def get_table_name(self) -> str: @@ -230,7 +423,7 @@ def __ne__(self, other: Any) -> bool: def __hash__(self) -> int: return hash(str(self)) - def select(self, *terms: Sequence[Union[int, float, str, bool, Term, Field]]) -> "QueryBuilder": + def select(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBuilderType": """ Perform a SELECT operation on the current table @@ -243,7 +436,7 @@ def select(self, *terms: Sequence[Union[int, float, str, bool, Term, Field]]) -> """ return self._query_cls.from_(self).select(*terms) - def update(self) -> "QueryBuilder": + def update(self) -> "QueryBuilderType": """ Perform an UPDATE operation on the current table @@ -251,7 +444,7 @@ def update(self) -> "QueryBuilder": """ return self._query_cls.update(self) - def insert(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBuilder": + def insert(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBuilderType": """ Perform an INSERT operation on the current table @@ -265,13 +458,17 @@ def insert(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBui return self._query_cls.into(self).insert(*terms) -def make_tables(*names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List[Table]: +def make_tables( + *names: Union[TypedTuple[str, str], str], + query_cls: "Type[BaseQuery[QueryBuilderType]]" = Query, # type: ignore + **kwargs: Any, +) -> List[Table[QueryBuilderType]]: """ Shortcut to create many tables. If `names` param is a tuple, the first position will refer to the `_table_name` while the second will be its `alias`. Any other data structure will be treated as a whole as the `_table_name`. """ - tables = [] + tables: List["Table[QueryBuilderType]"] = [] for name in names: if isinstance(name, tuple): if len(name) == 2: @@ -279,7 +476,7 @@ def make_tables(*names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List name=name[0], alias=name[1], schema=kwargs.get("schema"), - query_cls=kwargs.get("query_cls"), + query_cls=query_cls, ) else: raise TypeError("expect tuple[str, str] or str, got a tuple with {} element(s)".format(len(name))) @@ -287,13 +484,13 @@ def make_tables(*names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List t = Table( name=name, schema=kwargs.get("schema"), - query_cls=kwargs.get("query_cls"), + query_cls=query_cls, ) tables.append(t) return tables -class Column(SQLPart): +class Column: """Represents a column.""" def __init__( @@ -301,7 +498,7 @@ def __init__( column_name: str, column_type: Optional[str] = None, nullable: Optional[bool] = None, - default: Optional[Union[Any, Term]] = None, + default: object = None, ) -> None: self.name = column_name self.type = column_type @@ -351,7 +548,7 @@ def make_columns(*names: Union[TypedTuple[str, str], str]) -> List[Column]: return columns -class PeriodFor(SQLPart): +class PeriodFor: def __init__(self, name: str, start_column: Union[str, Column], end_column: Union[str, Column]) -> None: self.name = name self.start_column = start_column if isinstance(start_column, Column) else Column(start_column) @@ -373,177 +570,7 @@ def get_sql(self, **kwargs: Any) -> str: _TableClass = Table -class Query: - """ - Query is the primary class and entry point in pypika. It is used to build queries iteratively using the builder - design - pattern. - - This class is immutable. - """ - - @classmethod - def _builder(cls, **kwargs: Any) -> "QueryBuilder": - return QueryBuilder(**kwargs) - - @classmethod - def from_(cls, table: Union[Selectable, str], **kwargs: Any) -> "QueryBuilder": - """ - Query builder entry point. Initializes query building and sets the table to select from. When using this - function, the query becomes a SELECT query. - - :param table: - Type: Table or str - - An instance of a Table object or a string table name. - - :returns QueryBuilder - """ - return cls._builder(**kwargs).from_(table) - - @classmethod - def create_table(cls, table: Union[str, Table]) -> "CreateQueryBuilder": - """ - Query builder entry point. Initializes query building and sets the table name to be created. When using this - function, the query becomes a CREATE statement. - - :param table: An instance of a Table object or a string table name. - - :return: CreateQueryBuilder - """ - return CreateQueryBuilder().create_table(table) - - @classmethod - def drop_database(cls, database: Union[Database, str]) -> "DropQueryBuilder": - """ - Query builder entry point. Initializes query building and sets the table name to be dropped. When using this - function, the query becomes a DROP statement. - - :param database: An instance of a Database object or a string database name. - - :return: DropQueryBuilder - """ - return DropQueryBuilder().drop_database(database) - - @classmethod - def drop_table(cls, table: Union[str, Table]) -> "DropQueryBuilder": - """ - Query builder entry point. Initializes query building and sets the table name to be dropped. When using this - function, the query becomes a DROP statement. - - :param table: An instance of a Table object or a string table name. - - :return: DropQueryBuilder - """ - return DropQueryBuilder().drop_table(table) - - @classmethod - def drop_user(cls, user: str) -> "DropQueryBuilder": - """ - Query builder entry point. Initializes query building and sets the table name to be dropped. When using this - function, the query becomes a DROP statement. - - :param user: String user name. - - :return: DropQueryBuilder - """ - return DropQueryBuilder().drop_user(user) - - @classmethod - def drop_view(cls, view: str) -> "DropQueryBuilder": - """ - Query builder entry point. Initializes query building and sets the table name to be dropped. When using this - function, the query becomes a DROP statement. - - :param view: String view name. - - :return: DropQueryBuilder - """ - return DropQueryBuilder().drop_view(view) - - @classmethod - def into(cls, table: Union[Table, str], **kwargs: Any) -> "QueryBuilder": - """ - Query builder entry point. Initializes query building and sets the table to insert into. When using this - function, the query becomes an INSERT query. - - :param table: - Type: Table or str - - An instance of a Table object or a string table name. - - :returns QueryBuilder - """ - return cls._builder(**kwargs).into(table) - - @classmethod - def with_(cls, table: Union[str, Selectable], name: str, **kwargs: Any) -> "QueryBuilder": - return cls._builder(**kwargs).with_(table, name) - - @classmethod - def select(cls, *terms: Union[int, float, str, bool, Term], **kwargs: Any) -> "QueryBuilder": - """ - Query builder entry point. Initializes query building without a table and selects fields. Useful when testing - SQL functions. - - :param terms: - Type: list[expression] - - A list of terms to select. These can be any type of int, float, str, bool, or Term. They cannot be a Field - unless the function ``Query.from_`` is called first. - - :returns QueryBuilder - """ - return cls._builder(**kwargs).select(*terms) - - @classmethod - def update(cls, table: Union[str, Table], **kwargs) -> "QueryBuilder": - """ - Query builder entry point. Initializes query building and sets the table to update. When using this - function, the query becomes an UPDATE query. - - :param table: - Type: Table or str - - An instance of a Table object or a string table name. - - :returns QueryBuilder - """ - return cls._builder(**kwargs).update(table) - - @classmethod - def Table(cls, table_name: str, **kwargs) -> _TableClass: - """ - Convenience method for creating a Table that uses this Query class. - - :param table_name: - Type: str - - A string table name. - - :returns Table - """ - kwargs["query_cls"] = cls - return Table(table_name, **kwargs) - - @classmethod - def Tables(cls, *names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List[_TableClass]: - """ - Convenience method for creating many tables that uses this Query class. - See ``Query.make_tables`` for details. - - :param names: - Type: list[str or tuple] - - A list of string table names, or name and alias tuples. - - :returns Table - """ - kwargs["query_cls"] = cls - return make_tables(*names, **kwargs) - - -class _SetOperation(Selectable, Term, SQLPart): +class _SetOperation(Selectable, Term): # type: ignore """ A Query class wrapper for a all set operations, Union DISTINCT or ALL, Intersect, Except or Minus @@ -562,9 +589,7 @@ def __init__( ): super().__init__(alias) self.base_query = base_query - self._set_operation: List[TypedTuple[SetOperation, Union[QueryBuilder, Selectable]]] = [ - (set_operation, set_operation_query) - ] + self._set_operation: List[TypedTuple[SetOperation, QueryBuilder]] = [(set_operation, set_operation_query)] self._orderbys: List[TypedTuple[Union[Field, WrappedConstant, None], Optional[Order]]] = [] self._limit: Optional[int] = None @@ -599,29 +624,29 @@ def offset(self, offset: int): self._offset = offset @builder - def union(self, other: Selectable): + def union(self, other: "QueryBuilder"): self._set_operation.append((SetOperation.union, other)) @builder - def union_all(self, other: Selectable): + def union_all(self, other: "QueryBuilder"): self._set_operation.append((SetOperation.union_all, other)) @builder - def intersect(self, other: Selectable): + def intersect(self, other: "QueryBuilder"): self._set_operation.append((SetOperation.intersect, other)) @builder - def except_of(self, other: Selectable): + def except_of(self, other: "QueryBuilder"): self._set_operation.append((SetOperation.except_of, other)) @builder - def minus(self, other: Selectable): + def minus(self, other: "QueryBuilder"): self._set_operation.append((SetOperation.minus, other)) - def __add__(self, other: Selectable) -> "_SetOperation": # type: ignore + def __add__(self, other: "QueryBuilder") -> "_SetOperation": # type: ignore return self.union(other) - def __mul__(self, other: Selectable) -> "_SetOperation": # type: ignore + def __mul__(self, other: "QueryBuilder") -> "_SetOperation": # type: ignore return self.union_all(other) def __sub__(self, other: "QueryBuilder") -> "_SetOperation": # type: ignore @@ -671,7 +696,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An querystring = "({query})".format(query=querystring, **kwargs) if with_alias: - return format_alias_sql(querystring, self.alias or self._table_name, **kwargs) + return format_alias_sql(querystring, self.alias or self.get_table_name(), **kwargs) return querystring @@ -706,7 +731,7 @@ def _limit_sql(self) -> str: return " LIMIT {limit}".format(limit=self._limit) -class QueryBuilder(Selectable, Term, SQLPart): +class QueryBuilder(Selectable, Term): """ Query Builder is the main class in pypika which stores the state of a query and offers functions which allow the state to be branched immutably. @@ -716,7 +741,7 @@ class QueryBuilder(Selectable, Term, SQLPart): SECONDARY_QUOTE_CHAR: Optional[str] = "'" ALIAS_QUOTE_CHAR: Optional[str] = None QUERY_ALIAS_QUOTE_CHAR: Optional[str] = None - QUERY_CLS = Query + QUERY_CLS: Type[BaseQuery] = Query def __init__( self, @@ -750,7 +775,7 @@ def __init__( self._groupbys: List[Union[Term, WrappedConstant]] = [] self._with_totals = False self._havings: Optional[Union[Term, Criterion]] = None - self._orderbys: List[TypedTuple[Union[Field, WrappedConstant], Optional[Order]]] = [] + self._orderbys: List[TypedTuple[WrappedConstant, Optional[Order]]] = [] self._joins: List[Join] = [] self._unions: List[None] = [] self._using: List[Union[Selectable, str]] = [] @@ -776,7 +801,7 @@ def __init__( self.immutable = immutable - def __copy__(self) -> "QueryBuilder": + def __copy__(self) -> "Self": newone = type(self).__new__(type(self)) newone.__dict__.update(self.__dict__) newone._select_star_tables = copy(self._select_star_tables) @@ -838,7 +863,7 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl self._insert_table = new_table if self._insert_table == current_table else self._insert_table self._update_table = new_table if self._update_table == current_table else self._update_table - self._with = [alias_query.replace_table(current_table, new_table) for alias_query in self._with] + self._with = [alias_query.replace_table(current_table, new_table) for alias_query in self._with] # TODO: why? self._selects = [ select.replace_table(current_table, new_table) if isinstance(select, Term) else select for select in self._selects @@ -897,7 +922,7 @@ def select(self, *terms: Any): self._select_other(term) else: value = self.wrap_constant(term, wrapper_cls=self._wrapper_cls) - self._select_other(Term._assert_guard(value)) + self._select_other(value) @builder def delete(self): @@ -1048,7 +1073,7 @@ def rollup(self, *terms: Union[list, tuple, set, Term], **kwargs: Any): self._groupbys.append(Rollup(*wrapped_terms)) @builder - def orderby(self, *fields: Union[str, Field], order: Optional[Order] = None): + def orderby(self, *fields: WrappedConstantValue, order: Optional[Order] = None): table = self._from[0] if not isinstance(table, Selectable): raise TypeError("expect table is a Selectable, got {}".format(type(table).__name__)) @@ -1060,7 +1085,7 @@ def orderby(self, *fields: Union[str, Field], order: Optional[Order] = None): @builder def join( self, item: Union[Table, "QueryBuilder", AliasedQuery, _SetOperation], how: JoinType = JoinType.inner - ) -> "Joiner": + ) -> "Joiner[Self]": if isinstance(item, Table): return Joiner(self, item, how, type_label="table") @@ -1074,31 +1099,31 @@ def join( raise ValueError("Cannot join on type '%s'" % type(item)) - def inner_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def inner_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.inner) - def left_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def left_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.left) - def left_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def left_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.left_outer) - def right_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def right_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.right) - def right_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def right_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.right_outer) - def outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.outer) - def full_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def full_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.full_outer) - def cross_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def cross_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.cross) - def hash_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def hash_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.hash) @builder @@ -1148,7 +1173,15 @@ def slice(self, slice: slice): self._offset = slice.start self._limit = slice.stop - def __getitem__(self, item: Any) -> Union["QueryBuilder", Field]: # type: ignore + @overload # type: ignore[override] + def __getitem__(self, item: str) -> Field: + ... + + @overload + def __getitem__(self, item: builtins.slice) -> "Self": + ... + + def __getitem__(self, item: Union[str, builtins.slice]) -> Union["Self", Field]: if not isinstance(item, slice): return super().__getitem__(item) return self.slice(item) @@ -1528,7 +1561,7 @@ def _from_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: def _using_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: return " USING {selectable}".format( selectable=",".join( - clause.get_sql(subquery=True, with_alias=True, **kwargs) if isinstance(clause, SQLPart) else clause + clause.get_sql(subquery=True, with_alias=True, **kwargs) if isinstance(clause, Selectable) else clause for clause in self._using ) ) @@ -1644,14 +1677,14 @@ def _set_sql(self, **kwargs: Any) -> str: JoinableTerm = Union[Table, "QueryBuilder", AliasedQuery, _SetOperation] -class Joiner: - def __init__(self, query: QueryBuilder, item: JoinableTerm, how: JoinType, type_label: str) -> None: +class Joiner(Generic[QueryBuilderType]): + def __init__(self, query: "QueryBuilderType", item: JoinableTerm, how: JoinType, type_label: str) -> None: self.query = query self.item = item self.how = how self.type_label = type_label - def on(self, criterion: Optional[Criterion], collate: Optional[str] = None) -> QueryBuilder: + def on(self, criterion: Optional[Criterion], collate: Optional[str] = None) -> "QueryBuilderType": if criterion is None: raise JoinException( "Parameter 'criterion' is required for a " @@ -1661,7 +1694,7 @@ def on(self, criterion: Optional[Criterion], collate: Optional[str] = None) -> Q self.query.do_join(JoinOn(self.item, self.how, criterion, collate)) return self.query - def on_field(self, *fields: Any) -> QueryBuilder: + def on_field(self, *fields: Any) -> "QueryBuilderType": if not fields: raise JoinException( "Parameter 'fields' is required for a " "{type} JOIN but was not supplied.".format(type=self.type_label) @@ -1675,21 +1708,21 @@ def on_field(self, *fields: Any) -> QueryBuilder: self.query.do_join(JoinOn(self.item, self.how, cast(Criterion, criterion))) return self.query - def using(self, *fields: Any) -> QueryBuilder: + def using(self, *fields: Any) -> "QueryBuilderType": if not fields: raise JoinException("Parameter 'fields' is required when joining with a using clause but was not supplied.") self.query.do_join(JoinUsing(self.item, self.how, [Field(field) for field in fields])) return self.query - def cross(self) -> QueryBuilder: + def cross(self) -> "QueryBuilderType": """Return cross join""" self.query.do_join(Join(self.item, JoinType.cross)) return self.query -class Join(SQLPart): +class Join: def __init__(self, item: JoinableTerm, how: JoinType) -> None: self.item = item self.how = how @@ -1719,7 +1752,7 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl :return: A copy of the join with the tables replaced. """ - self.item = self.item.replace_table(current_table, new_table) + self.item = self.item.replace_table(current_table, new_table) # TODO: why? class JoinOn(Join): @@ -1803,7 +1836,7 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl raise ValueError("new_table should not be None for {}".format(type(self).__name__)) -class CreateQueryBuilder(SQLPart): +class CreateQueryBuilder: """ Query builder used to build CREATE queries. """ @@ -1811,7 +1844,7 @@ class CreateQueryBuilder(SQLPart): QUOTE_CHAR: Optional[str] = '"' SECONDARY_QUOTE_CHAR: Optional[str] = "'" ALIAS_QUOTE_CHAR: Optional[str] = None - QUERY_CLS = Query + QUERY_CLS: Type[BaseQuery] = Query def __init__(self, dialect: Optional[Dialects] = None) -> None: self._create_table: Optional[Table] = None @@ -1971,8 +2004,8 @@ def foreign_key( columns: List[Union[str, Column]], reference_table: Union[str, Table], reference_columns: List[Union[str, Column]], - on_delete: ReferenceOption = None, - on_update: ReferenceOption = None, + on_delete: Optional[ReferenceOption] = None, + on_update: Optional[ReferenceOption] = None, ): """ Adds a foreign key constraint. @@ -2112,6 +2145,7 @@ def _primary_key_clause(self, **kwargs) -> str: ) def _foreign_key_clause(self, **kwargs) -> str: + assert self._foreign_key_reference_table is not None clause = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})".format( columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key), # type: ignore table_name=( @@ -2155,7 +2189,7 @@ def __repr__(self) -> str: return self.__str__() -class DropQueryBuilder(SQLPart): +class DropQueryBuilder: """ Query builder used to build DROP queries. """ @@ -2163,7 +2197,7 @@ class DropQueryBuilder(SQLPart): QUOTE_CHAR: Optional[str] = '"' SECONDARY_QUOTE_CHAR: Optional[str] = "'" ALIAS_QUOTE_CHAR: Optional[str] = None - QUERY_CLS = Query + QUERY_CLS: Type[BaseQuery] = Query def __init__(self, dialect: Optional[Dialects] = None) -> None: self._drop_target_kind: Optional[str] = None diff --git a/pypika/terms.py b/pypika/terms.py index a831c216..e062c5cd 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -11,15 +11,19 @@ Iterable, Iterator, List, - MutableSequence, Optional, + MutableSequence, Sequence, Set, Type, TypeVar, Union, + overload, ) +if TYPE_CHECKING: + from typing_extensions import Self + from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order from pypika.utils import ( CaseException, @@ -29,12 +33,10 @@ format_quotes, ignore_copy, resolve_is_aggregate, - SQLPart, ) if TYPE_CHECKING: - from pypika.queries import QueryBuilder, Selectable, Table - from _typeshed import Self + from pypika.queries import QueryBuilder, Selectable, Table, Schema __author__ = "Timothy Heys" @@ -42,6 +44,10 @@ NodeT = TypeVar("NodeT", bound="Node") +TermT = TypeVar("TermT", bound="Term") +ValueWrapperT = TypeVar("ValueWrapperT", bound="ValueWrapper") +IntervalT = TypeVar("IntervalT", bound="Interval") +CriterionT = TypeVar("CriterionT", bound="Criterion") class Node: @@ -58,11 +64,14 @@ def find_(self, type: Type[NodeT]) -> List[NodeT]: WrappedConstantStrict = Union["LiteralValue", "Array", "Tuple", "ValueWrapper"] - -WrappedConstant = Union[Node, WrappedConstantStrict] +WrappedConstantValue = Union["Term", int, float, bool, str, date, None] +WrappedConstantValueUnion = Union[ + WrappedConstantValue, List[WrappedConstantValue], typing.Tuple[WrappedConstantValue, ...] +] +WrappedConstant = Union["Term", WrappedConstantStrict] -class Term(Node, SQLPart): +class Term(Node): def __init__(self, alias: Optional[str] = None) -> None: self.alias = alias @@ -83,8 +92,41 @@ def tables_(self) -> Set["Table"]: def fields_(self) -> Set["Field"]: return set(self.find_(Field)) + @overload + @staticmethod + def wrap_constant(val: TermT) -> TermT: + ... + + @overload + @staticmethod + def wrap_constant(val: None) -> "NullValue": + ... + + @overload + @staticmethod + def wrap_constant(val: List[WrappedConstantValue]) -> "Array": + ... + + @overload + @staticmethod + def wrap_constant(val: typing.Tuple[WrappedConstantValue, ...]) -> "Tuple": + ... + + @overload + @staticmethod + def wrap_constant(val: Union[int, float, bool, str, date], wrapper_cls: Type["ValueWrapperT"]) -> "ValueWrapperT": + ... + + @overload + @staticmethod + def wrap_constant(val: Union[int, float, bool, str, date], wrapper_cls: None = None) -> "ValueWrapper": + ... + @staticmethod - def wrap_constant(val, wrapper_cls: Optional[Type["Term"]] = None) -> WrappedConstant: + def wrap_constant( + val: WrappedConstantValueUnion, + wrapper_cls: Optional[Type["ValueWrapper"]] = None, + ) -> WrappedConstant: """ Used for wrapping raw inputs such as numbers in Criterions and Operator. @@ -113,9 +155,34 @@ def wrap_constant(val, wrapper_cls: Optional[Type["Term"]] = None) -> WrappedCon wrapper_cls = wrapper_cls or ValueWrapper return wrapper_cls(val) + @overload + @staticmethod + def wrap_json(val: TermT, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> TermT: # type: ignore[misc] + ... + + @overload + @staticmethod + def wrap_json(val: None, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> "NullValue": # type: ignore[misc] + ... + + @overload + @staticmethod + def wrap_json(val: Union[str, int, bool], wrapper_cls: Type["ValueWrapperT"]) -> "ValueWrapperT": # type: ignore[misc] + ... + + @overload + @staticmethod + def wrap_json(val: Union[str, int, bool], wrapper_cls: None = None) -> "ValueWrapper": # type: ignore[misc] + ... + + @overload + @staticmethod + def wrap_json(val: object, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> "JSON": + ... + @staticmethod def wrap_json( - val: Union["Term", "QueryBuilder", "Interval", None, str, int, bool], wrapper_cls=None + val: object, wrapper_cls=None ) -> Union["Term", "QueryBuilder", "Interval", "NullValue", "ValueWrapper", "JSON"]: from .queries import QueryBuilder @@ -178,28 +245,28 @@ def ne(self, other: Any) -> "BasicCriterion": return self != other def glob(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.glob, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.glob, self, self.wrap_constant(expr)) def like(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.like, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.like, self, self.wrap_constant(expr)) def not_like(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.not_like, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.not_like, self, self.wrap_constant(expr)) def ilike(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.ilike, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.ilike, self, self.wrap_constant(expr)) def not_ilike(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.not_ilike, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.not_ilike, self, self.wrap_constant(expr)) def rlike(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.rlike, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.rlike, self, self.wrap_constant(expr)) def regex(self, pattern: str) -> "BasicCriterion": - return BasicCriterion(Matching.regex, self, Term._assert_guard(self.wrap_constant(pattern))) + return BasicCriterion(Matching.regex, self, self.wrap_constant(pattern)) def regexp(self, pattern: str) -> "BasicCriterion": - return BasicCriterion(Matching.regexp, self, Term._assert_guard(self.wrap_constant(pattern))) + return BasicCriterion(Matching.regexp, self, self.wrap_constant(pattern)) def between(self, lower: Any, upper: Any) -> "BetweenCriterion": return BetweenCriterion(self, self.wrap_constant(lower), self.wrap_constant(upper)) @@ -208,7 +275,7 @@ def from_to(self, start: Any, end: Any) -> "PeriodCriterion": return PeriodCriterion(self, self.wrap_constant(start), self.wrap_constant(end)) def as_of(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.as_of, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.as_of, self, self.wrap_constant(expr)) def all_(self) -> "All": return All(self) @@ -222,7 +289,7 @@ def notin(self, arg: Union[list, tuple, set, "Term"]) -> "ContainsCriterion": return self.isin(arg).negate() def bin_regex(self, pattern: str) -> "BasicCriterion": - return BasicCriterion(Matching.bin_regex, self, Term._assert_guard(self.wrap_constant(pattern))) + return BasicCriterion(Matching.bin_regex, self, self.wrap_constant(pattern)) def negate(self) -> "Not": return Not(self) @@ -285,22 +352,22 @@ def __rrshift__(self, other: Any) -> "ArithmeticExpression": return ArithmeticExpression(Arithmetic.rshift, self.wrap_constant(other), self) def __eq__(self, other: Any) -> "BasicCriterion": # type: ignore - return BasicCriterion(Equality.eq, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.eq, self, self.wrap_constant(other)) def __ne__(self, other: Any) -> "BasicCriterion": # type: ignore - return BasicCriterion(Equality.ne, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.ne, self, self.wrap_constant(other)) def __gt__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.gt, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.gt, self, self.wrap_constant(other)) def __ge__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.gte, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.gte, self, self.wrap_constant(other)) def __lt__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.lt, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.lt, self, self.wrap_constant(other)) def __le__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.lte, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.lte, self, self.wrap_constant(other)) def __getitem__(self, item: slice) -> "BetweenCriterion": if not isinstance(item, slice): @@ -316,13 +383,6 @@ def __hash__(self) -> int: def get_sql(self, **kwargs: Any) -> str: raise NotImplementedError() - @classmethod - def _assert_guard(cls, v: Any) -> "Term": - if isinstance(v, cls): - return v - else: - raise TypeError("expect Term object, got {}".format(type(v).__name__)) - class Parameter(Term): def __init__(self, placeholder: Union[str, int]) -> None: @@ -404,7 +464,7 @@ def get_value_sql(self, **kwargs: Any) -> str: return self.get_formatted_value(self.value, **kwargs) @classmethod - def get_formatted_value(cls, value: Any, **kwargs): + def get_formatted_value(cls, value: Any, **kwargs) -> str: quote_char = kwargs.get("secondary_quote_char") or "" # FIXME escape values @@ -431,7 +491,7 @@ def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = class JSON(Term): - def __init__(self, value: Any = None, alias: Optional[str] = None) -> None: + def __init__(self, value: object = None, alias: Optional[str] = None) -> None: super().__init__(alias) self.value = value self.table: Optional[Union[str, "Selectable"]] = None @@ -468,10 +528,10 @@ def get_sql(self, secondary_quote_char: str = "'", **kwargs: Any) -> str: return format_alias_sql(sql, self.alias, **kwargs) def get_json_value(self, key_or_index: Union[str, int]) -> "BasicCriterion": - return BasicCriterion(JSONOperators.GET_JSON_VALUE, self, Term._assert_guard(self.wrap_constant(key_or_index))) + return BasicCriterion(JSONOperators.GET_JSON_VALUE, self, self.wrap_constant(key_or_index)) def get_text_value(self, key_or_index: Union[str, int]) -> "BasicCriterion": - return BasicCriterion(JSONOperators.GET_TEXT_VALUE, self, Term._assert_guard(self.wrap_constant(key_or_index))) + return BasicCriterion(JSONOperators.GET_TEXT_VALUE, self, self.wrap_constant(key_or_index)) def get_path_json_value(self, path_json: str) -> "BasicCriterion": return BasicCriterion(JSONOperators.GET_PATH_JSON_VALUE, self, self.wrap_json(path_json)) @@ -524,18 +584,18 @@ def __init__(self, alias: Optional[str] = None) -> None: class Criterion(Term): - def __and__(self, other: Any) -> "ComplexCriterion": + def __and__(self, other: "Criterion") -> "Criterion": return ComplexCriterion(Boolean.and_, self, other) - def __or__(self, other: Any) -> "ComplexCriterion": + def __or__(self, other: "Criterion") -> "Criterion": return ComplexCriterion(Boolean.or_, self, other) - def __xor__(self, other: Any) -> "ComplexCriterion": + def __xor__(self, other: "Criterion") -> "Criterion": return ComplexCriterion(Boolean.xor_, self, other) @staticmethod - def any(terms: Iterable[Term] = ()) -> "EmptyCriterion": - crit = EmptyCriterion() + def any(terms: Iterable["Criterion"] = ()) -> "Criterion": + crit: Criterion = EmptyCriterion() for term in terms: crit |= term @@ -543,8 +603,8 @@ def any(terms: Iterable[Term] = ()) -> "EmptyCriterion": return crit @staticmethod - def all(terms: Iterable[Any] = ()) -> "EmptyCriterion": - crit = EmptyCriterion() + def all(terms: Iterable["Criterion"] = ()) -> "Criterion": + crit: Criterion = EmptyCriterion() for term in terms: crit &= term @@ -559,13 +619,13 @@ class EmptyCriterion(Criterion): def fields_(self) -> Set["Field"]: return set() - def __and__(self, other: Any) -> Any: + def __and__(self, other: CriterionT) -> CriterionT: return other - def __or__(self, other: Any) -> Any: + def __or__(self, other: CriterionT) -> CriterionT: return other - def __xor__(self, other: Any) -> Any: + def __xor__(self, other: CriterionT) -> CriterionT: return other @property @@ -656,9 +716,9 @@ def get_sql( # type: ignore class Tuple(Criterion): - def __init__(self, *values: Any) -> None: + def __init__(self, *values: WrappedConstantValueUnion) -> None: super().__init__() - self.values = [self.wrap_constant(value) for value in values] + self.values: List[Term] = [self.wrap_constant(value) for value in values] def nodes_(self) -> Iterator[Node]: yield self @@ -666,7 +726,7 @@ def nodes_(self) -> Iterator[Node]: yield from value.nodes_() def get_sql(self, **kwargs: Any) -> str: - sql = "({})".format(",".join(Term._assert_guard(term).get_sql(**kwargs) for term in self.values)) + sql = "({})".format(",".join(term.get_sql(**kwargs) for term in self.values)) return format_alias_sql(sql, self.alias, **kwargs) @property @@ -685,13 +745,13 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T :return: A copy of the field with the tables replaced. """ - self.values = [Term._assert_guard(value).replace_table(current_table, new_table) for value in self.values] + self.values = [value.replace_table(current_table, new_table) for value in self.values] class Array(Tuple): def get_sql(self, **kwargs: Any) -> str: dialect = kwargs.get("dialect", None) - values = ",".join(Term._assert_guard(term).get_sql(**kwargs) for term in self.values) + values = ",".join(term.get_sql(**kwargs) for term in self.values) sql = "[{}]".format(values) if dialect in (Dialects.POSTGRESQL, Dialects.REDSHIFT): @@ -1037,7 +1097,7 @@ class ArithmeticExpression(Term): add_order = [Arithmetic.add, Arithmetic.sub] - def __init__(self, operator: Arithmetic, left: Any, right: Any, alias: Optional[str] = None) -> None: + def __init__(self, operator: Arithmetic, left: Term, right: Term, alias: Optional[str] = None) -> None: """ Wrapper for an arithmetic expression. @@ -1148,8 +1208,8 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: class Case(Criterion): def __init__(self, alias: Optional[str] = None) -> None: super().__init__(alias=alias) - self._cases: List[typing.Tuple[Any, Any]] = [] - self._else: WrappedConstant | None = None + self._cases: List[typing.Tuple["Criterion", "Term"]] = [] + self._else: Optional[WrappedConstant] = None def nodes_(self) -> Iterator[Node]: yield self @@ -1170,7 +1230,7 @@ def is_aggregate(self) -> Optional[bool]: ) @builder - def when(self, criterion: Any, term: Any): + def when(self, criterion: Criterion, term: WrappedConstantValue): self._cases.append((criterion, self.wrap_constant(term))) @builder @@ -1192,10 +1252,10 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T ) for criterion, term in self._cases ] - self._else = Term._assert_guard(self._else).replace_table(current_table, new_table) if self._else else None + self._else = self._else.replace_table(current_table, new_table) if self._else else None @builder - def else_(self, term: Any) -> "Case": + def else_(self, term: WrappedConstantValue) -> "Self": self._else = self.wrap_constant(term) return self @@ -1207,7 +1267,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: "WHEN {when} THEN {then}".format(when=criterion.get_sql(**kwargs), then=term.get_sql(**kwargs)) for criterion, term in self._cases ) - else_ = " ELSE {}".format(Term._assert_guard(self._else).get_sql(**kwargs)) if self._else else "" + else_ = " ELSE {}".format(self._else.get_sql(**kwargs)) if self._else else "" case_sql = "CASE {cases}{else_} END".format(cases=cases, else_=else_) @@ -1218,7 +1278,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: class Not(Criterion): - def __init__(self, term: Any, alias: Optional[str] = None) -> None: + def __init__(self, term: Term, alias: Optional[str] = None) -> None: super().__init__(alias=alias) self.term = term @@ -1266,7 +1326,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T class All(Criterion): - def __init__(self, term: Any, alias: Optional[str] = None) -> None: + def __init__(self, term: Term, alias: Optional[str] = None) -> None: super().__init__(alias=alias) self.term = term @@ -1285,7 +1345,7 @@ def __init__(self, name: str, params: Optional[Sequence] = None) -> None: self.params = params def __call__(self, *args: Any, **kwargs: Any) -> "Function": - if not self._has_params(): + if self.params is None: return Function(self.name, alias=kwargs.get("alias")) if not self._is_valid_function_call(*args): @@ -1303,15 +1363,16 @@ def _has_params(self): return self.params is not None def _is_valid_function_call(self, *args): + assert self.params is not None return len(args) == len(self.params) class Function(Criterion): - def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: + def __init__(self, name: str, *args: WrappedConstantValueUnion, **kwargs: Any) -> None: super().__init__(kwargs.get("alias")) self.name = name - self.args: MutableSequence[WrappedConstant] = [self.wrap_constant(param) for param in args] - self.schema = kwargs.get("schema") + self.args: MutableSequence[Term] = [self.wrap_constant(param) for param in args] + self.schema: Optional["Schema"] = kwargs.get("schema") def nodes_(self) -> Iterator[Node]: yield self @@ -1340,7 +1401,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T :return: A copy of the criterion with the tables replaced. """ - self.args = [Term._assert_guard(param).replace_table(current_table, new_table) for param in self.args] + self.args = [param.replace_table(current_table, new_table) for param in self.args] def get_special_params_sql(self, **kwargs: Any) -> Any: pass @@ -1355,7 +1416,7 @@ def get_function_sql(self, **kwargs: Any) -> str: return "{name}({args}{special})".format( name=self.name, args=",".join( - Term._assert_guard(p).get_sql(with_alias=False, subquery=True, **kwargs) + p.get_sql(with_alias=False, subquery=True, **kwargs) if hasattr(p, "get_sql") else self.get_arg_sql(p, **kwargs) for p in self.args @@ -1385,7 +1446,7 @@ def get_sql(self, **kwargs: Any) -> str: class AggregateFunction(Function): - is_aggregate = True + is_aggregate: Optional[bool] = True def __init__(self, name, *args, **kwargs): super(AggregateFunction, self).__init__(name, *args, **kwargs) @@ -1414,7 +1475,7 @@ def get_function_sql(self, **kwargs: Any): class AnalyticFunction(AggregateFunction): - is_aggregate = False + is_aggregate: Optional[bool] = False is_analytic = True def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: @@ -1480,7 +1541,7 @@ def get_function_sql(self, **kwargs: Any) -> str: class WindowFrameAnalyticFunction(AnalyticFunction): class Edge: - modifier: ClassVar[Optional[str]] = None + modifier: ClassVar[str] = "" def __init__(self, value: Optional[Union[str, int]] = None) -> None: self.value = value @@ -1488,7 +1549,7 @@ def __init__(self, value: Optional[Union[str, int]] = None) -> None: def __str__(self) -> str: return "{value} {modifier}".format( value=self.value or "UNBOUNDED", - modifier=self.modifier or "", + modifier=self.modifier, ) def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: @@ -1682,7 +1743,7 @@ def get_sql(self, **kwargs: Any) -> str: return self.name -class AtTimezone(Term, SQLPart): +class AtTimezone(Term): """ Generates AT TIME ZONE SQL. Examples: @@ -1690,7 +1751,7 @@ class AtTimezone(Term, SQLPart): AT TIME ZONE INTERVAL '-06:00' """ - is_aggregate = None + is_aggregate: Optional[bool] = None def __init__(self, field, zone, interval=False, alias=None): super().__init__(alias) diff --git a/pypika/utils.py b/pypika/utils.py index 07e63e8d..4fd2da8a 100644 --- a/pypika/utils.py +++ b/pypika/utils.py @@ -1,13 +1,7 @@ -from typing import Any, Callable, List, Optional, Protocol, Type, TYPE_CHECKING, runtime_checkable +from typing import Any, Callable, List, Optional, Type, Union, overload, TypeVar, TYPE_CHECKING if TYPE_CHECKING: - import sys - from typing import overload, TypeVar - - if sys.version_info >= (3, 10): - from typing import ParamSpec, Concatenate - else: - from typing_extensions import ParamSpec, Concatenate + from typing_extensions import ParamSpec, Concatenate __author__ = "Timothy Heys" __email__ = "theys@kayak.com" @@ -50,18 +44,18 @@ class FunctionException(Exception): _S = TypeVar('_S') _P = ParamSpec('_P') -if TYPE_CHECKING: - @overload - def builder(func: Callable[Concatenate[_S, _P], None]) -> Callable[Concatenate[_S, _P], _S]: - ... +@overload +def builder(func: "Callable[Concatenate[_S, _P], Union[_S, None]]") -> "Callable[Concatenate[_S, _P], _S]": + ... + - @overload - def builder(func: Callable[Concatenate[_S, _P], _T]) -> Callable[Concatenate[_S, _P], _T]: - ... +@overload +def builder(func: "Callable[Concatenate[_S, _P], _T]") -> "Callable[Concatenate[_S, _P], _T]": + ... -def builder(func): +def builder(func: "Callable[Concatenate[_S, _P], Union[_T, None]]") -> "Callable[Concatenate[_S, _P], Union[_T, _S]]": """ Decorator for wrapper "builder" functions. These are functions on the Query class or other classes used for building queries which mutate the query and return self. To make the build functions immutable, this decorator is @@ -70,7 +64,7 @@ def builder(func): """ import copy - def _copy(self, *args, **kwargs): + def _copy(self: "_S", *args: "_P.args", **kwargs: "_P.kwargs"): self_copy = copy.copy(self) if getattr(self, "immutable", True) else self result = func(self_copy, *args, **kwargs) @@ -84,7 +78,7 @@ def _copy(self, *args, **kwargs): return _copy -def ignore_copy(func: Callable) -> Callable: +def ignore_copy(func: "Callable[[_S, str], _T]") -> "Callable[[_S, str], _T]": """ Decorator for wrapping the __getattr__ function for classes that are copied via deepcopy. This prevents infinite recursion caused by deepcopy looking for magic functions in the class. Any class implementing __getattr__ that is @@ -143,16 +137,9 @@ def format_alias_sql( ) -def validate(*args: Any, exc: Exception, type: Optional[Type] = None) -> None: +def validate(*args: Any, exc: Optional[Exception], type: Optional[Type] = None) -> None: if type is not None: + assert exc is not None for arg in args: if not isinstance(arg, type): raise exc - - -@runtime_checkable -class SQLPart(Protocol): - """This protocol indicates the class can generate a part of SQL""" - - def get_sql(self, **kwargs) -> str: - ...