From 42ecf87594dc1777d4a0f4c03422d05b6a75978c Mon Sep 17 00:00:00 2001 From: fish Date: Sat, 25 Feb 2023 22:48:50 +0800 Subject: [PATCH 01/23] feat: fix type hint fix python3.6 support fix builder overload add ignore_copy type hint fix validate type hint --- pypika/utils.py | 42 +++++++++++++++++------------------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/pypika/utils.py b/pypika/utils.py index 07e63e8d..35a58e9e 100644 --- a/pypika/utils.py +++ b/pypika/utils.py @@ -1,13 +1,5 @@ -from typing import Any, Callable, List, Optional, Protocol, Type, TYPE_CHECKING, runtime_checkable - -if TYPE_CHECKING: - import sys - from typing import overload, TypeVar - - if sys.version_info >= (3, 10): - from typing import ParamSpec, Concatenate - else: - from typing_extensions import ParamSpec, Concatenate +from typing import Any, Callable, List, Optional, Type, Union, overload, TypeVar +from typing_extensions import ParamSpec, Concatenate, Protocol, runtime_checkable __author__ = "Timothy Heys" __email__ = "theys@kayak.com" @@ -45,23 +37,22 @@ class FunctionException(Exception): pass -if TYPE_CHECKING: - _T = TypeVar('_T') - _S = TypeVar('_S') - _P = ParamSpec('_P') +_T = TypeVar('_T') +_S = TypeVar('_S') +_P = ParamSpec('_P') -if TYPE_CHECKING: - @overload - def builder(func: Callable[Concatenate[_S, _P], None]) -> Callable[Concatenate[_S, _P], _S]: - ... +@overload +def builder(func: Callable[Concatenate[_S, _P], Union[_S, None]]) -> Callable[Concatenate[_S, _P], _S]: + ... - @overload - def builder(func: Callable[Concatenate[_S, _P], _T]) -> Callable[Concatenate[_S, _P], _T]: - ... + +@overload +def builder(func: Callable[Concatenate[_S, _P], _T]) -> Callable[Concatenate[_S, _P], _T]: + ... -def builder(func): +def builder(func: Callable[Concatenate[_S, _P], Union[_T, None]]) -> Callable[Concatenate[_S, _P], Union[_T, _S]]: """ Decorator for wrapper "builder" functions. These are functions on the Query class or other classes used for building queries which mutate the query and return self. To make the build functions immutable, this decorator is @@ -70,7 +61,7 @@ def builder(func): """ import copy - def _copy(self, *args, **kwargs): + def _copy(self: _S, *args: _P.args, **kwargs: _P.kwargs): self_copy = copy.copy(self) if getattr(self, "immutable", True) else self result = func(self_copy, *args, **kwargs) @@ -84,7 +75,7 @@ def _copy(self, *args, **kwargs): return _copy -def ignore_copy(func: Callable) -> Callable: +def ignore_copy(func: Callable[[_S, str], _T]) -> Callable[[_S, str], _T]: """ Decorator for wrapping the __getattr__ function for classes that are copied via deepcopy. This prevents infinite recursion caused by deepcopy looking for magic functions in the class. Any class implementing __getattr__ that is @@ -143,8 +134,9 @@ def format_alias_sql( ) -def validate(*args: Any, exc: Exception, type: Optional[Type] = None) -> None: +def validate(*args: Any, exc: Optional[Exception], type: Optional[Type] = None) -> None: if type is not None: + assert exc is not None for arg in args: if not isinstance(arg, type): raise exc From c1feea7c956e2288257d28219c1fa6fa1924156a Mon Sep 17 00:00:00 2001 From: fish Date: Sat, 25 Feb 2023 22:48:50 +0800 Subject: [PATCH 02/23] feat: fix type hint fix python3.6 support fix builder overload add ignore_copy type hint fix validate type hint --- pypika/utils.py | 42 +++++++++++++++++------------------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/pypika/utils.py b/pypika/utils.py index 07e63e8d..35a58e9e 100644 --- a/pypika/utils.py +++ b/pypika/utils.py @@ -1,13 +1,5 @@ -from typing import Any, Callable, List, Optional, Protocol, Type, TYPE_CHECKING, runtime_checkable - -if TYPE_CHECKING: - import sys - from typing import overload, TypeVar - - if sys.version_info >= (3, 10): - from typing import ParamSpec, Concatenate - else: - from typing_extensions import ParamSpec, Concatenate +from typing import Any, Callable, List, Optional, Type, Union, overload, TypeVar +from typing_extensions import ParamSpec, Concatenate, Protocol, runtime_checkable __author__ = "Timothy Heys" __email__ = "theys@kayak.com" @@ -45,23 +37,22 @@ class FunctionException(Exception): pass -if TYPE_CHECKING: - _T = TypeVar('_T') - _S = TypeVar('_S') - _P = ParamSpec('_P') +_T = TypeVar('_T') +_S = TypeVar('_S') +_P = ParamSpec('_P') -if TYPE_CHECKING: - @overload - def builder(func: Callable[Concatenate[_S, _P], None]) -> Callable[Concatenate[_S, _P], _S]: - ... +@overload +def builder(func: Callable[Concatenate[_S, _P], Union[_S, None]]) -> Callable[Concatenate[_S, _P], _S]: + ... - @overload - def builder(func: Callable[Concatenate[_S, _P], _T]) -> Callable[Concatenate[_S, _P], _T]: - ... + +@overload +def builder(func: Callable[Concatenate[_S, _P], _T]) -> Callable[Concatenate[_S, _P], _T]: + ... -def builder(func): +def builder(func: Callable[Concatenate[_S, _P], Union[_T, None]]) -> Callable[Concatenate[_S, _P], Union[_T, _S]]: """ Decorator for wrapper "builder" functions. These are functions on the Query class or other classes used for building queries which mutate the query and return self. To make the build functions immutable, this decorator is @@ -70,7 +61,7 @@ def builder(func): """ import copy - def _copy(self, *args, **kwargs): + def _copy(self: _S, *args: _P.args, **kwargs: _P.kwargs): self_copy = copy.copy(self) if getattr(self, "immutable", True) else self result = func(self_copy, *args, **kwargs) @@ -84,7 +75,7 @@ def _copy(self, *args, **kwargs): return _copy -def ignore_copy(func: Callable) -> Callable: +def ignore_copy(func: Callable[[_S, str], _T]) -> Callable[[_S, str], _T]: """ Decorator for wrapping the __getattr__ function for classes that are copied via deepcopy. This prevents infinite recursion caused by deepcopy looking for magic functions in the class. Any class implementing __getattr__ that is @@ -143,8 +134,9 @@ def format_alias_sql( ) -def validate(*args: Any, exc: Exception, type: Optional[Type] = None) -> None: +def validate(*args: Any, exc: Optional[Exception], type: Optional[Type] = None) -> None: if type is not None: + assert exc is not None for arg in args: if not isinstance(arg, type): raise exc From b824f2905ec3af612d4391f8e05c6e301e8084a0 Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 26 Feb 2023 01:08:40 +0800 Subject: [PATCH 03/23] feat: add terms.py type hint add wrap_constant's overload change wrap_constant's value type: Node -> Term add wrap_json's overload remove Term._assert_guard add get_formatted_value's return type add python3.6, 3.7, 3.8, 3.9 support (Case.__init__) change WindowFrameAnalyticFunction.Edge.modifier default value fix other type --- pypika/terms.py | 182 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 124 insertions(+), 58 deletions(-) diff --git a/pypika/terms.py b/pypika/terms.py index a831c216..2b3ef073 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -11,14 +11,16 @@ Iterable, Iterator, List, - MutableSequence, Optional, + MutableSequence, Sequence, Set, Type, TypeVar, Union, + overload, ) +from typing_extensions import Self from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order from pypika.utils import ( @@ -33,8 +35,7 @@ ) if TYPE_CHECKING: - from pypika.queries import QueryBuilder, Selectable, Table - from _typeshed import Self + from pypika.queries import QueryBuilder, Selectable, Table, Schema __author__ = "Timothy Heys" @@ -42,6 +43,10 @@ NodeT = TypeVar("NodeT", bound="Node") +TermT = TypeVar("TermT", bound="Term") +ValueWrapperT = TypeVar("ValueWrapperT", bound="ValueWrapper") +IntervalT = TypeVar("IntervalT", bound="Interval") +CriterionT = TypeVar("CriterionT", bound="Criterion") class Node: @@ -58,8 +63,12 @@ def find_(self, type: Type[NodeT]) -> List[NodeT]: WrappedConstantStrict = Union["LiteralValue", "Array", "Tuple", "ValueWrapper"] +WrappedConstantValue = Union["Term", int, float, bool, str, date, None] +WrappedConstantValueUnion = Union[ + WrappedConstantValue, List[WrappedConstantValue], typing.Tuple[WrappedConstantValue, ...] +] +WrappedConstant = Union["Term", WrappedConstantStrict] -WrappedConstant = Union[Node, WrappedConstantStrict] class Term(Node, SQLPart): @@ -83,8 +92,41 @@ def tables_(self) -> Set["Table"]: def fields_(self) -> Set["Field"]: return set(self.find_(Field)) + @overload + @staticmethod + def wrap_constant(val: TermT) -> TermT: + ... + + @overload + @staticmethod + def wrap_constant(val: None) -> "NullValue": + ... + + @overload + @staticmethod + def wrap_constant(val: List[WrappedConstantValue]) -> "Array": + ... + + @overload @staticmethod - def wrap_constant(val, wrapper_cls: Optional[Type["Term"]] = None) -> WrappedConstant: + def wrap_constant(val: typing.Tuple[WrappedConstantValue, ...]) -> "Tuple": + ... + + @overload + @staticmethod + def wrap_constant(val: Union[int, float, bool, str, date], wrapper_cls: Type["ValueWrapperT"]) -> "ValueWrapperT": + ... + + @overload + @staticmethod + def wrap_constant(val: Union[int, float, bool, str, date], wrapper_cls: None = None) -> "ValueWrapper": + ... + + @staticmethod + def wrap_constant( + val: WrappedConstantValueUnion, + wrapper_cls: Optional[Type["ValueWrapper"]] = None, + ) -> WrappedConstant: """ Used for wrapping raw inputs such as numbers in Criterions and Operator. @@ -113,9 +155,39 @@ def wrap_constant(val, wrapper_cls: Optional[Type["Term"]] = None) -> WrappedCon wrapper_cls = wrapper_cls or ValueWrapper return wrapper_cls(val) + @overload + @staticmethod + def wrap_json(val: TermT, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> TermT: + ... + + @overload + @staticmethod + def wrap_json(val: IntervalT, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> IntervalT: + ... + + @overload + @staticmethod + def wrap_json(val: None, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> "NullValue": + ... + + @overload + @staticmethod + def wrap_json(val: Union[str, int, bool], wrapper_cls: Type["ValueWrapperT"]) -> "ValueWrapperT": + ... + + @overload + @staticmethod + def wrap_json(val: Union[str, int, bool], wrapper_cls: None = None) -> "ValueWrapper": + ... + + @overload + @staticmethod + def wrap_json(val: object, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> "JSON": + ... + @staticmethod def wrap_json( - val: Union["Term", "QueryBuilder", "Interval", None, str, int, bool], wrapper_cls=None + val: object, wrapper_cls=None ) -> Union["Term", "QueryBuilder", "Interval", "NullValue", "ValueWrapper", "JSON"]: from .queries import QueryBuilder @@ -178,28 +250,28 @@ def ne(self, other: Any) -> "BasicCriterion": return self != other def glob(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.glob, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.glob, self, self.wrap_constant(expr)) def like(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.like, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.like, self, self.wrap_constant(expr)) def not_like(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.not_like, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.not_like, self, self.wrap_constant(expr)) def ilike(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.ilike, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.ilike, self, self.wrap_constant(expr)) def not_ilike(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.not_ilike, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.not_ilike, self, self.wrap_constant(expr)) def rlike(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.rlike, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.rlike, self, self.wrap_constant(expr)) def regex(self, pattern: str) -> "BasicCriterion": - return BasicCriterion(Matching.regex, self, Term._assert_guard(self.wrap_constant(pattern))) + return BasicCriterion(Matching.regex, self, self.wrap_constant(pattern)) def regexp(self, pattern: str) -> "BasicCriterion": - return BasicCriterion(Matching.regexp, self, Term._assert_guard(self.wrap_constant(pattern))) + return BasicCriterion(Matching.regexp, self, self.wrap_constant(pattern)) def between(self, lower: Any, upper: Any) -> "BetweenCriterion": return BetweenCriterion(self, self.wrap_constant(lower), self.wrap_constant(upper)) @@ -208,7 +280,7 @@ def from_to(self, start: Any, end: Any) -> "PeriodCriterion": return PeriodCriterion(self, self.wrap_constant(start), self.wrap_constant(end)) def as_of(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.as_of, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.as_of, self, self.wrap_constant(expr)) def all_(self) -> "All": return All(self) @@ -222,7 +294,7 @@ def notin(self, arg: Union[list, tuple, set, "Term"]) -> "ContainsCriterion": return self.isin(arg).negate() def bin_regex(self, pattern: str) -> "BasicCriterion": - return BasicCriterion(Matching.bin_regex, self, Term._assert_guard(self.wrap_constant(pattern))) + return BasicCriterion(Matching.bin_regex, self, self.wrap_constant(pattern)) def negate(self) -> "Not": return Not(self) @@ -285,22 +357,22 @@ def __rrshift__(self, other: Any) -> "ArithmeticExpression": return ArithmeticExpression(Arithmetic.rshift, self.wrap_constant(other), self) def __eq__(self, other: Any) -> "BasicCriterion": # type: ignore - return BasicCriterion(Equality.eq, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.eq, self, self.wrap_constant(other)) def __ne__(self, other: Any) -> "BasicCriterion": # type: ignore - return BasicCriterion(Equality.ne, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.ne, self, self.wrap_constant(other)) def __gt__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.gt, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.gt, self, self.wrap_constant(other)) def __ge__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.gte, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.gte, self, self.wrap_constant(other)) def __lt__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.lt, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.lt, self, self.wrap_constant(other)) def __le__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.lte, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.lte, self, self.wrap_constant(other)) def __getitem__(self, item: slice) -> "BetweenCriterion": if not isinstance(item, slice): @@ -316,13 +388,6 @@ def __hash__(self) -> int: def get_sql(self, **kwargs: Any) -> str: raise NotImplementedError() - @classmethod - def _assert_guard(cls, v: Any) -> "Term": - if isinstance(v, cls): - return v - else: - raise TypeError("expect Term object, got {}".format(type(v).__name__)) - class Parameter(Term): def __init__(self, placeholder: Union[str, int]) -> None: @@ -404,7 +469,7 @@ def get_value_sql(self, **kwargs: Any) -> str: return self.get_formatted_value(self.value, **kwargs) @classmethod - def get_formatted_value(cls, value: Any, **kwargs): + def get_formatted_value(cls, value: Any, **kwargs) -> str: quote_char = kwargs.get("secondary_quote_char") or "" # FIXME escape values @@ -431,7 +496,7 @@ def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = class JSON(Term): - def __init__(self, value: Any = None, alias: Optional[str] = None) -> None: + def __init__(self, value: object = None, alias: Optional[str] = None) -> None: super().__init__(alias) self.value = value self.table: Optional[Union[str, "Selectable"]] = None @@ -468,10 +533,10 @@ def get_sql(self, secondary_quote_char: str = "'", **kwargs: Any) -> str: return format_alias_sql(sql, self.alias, **kwargs) def get_json_value(self, key_or_index: Union[str, int]) -> "BasicCriterion": - return BasicCriterion(JSONOperators.GET_JSON_VALUE, self, Term._assert_guard(self.wrap_constant(key_or_index))) + return BasicCriterion(JSONOperators.GET_JSON_VALUE, self, self.wrap_constant(key_or_index)) def get_text_value(self, key_or_index: Union[str, int]) -> "BasicCriterion": - return BasicCriterion(JSONOperators.GET_TEXT_VALUE, self, Term._assert_guard(self.wrap_constant(key_or_index))) + return BasicCriterion(JSONOperators.GET_TEXT_VALUE, self, self.wrap_constant(key_or_index)) def get_path_json_value(self, path_json: str) -> "BasicCriterion": return BasicCriterion(JSONOperators.GET_PATH_JSON_VALUE, self, self.wrap_json(path_json)) @@ -656,9 +721,9 @@ def get_sql( # type: ignore class Tuple(Criterion): - def __init__(self, *values: Any) -> None: + def __init__(self, *values: WrappedConstantValueUnion) -> None: super().__init__() - self.values = [self.wrap_constant(value) for value in values] + self.values: List[Term] = [self.wrap_constant(value) for value in values] def nodes_(self) -> Iterator[Node]: yield self @@ -666,7 +731,7 @@ def nodes_(self) -> Iterator[Node]: yield from value.nodes_() def get_sql(self, **kwargs: Any) -> str: - sql = "({})".format(",".join(Term._assert_guard(term).get_sql(**kwargs) for term in self.values)) + sql = "({})".format(",".join(term.get_sql(**kwargs) for term in self.values)) return format_alias_sql(sql, self.alias, **kwargs) @property @@ -685,13 +750,13 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T :return: A copy of the field with the tables replaced. """ - self.values = [Term._assert_guard(value).replace_table(current_table, new_table) for value in self.values] + self.values = [value.replace_table(current_table, new_table) for value in self.values] class Array(Tuple): def get_sql(self, **kwargs: Any) -> str: dialect = kwargs.get("dialect", None) - values = ",".join(Term._assert_guard(term).get_sql(**kwargs) for term in self.values) + values = ",".join(term.get_sql(**kwargs) for term in self.values) sql = "[{}]".format(values) if dialect in (Dialects.POSTGRESQL, Dialects.REDSHIFT): @@ -1037,7 +1102,7 @@ class ArithmeticExpression(Term): add_order = [Arithmetic.add, Arithmetic.sub] - def __init__(self, operator: Arithmetic, left: Any, right: Any, alias: Optional[str] = None) -> None: + def __init__(self, operator: Arithmetic, left: Term, right: Term, alias: Optional[str] = None) -> None: """ Wrapper for an arithmetic expression. @@ -1148,8 +1213,8 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: class Case(Criterion): def __init__(self, alias: Optional[str] = None) -> None: super().__init__(alias=alias) - self._cases: List[typing.Tuple[Any, Any]] = [] - self._else: WrappedConstant | None = None + self._cases: List[typing.Tuple["Criterion", "Term"]] = [] + self._else: Optional[WrappedConstant] = None def nodes_(self) -> Iterator[Node]: yield self @@ -1170,7 +1235,7 @@ def is_aggregate(self) -> Optional[bool]: ) @builder - def when(self, criterion: Any, term: Any): + def when(self, criterion: Criterion, term: WrappedConstantValue): self._cases.append((criterion, self.wrap_constant(term))) @builder @@ -1192,10 +1257,10 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T ) for criterion, term in self._cases ] - self._else = Term._assert_guard(self._else).replace_table(current_table, new_table) if self._else else None + self._else = self._else.replace_table(current_table, new_table) if self._else else None @builder - def else_(self, term: Any) -> "Case": + def else_(self, term: WrappedConstantValue) -> "Case": self._else = self.wrap_constant(term) return self @@ -1207,7 +1272,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: "WHEN {when} THEN {then}".format(when=criterion.get_sql(**kwargs), then=term.get_sql(**kwargs)) for criterion, term in self._cases ) - else_ = " ELSE {}".format(Term._assert_guard(self._else).get_sql(**kwargs)) if self._else else "" + else_ = " ELSE {}".format(self._else.get_sql(**kwargs)) if self._else else "" case_sql = "CASE {cases}{else_} END".format(cases=cases, else_=else_) @@ -1218,7 +1283,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: class Not(Criterion): - def __init__(self, term: Any, alias: Optional[str] = None) -> None: + def __init__(self, term: Term, alias: Optional[str] = None) -> None: super().__init__(alias=alias) self.term = term @@ -1266,7 +1331,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T class All(Criterion): - def __init__(self, term: Any, alias: Optional[str] = None) -> None: + def __init__(self, term: Term, alias: Optional[str] = None) -> None: super().__init__(alias=alias) self.term = term @@ -1285,7 +1350,7 @@ def __init__(self, name: str, params: Optional[Sequence] = None) -> None: self.params = params def __call__(self, *args: Any, **kwargs: Any) -> "Function": - if not self._has_params(): + if self.params is None: return Function(self.name, alias=kwargs.get("alias")) if not self._is_valid_function_call(*args): @@ -1303,15 +1368,16 @@ def _has_params(self): return self.params is not None def _is_valid_function_call(self, *args): + assert self.params is not None return len(args) == len(self.params) class Function(Criterion): - def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: + def __init__(self, name: str, *args: WrappedConstantValueUnion, **kwargs: Any) -> None: super().__init__(kwargs.get("alias")) self.name = name - self.args: MutableSequence[WrappedConstant] = [self.wrap_constant(param) for param in args] - self.schema = kwargs.get("schema") + self.args: MutableSequence[Term] = [self.wrap_constant(param) for param in args] + self.schema: Optional["Schema"] = kwargs.get("schema") def nodes_(self) -> Iterator[Node]: yield self @@ -1340,7 +1406,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T :return: A copy of the criterion with the tables replaced. """ - self.args = [Term._assert_guard(param).replace_table(current_table, new_table) for param in self.args] + self.args = [param.replace_table(current_table, new_table) for param in self.args] def get_special_params_sql(self, **kwargs: Any) -> Any: pass @@ -1355,7 +1421,7 @@ def get_function_sql(self, **kwargs: Any) -> str: return "{name}({args}{special})".format( name=self.name, args=",".join( - Term._assert_guard(p).get_sql(with_alias=False, subquery=True, **kwargs) + p.get_sql(with_alias=False, subquery=True, **kwargs) if hasattr(p, "get_sql") else self.get_arg_sql(p, **kwargs) for p in self.args @@ -1385,7 +1451,7 @@ def get_sql(self, **kwargs: Any) -> str: class AggregateFunction(Function): - is_aggregate = True + is_aggregate: Optional[bool] = True def __init__(self, name, *args, **kwargs): super(AggregateFunction, self).__init__(name, *args, **kwargs) @@ -1414,7 +1480,7 @@ def get_function_sql(self, **kwargs: Any): class AnalyticFunction(AggregateFunction): - is_aggregate = False + is_aggregate: Optional[bool] = False is_analytic = True def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: @@ -1480,7 +1546,7 @@ def get_function_sql(self, **kwargs: Any) -> str: class WindowFrameAnalyticFunction(AnalyticFunction): class Edge: - modifier: ClassVar[Optional[str]] = None + modifier: ClassVar[str] = "" def __init__(self, value: Optional[Union[str, int]] = None) -> None: self.value = value @@ -1488,7 +1554,7 @@ def __init__(self, value: Optional[Union[str, int]] = None) -> None: def __str__(self) -> str: return "{value} {modifier}".format( value=self.value or "UNBOUNDED", - modifier=self.modifier or "", + modifier=self.modifier, ) def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: @@ -1690,7 +1756,7 @@ class AtTimezone(Term, SQLPart): AT TIME ZONE INTERVAL '-06:00' """ - is_aggregate = None + is_aggregate: Optional[bool] = None def __init__(self, field, zone, interval=False, alias=None): super().__init__(alias) From 30b8bc8aa651a75c6402061e7b7da03b9df5fa97 Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 26 Feb 2023 01:08:40 +0800 Subject: [PATCH 04/23] feat: add terms.py type hint add wrap_constant's overload change wrap_constant's value type: Node -> Term add wrap_json's overload remove Term._assert_guard add get_formatted_value's return type add python3.6, 3.7, 3.8, 3.9 support (Case.__init__) change WindowFrameAnalyticFunction.Edge.modifier default value fix other type --- pypika/terms.py | 182 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 124 insertions(+), 58 deletions(-) diff --git a/pypika/terms.py b/pypika/terms.py index a831c216..2b3ef073 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -11,14 +11,16 @@ Iterable, Iterator, List, - MutableSequence, Optional, + MutableSequence, Sequence, Set, Type, TypeVar, Union, + overload, ) +from typing_extensions import Self from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order from pypika.utils import ( @@ -33,8 +35,7 @@ ) if TYPE_CHECKING: - from pypika.queries import QueryBuilder, Selectable, Table - from _typeshed import Self + from pypika.queries import QueryBuilder, Selectable, Table, Schema __author__ = "Timothy Heys" @@ -42,6 +43,10 @@ NodeT = TypeVar("NodeT", bound="Node") +TermT = TypeVar("TermT", bound="Term") +ValueWrapperT = TypeVar("ValueWrapperT", bound="ValueWrapper") +IntervalT = TypeVar("IntervalT", bound="Interval") +CriterionT = TypeVar("CriterionT", bound="Criterion") class Node: @@ -58,8 +63,12 @@ def find_(self, type: Type[NodeT]) -> List[NodeT]: WrappedConstantStrict = Union["LiteralValue", "Array", "Tuple", "ValueWrapper"] +WrappedConstantValue = Union["Term", int, float, bool, str, date, None] +WrappedConstantValueUnion = Union[ + WrappedConstantValue, List[WrappedConstantValue], typing.Tuple[WrappedConstantValue, ...] +] +WrappedConstant = Union["Term", WrappedConstantStrict] -WrappedConstant = Union[Node, WrappedConstantStrict] class Term(Node, SQLPart): @@ -83,8 +92,41 @@ def tables_(self) -> Set["Table"]: def fields_(self) -> Set["Field"]: return set(self.find_(Field)) + @overload + @staticmethod + def wrap_constant(val: TermT) -> TermT: + ... + + @overload + @staticmethod + def wrap_constant(val: None) -> "NullValue": + ... + + @overload + @staticmethod + def wrap_constant(val: List[WrappedConstantValue]) -> "Array": + ... + + @overload @staticmethod - def wrap_constant(val, wrapper_cls: Optional[Type["Term"]] = None) -> WrappedConstant: + def wrap_constant(val: typing.Tuple[WrappedConstantValue, ...]) -> "Tuple": + ... + + @overload + @staticmethod + def wrap_constant(val: Union[int, float, bool, str, date], wrapper_cls: Type["ValueWrapperT"]) -> "ValueWrapperT": + ... + + @overload + @staticmethod + def wrap_constant(val: Union[int, float, bool, str, date], wrapper_cls: None = None) -> "ValueWrapper": + ... + + @staticmethod + def wrap_constant( + val: WrappedConstantValueUnion, + wrapper_cls: Optional[Type["ValueWrapper"]] = None, + ) -> WrappedConstant: """ Used for wrapping raw inputs such as numbers in Criterions and Operator. @@ -113,9 +155,39 @@ def wrap_constant(val, wrapper_cls: Optional[Type["Term"]] = None) -> WrappedCon wrapper_cls = wrapper_cls or ValueWrapper return wrapper_cls(val) + @overload + @staticmethod + def wrap_json(val: TermT, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> TermT: + ... + + @overload + @staticmethod + def wrap_json(val: IntervalT, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> IntervalT: + ... + + @overload + @staticmethod + def wrap_json(val: None, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> "NullValue": + ... + + @overload + @staticmethod + def wrap_json(val: Union[str, int, bool], wrapper_cls: Type["ValueWrapperT"]) -> "ValueWrapperT": + ... + + @overload + @staticmethod + def wrap_json(val: Union[str, int, bool], wrapper_cls: None = None) -> "ValueWrapper": + ... + + @overload + @staticmethod + def wrap_json(val: object, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> "JSON": + ... + @staticmethod def wrap_json( - val: Union["Term", "QueryBuilder", "Interval", None, str, int, bool], wrapper_cls=None + val: object, wrapper_cls=None ) -> Union["Term", "QueryBuilder", "Interval", "NullValue", "ValueWrapper", "JSON"]: from .queries import QueryBuilder @@ -178,28 +250,28 @@ def ne(self, other: Any) -> "BasicCriterion": return self != other def glob(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.glob, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.glob, self, self.wrap_constant(expr)) def like(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.like, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.like, self, self.wrap_constant(expr)) def not_like(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.not_like, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.not_like, self, self.wrap_constant(expr)) def ilike(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.ilike, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.ilike, self, self.wrap_constant(expr)) def not_ilike(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.not_ilike, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.not_ilike, self, self.wrap_constant(expr)) def rlike(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.rlike, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.rlike, self, self.wrap_constant(expr)) def regex(self, pattern: str) -> "BasicCriterion": - return BasicCriterion(Matching.regex, self, Term._assert_guard(self.wrap_constant(pattern))) + return BasicCriterion(Matching.regex, self, self.wrap_constant(pattern)) def regexp(self, pattern: str) -> "BasicCriterion": - return BasicCriterion(Matching.regexp, self, Term._assert_guard(self.wrap_constant(pattern))) + return BasicCriterion(Matching.regexp, self, self.wrap_constant(pattern)) def between(self, lower: Any, upper: Any) -> "BetweenCriterion": return BetweenCriterion(self, self.wrap_constant(lower), self.wrap_constant(upper)) @@ -208,7 +280,7 @@ def from_to(self, start: Any, end: Any) -> "PeriodCriterion": return PeriodCriterion(self, self.wrap_constant(start), self.wrap_constant(end)) def as_of(self, expr: str) -> "BasicCriterion": - return BasicCriterion(Matching.as_of, self, Term._assert_guard(self.wrap_constant(expr))) + return BasicCriterion(Matching.as_of, self, self.wrap_constant(expr)) def all_(self) -> "All": return All(self) @@ -222,7 +294,7 @@ def notin(self, arg: Union[list, tuple, set, "Term"]) -> "ContainsCriterion": return self.isin(arg).negate() def bin_regex(self, pattern: str) -> "BasicCriterion": - return BasicCriterion(Matching.bin_regex, self, Term._assert_guard(self.wrap_constant(pattern))) + return BasicCriterion(Matching.bin_regex, self, self.wrap_constant(pattern)) def negate(self) -> "Not": return Not(self) @@ -285,22 +357,22 @@ def __rrshift__(self, other: Any) -> "ArithmeticExpression": return ArithmeticExpression(Arithmetic.rshift, self.wrap_constant(other), self) def __eq__(self, other: Any) -> "BasicCriterion": # type: ignore - return BasicCriterion(Equality.eq, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.eq, self, self.wrap_constant(other)) def __ne__(self, other: Any) -> "BasicCriterion": # type: ignore - return BasicCriterion(Equality.ne, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.ne, self, self.wrap_constant(other)) def __gt__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.gt, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.gt, self, self.wrap_constant(other)) def __ge__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.gte, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.gte, self, self.wrap_constant(other)) def __lt__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.lt, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.lt, self, self.wrap_constant(other)) def __le__(self, other: Any) -> "BasicCriterion": - return BasicCriterion(Equality.lte, self, Term._assert_guard(self.wrap_constant(other))) + return BasicCriterion(Equality.lte, self, self.wrap_constant(other)) def __getitem__(self, item: slice) -> "BetweenCriterion": if not isinstance(item, slice): @@ -316,13 +388,6 @@ def __hash__(self) -> int: def get_sql(self, **kwargs: Any) -> str: raise NotImplementedError() - @classmethod - def _assert_guard(cls, v: Any) -> "Term": - if isinstance(v, cls): - return v - else: - raise TypeError("expect Term object, got {}".format(type(v).__name__)) - class Parameter(Term): def __init__(self, placeholder: Union[str, int]) -> None: @@ -404,7 +469,7 @@ def get_value_sql(self, **kwargs: Any) -> str: return self.get_formatted_value(self.value, **kwargs) @classmethod - def get_formatted_value(cls, value: Any, **kwargs): + def get_formatted_value(cls, value: Any, **kwargs) -> str: quote_char = kwargs.get("secondary_quote_char") or "" # FIXME escape values @@ -431,7 +496,7 @@ def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = class JSON(Term): - def __init__(self, value: Any = None, alias: Optional[str] = None) -> None: + def __init__(self, value: object = None, alias: Optional[str] = None) -> None: super().__init__(alias) self.value = value self.table: Optional[Union[str, "Selectable"]] = None @@ -468,10 +533,10 @@ def get_sql(self, secondary_quote_char: str = "'", **kwargs: Any) -> str: return format_alias_sql(sql, self.alias, **kwargs) def get_json_value(self, key_or_index: Union[str, int]) -> "BasicCriterion": - return BasicCriterion(JSONOperators.GET_JSON_VALUE, self, Term._assert_guard(self.wrap_constant(key_or_index))) + return BasicCriterion(JSONOperators.GET_JSON_VALUE, self, self.wrap_constant(key_or_index)) def get_text_value(self, key_or_index: Union[str, int]) -> "BasicCriterion": - return BasicCriterion(JSONOperators.GET_TEXT_VALUE, self, Term._assert_guard(self.wrap_constant(key_or_index))) + return BasicCriterion(JSONOperators.GET_TEXT_VALUE, self, self.wrap_constant(key_or_index)) def get_path_json_value(self, path_json: str) -> "BasicCriterion": return BasicCriterion(JSONOperators.GET_PATH_JSON_VALUE, self, self.wrap_json(path_json)) @@ -656,9 +721,9 @@ def get_sql( # type: ignore class Tuple(Criterion): - def __init__(self, *values: Any) -> None: + def __init__(self, *values: WrappedConstantValueUnion) -> None: super().__init__() - self.values = [self.wrap_constant(value) for value in values] + self.values: List[Term] = [self.wrap_constant(value) for value in values] def nodes_(self) -> Iterator[Node]: yield self @@ -666,7 +731,7 @@ def nodes_(self) -> Iterator[Node]: yield from value.nodes_() def get_sql(self, **kwargs: Any) -> str: - sql = "({})".format(",".join(Term._assert_guard(term).get_sql(**kwargs) for term in self.values)) + sql = "({})".format(",".join(term.get_sql(**kwargs) for term in self.values)) return format_alias_sql(sql, self.alias, **kwargs) @property @@ -685,13 +750,13 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T :return: A copy of the field with the tables replaced. """ - self.values = [Term._assert_guard(value).replace_table(current_table, new_table) for value in self.values] + self.values = [value.replace_table(current_table, new_table) for value in self.values] class Array(Tuple): def get_sql(self, **kwargs: Any) -> str: dialect = kwargs.get("dialect", None) - values = ",".join(Term._assert_guard(term).get_sql(**kwargs) for term in self.values) + values = ",".join(term.get_sql(**kwargs) for term in self.values) sql = "[{}]".format(values) if dialect in (Dialects.POSTGRESQL, Dialects.REDSHIFT): @@ -1037,7 +1102,7 @@ class ArithmeticExpression(Term): add_order = [Arithmetic.add, Arithmetic.sub] - def __init__(self, operator: Arithmetic, left: Any, right: Any, alias: Optional[str] = None) -> None: + def __init__(self, operator: Arithmetic, left: Term, right: Term, alias: Optional[str] = None) -> None: """ Wrapper for an arithmetic expression. @@ -1148,8 +1213,8 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: class Case(Criterion): def __init__(self, alias: Optional[str] = None) -> None: super().__init__(alias=alias) - self._cases: List[typing.Tuple[Any, Any]] = [] - self._else: WrappedConstant | None = None + self._cases: List[typing.Tuple["Criterion", "Term"]] = [] + self._else: Optional[WrappedConstant] = None def nodes_(self) -> Iterator[Node]: yield self @@ -1170,7 +1235,7 @@ def is_aggregate(self) -> Optional[bool]: ) @builder - def when(self, criterion: Any, term: Any): + def when(self, criterion: Criterion, term: WrappedConstantValue): self._cases.append((criterion, self.wrap_constant(term))) @builder @@ -1192,10 +1257,10 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T ) for criterion, term in self._cases ] - self._else = Term._assert_guard(self._else).replace_table(current_table, new_table) if self._else else None + self._else = self._else.replace_table(current_table, new_table) if self._else else None @builder - def else_(self, term: Any) -> "Case": + def else_(self, term: WrappedConstantValue) -> "Case": self._else = self.wrap_constant(term) return self @@ -1207,7 +1272,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: "WHEN {when} THEN {then}".format(when=criterion.get_sql(**kwargs), then=term.get_sql(**kwargs)) for criterion, term in self._cases ) - else_ = " ELSE {}".format(Term._assert_guard(self._else).get_sql(**kwargs)) if self._else else "" + else_ = " ELSE {}".format(self._else.get_sql(**kwargs)) if self._else else "" case_sql = "CASE {cases}{else_} END".format(cases=cases, else_=else_) @@ -1218,7 +1283,7 @@ def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str: class Not(Criterion): - def __init__(self, term: Any, alias: Optional[str] = None) -> None: + def __init__(self, term: Term, alias: Optional[str] = None) -> None: super().__init__(alias=alias) self.term = term @@ -1266,7 +1331,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T class All(Criterion): - def __init__(self, term: Any, alias: Optional[str] = None) -> None: + def __init__(self, term: Term, alias: Optional[str] = None) -> None: super().__init__(alias=alias) self.term = term @@ -1285,7 +1350,7 @@ def __init__(self, name: str, params: Optional[Sequence] = None) -> None: self.params = params def __call__(self, *args: Any, **kwargs: Any) -> "Function": - if not self._has_params(): + if self.params is None: return Function(self.name, alias=kwargs.get("alias")) if not self._is_valid_function_call(*args): @@ -1303,15 +1368,16 @@ def _has_params(self): return self.params is not None def _is_valid_function_call(self, *args): + assert self.params is not None return len(args) == len(self.params) class Function(Criterion): - def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: + def __init__(self, name: str, *args: WrappedConstantValueUnion, **kwargs: Any) -> None: super().__init__(kwargs.get("alias")) self.name = name - self.args: MutableSequence[WrappedConstant] = [self.wrap_constant(param) for param in args] - self.schema = kwargs.get("schema") + self.args: MutableSequence[Term] = [self.wrap_constant(param) for param in args] + self.schema: Optional["Schema"] = kwargs.get("schema") def nodes_(self) -> Iterator[Node]: yield self @@ -1340,7 +1406,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T :return: A copy of the criterion with the tables replaced. """ - self.args = [Term._assert_guard(param).replace_table(current_table, new_table) for param in self.args] + self.args = [param.replace_table(current_table, new_table) for param in self.args] def get_special_params_sql(self, **kwargs: Any) -> Any: pass @@ -1355,7 +1421,7 @@ def get_function_sql(self, **kwargs: Any) -> str: return "{name}({args}{special})".format( name=self.name, args=",".join( - Term._assert_guard(p).get_sql(with_alias=False, subquery=True, **kwargs) + p.get_sql(with_alias=False, subquery=True, **kwargs) if hasattr(p, "get_sql") else self.get_arg_sql(p, **kwargs) for p in self.args @@ -1385,7 +1451,7 @@ def get_sql(self, **kwargs: Any) -> str: class AggregateFunction(Function): - is_aggregate = True + is_aggregate: Optional[bool] = True def __init__(self, name, *args, **kwargs): super(AggregateFunction, self).__init__(name, *args, **kwargs) @@ -1414,7 +1480,7 @@ def get_function_sql(self, **kwargs: Any): class AnalyticFunction(AggregateFunction): - is_aggregate = False + is_aggregate: Optional[bool] = False is_analytic = True def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: @@ -1480,7 +1546,7 @@ def get_function_sql(self, **kwargs: Any) -> str: class WindowFrameAnalyticFunction(AnalyticFunction): class Edge: - modifier: ClassVar[Optional[str]] = None + modifier: ClassVar[str] = "" def __init__(self, value: Optional[Union[str, int]] = None) -> None: self.value = value @@ -1488,7 +1554,7 @@ def __init__(self, value: Optional[Union[str, int]] = None) -> None: def __str__(self) -> str: return "{value} {modifier}".format( value=self.value or "UNBOUNDED", - modifier=self.modifier or "", + modifier=self.modifier, ) def __init__(self, name: str, *args: Any, **kwargs: Any) -> None: @@ -1690,7 +1756,7 @@ class AtTimezone(Term, SQLPart): AT TIME ZONE INTERVAL '-06:00' """ - is_aggregate = None + is_aggregate: Optional[bool] = None def __init__(self, field, zone, interval=False, alias=None): super().__init__(alias) From 00a604489dca46e87facc343340d6c7c64b363ea Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 26 Feb 2023 14:23:36 +0800 Subject: [PATCH 05/23] feat: add queries.py type hint Selectable: change Base Type: Node -> Term Table: add generic type Query: add generic type __copy__: return Self type remove Term._assert_guard Joiner: add generic type fix other type --- pypika/queries.py | 143 ++++++++++++++++++++++++++-------------------- 1 file changed, 82 insertions(+), 61 deletions(-) diff --git a/pypika/queries.py b/pypika/queries.py index 81d235c2..cd693f96 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -2,6 +2,7 @@ from functools import reduce from itertools import chain import operator +import builtins from typing import ( Any, Callable, @@ -16,7 +17,10 @@ Set, cast, TypeVar, + overload, + TYPE_CHECKING, ) +from typing_extensions import Self from pypika.enums import Dialects, JoinType, ReferenceOption, SetOperation, Order from pypika.terms import ( @@ -26,7 +30,6 @@ Field, Function, Index, - Node, Rollup, Star, Term, @@ -34,6 +37,7 @@ ValueWrapper, Criterion, PeriodCriterion, + WrappedConstantValue, WrappedConstant, ) from pypika.utils import ( @@ -53,9 +57,16 @@ _T = TypeVar("_T") +SchemaT = TypeVar("SchemaT", bound="Schema") +if TYPE_CHECKING: + from typing_extensions import TypeVar + QueryBuilderType = TypeVar("QueryBuilderType", bound="QueryBuilder", covariant=True, default="QueryBuilder") +else: + QueryBuilderType = TypeVar("QueryBuilderType", bound="QueryBuilder", covariant=True) -class Selectable(Node): + +class Selectable(Term): def __init__(self, alias: Optional[str]) -> None: self.alias = alias @@ -79,10 +90,13 @@ def __getitem__(self, name: str) -> Field: return self.field(name) def get_table_name(self) -> str: - if not self.alias: + if self.alias is None: raise TypeError("expect str, got None") return self.alias + def get_sql(self, **kwargs) -> str: + raise NotImplementedError + class AliasedQuery(Selectable, SQLPart): def __init__(self, name: str, query: Optional[Selectable] = None) -> None: @@ -136,7 +150,7 @@ def __getattr__(self, item: str) -> Schema: return Schema(item, parent=self) -class Table(Selectable): +class Table(Selectable, Generic[QueryBuilderType]): @staticmethod def _init_schema(schema: Union[str, list, tuple, Schema, None]) -> Optional[Schema]: # This is a bit complicated in order to support backwards compatibility. It should probably be cleaned up for @@ -152,14 +166,14 @@ def _init_schema(schema: Union[str, list, tuple, Schema, None]) -> Optional[Sche def __init__( self, name: str, - schema: Optional[Union[Schema, str]] = None, + schema: Union[str, list, tuple, Schema, None] = None, alias: Optional[str] = None, - query_cls: Optional[Type["Query"]] = None, + query_cls: Optional[Type["Query[QueryBuilderType]"]] = None, ) -> None: super().__init__(alias) self._table_name = name self._schema = self._init_schema(schema) - self._query_cls = query_cls or Query + self._query_cls: Type["Query[QueryBuilderType]"] = query_cls or Query self._for: Optional[Criterion] = None self._for_portion: Optional[PeriodCriterion] = None if not issubclass(self._query_cls, Query): @@ -230,7 +244,7 @@ def __ne__(self, other: Any) -> bool: def __hash__(self) -> int: return hash(str(self)) - def select(self, *terms: Sequence[Union[int, float, str, bool, Term, Field]]) -> "QueryBuilder": + def select(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBuilderType": """ Perform a SELECT operation on the current table @@ -243,7 +257,7 @@ def select(self, *terms: Sequence[Union[int, float, str, bool, Term, Field]]) -> """ return self._query_cls.from_(self).select(*terms) - def update(self) -> "QueryBuilder": + def update(self) -> "QueryBuilderType": """ Perform an UPDATE operation on the current table @@ -251,7 +265,7 @@ def update(self) -> "QueryBuilder": """ return self._query_cls.update(self) - def insert(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBuilder": + def insert(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBuilderType": """ Perform an INSERT operation on the current table @@ -265,13 +279,15 @@ def insert(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBui return self._query_cls.into(self).insert(*terms) -def make_tables(*names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List[Table]: +def make_tables( + *names: Union[TypedTuple[str, str], str], query_cls: "Optional[Type[Query[QueryBuilderType]]]" = None, **kwargs: Any +) -> List[Table[QueryBuilderType]]: """ Shortcut to create many tables. If `names` param is a tuple, the first position will refer to the `_table_name` while the second will be its `alias`. Any other data structure will be treated as a whole as the `_table_name`. """ - tables = [] + tables: List["Table[QueryBuilderType]"] = [] for name in names: if isinstance(name, tuple): if len(name) == 2: @@ -279,7 +295,7 @@ def make_tables(*names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List name=name[0], alias=name[1], schema=kwargs.get("schema"), - query_cls=kwargs.get("query_cls"), + query_cls=query_cls, ) else: raise TypeError("expect tuple[str, str] or str, got a tuple with {} element(s)".format(len(name))) @@ -301,7 +317,7 @@ def __init__( column_name: str, column_type: Optional[str] = None, nullable: Optional[bool] = None, - default: Optional[Union[Any, Term]] = None, + default: object = None, ) -> None: self.name = column_name self.type = column_type @@ -373,7 +389,7 @@ def get_sql(self, **kwargs: Any) -> str: _TableClass = Table -class Query: +class Query(Generic[QueryBuilderType]): """ Query is the primary class and entry point in pypika. It is used to build queries iteratively using the builder design @@ -383,11 +399,11 @@ class Query: """ @classmethod - def _builder(cls, **kwargs: Any) -> "QueryBuilder": + def _builder(cls, **kwargs: Any) -> "QueryBuilderType": return QueryBuilder(**kwargs) @classmethod - def from_(cls, table: Union[Selectable, str], **kwargs: Any) -> "QueryBuilder": + def from_(cls, table: Union[Selectable, str], **kwargs: Any) -> "QueryBuilderType": """ Query builder entry point. Initializes query building and sets the table to select from. When using this function, the query becomes a SELECT query. @@ -462,7 +478,7 @@ def drop_view(cls, view: str) -> "DropQueryBuilder": return DropQueryBuilder().drop_view(view) @classmethod - def into(cls, table: Union[Table, str], **kwargs: Any) -> "QueryBuilder": + def into(cls, table: Union[Table, str], **kwargs: Any) -> "QueryBuilderType": """ Query builder entry point. Initializes query building and sets the table to insert into. When using this function, the query becomes an INSERT query. @@ -477,11 +493,11 @@ def into(cls, table: Union[Table, str], **kwargs: Any) -> "QueryBuilder": return cls._builder(**kwargs).into(table) @classmethod - def with_(cls, table: Union[str, Selectable], name: str, **kwargs: Any) -> "QueryBuilder": + def with_(cls, table: Selectable, name: str, **kwargs: Any) -> "QueryBuilderType": return cls._builder(**kwargs).with_(table, name) @classmethod - def select(cls, *terms: Union[int, float, str, bool, Term], **kwargs: Any) -> "QueryBuilder": + def select(cls, *terms: Union[int, float, str, bool, Term], **kwargs: Any) -> "QueryBuilderType": """ Query builder entry point. Initializes query building without a table and selects fields. Useful when testing SQL functions. @@ -497,7 +513,7 @@ def select(cls, *terms: Union[int, float, str, bool, Term], **kwargs: Any) -> "Q return cls._builder(**kwargs).select(*terms) @classmethod - def update(cls, table: Union[str, Table], **kwargs) -> "QueryBuilder": + def update(cls, table: Union[str, Table], **kwargs) -> "QueryBuilderType": """ Query builder entry point. Initializes query building and sets the table to update. When using this function, the query becomes an UPDATE query. @@ -512,7 +528,7 @@ def update(cls, table: Union[str, Table], **kwargs) -> "QueryBuilder": return cls._builder(**kwargs).update(table) @classmethod - def Table(cls, table_name: str, **kwargs) -> _TableClass: + def Table(cls, table_name: str, **kwargs) -> Table[QueryBuilderType]: """ Convenience method for creating a Table that uses this Query class. @@ -523,11 +539,10 @@ def Table(cls, table_name: str, **kwargs) -> _TableClass: :returns Table """ - kwargs["query_cls"] = cls - return Table(table_name, **kwargs) + return Table(table_name, query_cls=cls, **kwargs) @classmethod - def Tables(cls, *names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List[_TableClass]: + def Tables(cls, *names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List["Table[QueryBuilderType]"]: """ Convenience method for creating many tables that uses this Query class. See ``Query.make_tables`` for details. @@ -539,8 +554,7 @@ def Tables(cls, *names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List :returns Table """ - kwargs["query_cls"] = cls - return make_tables(*names, **kwargs) + return make_tables(*names, query_cls=cls, **kwargs) class _SetOperation(Selectable, Term, SQLPart): @@ -562,9 +576,7 @@ def __init__( ): super().__init__(alias) self.base_query = base_query - self._set_operation: List[TypedTuple[SetOperation, Union[QueryBuilder, Selectable]]] = [ - (set_operation, set_operation_query) - ] + self._set_operation: List[TypedTuple[SetOperation, QueryBuilder]] = [(set_operation, set_operation_query)] self._orderbys: List[TypedTuple[Union[Field, WrappedConstant, None], Optional[Order]]] = [] self._limit: Optional[int] = None @@ -599,29 +611,29 @@ def offset(self, offset: int): self._offset = offset @builder - def union(self, other: Selectable): + def union(self, other: "QueryBuilder"): self._set_operation.append((SetOperation.union, other)) @builder - def union_all(self, other: Selectable): + def union_all(self, other: "QueryBuilder"): self._set_operation.append((SetOperation.union_all, other)) @builder - def intersect(self, other: Selectable): + def intersect(self, other: "QueryBuilder"): self._set_operation.append((SetOperation.intersect, other)) @builder - def except_of(self, other: Selectable): + def except_of(self, other: "QueryBuilder"): self._set_operation.append((SetOperation.except_of, other)) @builder - def minus(self, other: Selectable): + def minus(self, other: "QueryBuilder"): self._set_operation.append((SetOperation.minus, other)) - def __add__(self, other: Selectable) -> "_SetOperation": # type: ignore + def __add__(self, other: "QueryBuilder") -> "_SetOperation": # type: ignore return self.union(other) - def __mul__(self, other: Selectable) -> "_SetOperation": # type: ignore + def __mul__(self, other: "QueryBuilder") -> "_SetOperation": # type: ignore return self.union_all(other) def __sub__(self, other: "QueryBuilder") -> "_SetOperation": # type: ignore @@ -671,7 +683,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An querystring = "({query})".format(query=querystring, **kwargs) if with_alias: - return format_alias_sql(querystring, self.alias or self._table_name, **kwargs) + return format_alias_sql(querystring, self.alias or self.get_table_name(), **kwargs) return querystring @@ -750,7 +762,7 @@ def __init__( self._groupbys: List[Union[Term, WrappedConstant]] = [] self._with_totals = False self._havings: Optional[Union[Term, Criterion]] = None - self._orderbys: List[TypedTuple[Union[Field, WrappedConstant], Optional[Order]]] = [] + self._orderbys: List[TypedTuple[WrappedConstant, Optional[Order]]] = [] self._joins: List[Join] = [] self._unions: List[None] = [] self._using: List[Union[Selectable, str]] = [] @@ -776,7 +788,7 @@ def __init__( self.immutable = immutable - def __copy__(self) -> "QueryBuilder": + def __copy__(self) -> Self: newone = type(self).__new__(type(self)) newone.__dict__.update(self.__dict__) newone._select_star_tables = copy(self._select_star_tables) @@ -897,7 +909,7 @@ def select(self, *terms: Any): self._select_other(term) else: value = self.wrap_constant(term, wrapper_cls=self._wrapper_cls) - self._select_other(Term._assert_guard(value)) + self._select_other(value) @builder def delete(self): @@ -1048,7 +1060,7 @@ def rollup(self, *terms: Union[list, tuple, set, Term], **kwargs: Any): self._groupbys.append(Rollup(*wrapped_terms)) @builder - def orderby(self, *fields: Union[str, Field], order: Optional[Order] = None): + def orderby(self, *fields: WrappedConstantValue, order: Optional[Order] = None): table = self._from[0] if not isinstance(table, Selectable): raise TypeError("expect table is a Selectable, got {}".format(type(table).__name__)) @@ -1060,7 +1072,7 @@ def orderby(self, *fields: Union[str, Field], order: Optional[Order] = None): @builder def join( self, item: Union[Table, "QueryBuilder", AliasedQuery, _SetOperation], how: JoinType = JoinType.inner - ) -> "Joiner": + ) -> "Joiner[Self]": if isinstance(item, Table): return Joiner(self, item, how, type_label="table") @@ -1074,31 +1086,31 @@ def join( raise ValueError("Cannot join on type '%s'" % type(item)) - def inner_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def inner_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.inner) - def left_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def left_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.left) - def left_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def left_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.left_outer) - def right_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def right_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.right) - def right_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def right_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.right_outer) - def outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.outer) - def full_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def full_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.full_outer) - def cross_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def cross_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.cross) - def hash_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def hash_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.hash) @builder @@ -1148,7 +1160,15 @@ def slice(self, slice: slice): self._offset = slice.start self._limit = slice.stop - def __getitem__(self, item: Any) -> Union["QueryBuilder", Field]: # type: ignore + @overload + def __getitem__(self, item: str) -> Field: + ... + + @overload + def __getitem__(self, item: builtins.slice) -> Self: + ... + + def __getitem__(self, item: Union[str, builtins.slice]) -> Union[Self, Field]: if not isinstance(item, slice): return super().__getitem__(item) return self.slice(item) @@ -1644,14 +1664,14 @@ def _set_sql(self, **kwargs: Any) -> str: JoinableTerm = Union[Table, "QueryBuilder", AliasedQuery, _SetOperation] -class Joiner: - def __init__(self, query: QueryBuilder, item: JoinableTerm, how: JoinType, type_label: str) -> None: +class Joiner(Generic[QueryBuilderType]): + def __init__(self, query: "QueryBuilderType", item: JoinableTerm, how: JoinType, type_label: str) -> None: self.query = query self.item = item self.how = how self.type_label = type_label - def on(self, criterion: Optional[Criterion], collate: Optional[str] = None) -> QueryBuilder: + def on(self, criterion: Optional[Criterion], collate: Optional[str] = None) -> "QueryBuilderType": if criterion is None: raise JoinException( "Parameter 'criterion' is required for a " @@ -1661,7 +1681,7 @@ def on(self, criterion: Optional[Criterion], collate: Optional[str] = None) -> Q self.query.do_join(JoinOn(self.item, self.how, criterion, collate)) return self.query - def on_field(self, *fields: Any) -> QueryBuilder: + def on_field(self, *fields: Any) -> "QueryBuilderType": if not fields: raise JoinException( "Parameter 'fields' is required for a " "{type} JOIN but was not supplied.".format(type=self.type_label) @@ -1675,14 +1695,14 @@ def on_field(self, *fields: Any) -> QueryBuilder: self.query.do_join(JoinOn(self.item, self.how, cast(Criterion, criterion))) return self.query - def using(self, *fields: Any) -> QueryBuilder: + def using(self, *fields: Any) -> "QueryBuilderType": if not fields: raise JoinException("Parameter 'fields' is required when joining with a using clause but was not supplied.") self.query.do_join(JoinUsing(self.item, self.how, [Field(field) for field in fields])) return self.query - def cross(self) -> QueryBuilder: + def cross(self) -> "QueryBuilderType": """Return cross join""" self.query.do_join(Join(self.item, JoinType.cross)) @@ -1971,8 +1991,8 @@ def foreign_key( columns: List[Union[str, Column]], reference_table: Union[str, Table], reference_columns: List[Union[str, Column]], - on_delete: ReferenceOption = None, - on_update: ReferenceOption = None, + on_delete: Optional[ReferenceOption] = None, + on_update: Optional[ReferenceOption] = None, ): """ Adds a foreign key constraint. @@ -2112,6 +2132,7 @@ def _primary_key_clause(self, **kwargs) -> str: ) def _foreign_key_clause(self, **kwargs) -> str: + assert self._foreign_key_reference_table is not None clause = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})".format( columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key), # type: ignore table_name=( From 00a857370ebbd8386aaf7b5236f5a9e775f1e943 Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 26 Feb 2023 14:23:36 +0800 Subject: [PATCH 06/23] feat: add queries.py type hint Selectable: change Base Type: Node -> Term Table: add generic type Query: add generic type __copy__: return Self type remove Term._assert_guard Joiner: add generic type fix other type --- pypika/queries.py | 143 ++++++++++++++++++++++++++-------------------- 1 file changed, 82 insertions(+), 61 deletions(-) diff --git a/pypika/queries.py b/pypika/queries.py index 81d235c2..cd693f96 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -2,6 +2,7 @@ from functools import reduce from itertools import chain import operator +import builtins from typing import ( Any, Callable, @@ -16,7 +17,10 @@ Set, cast, TypeVar, + overload, + TYPE_CHECKING, ) +from typing_extensions import Self from pypika.enums import Dialects, JoinType, ReferenceOption, SetOperation, Order from pypika.terms import ( @@ -26,7 +30,6 @@ Field, Function, Index, - Node, Rollup, Star, Term, @@ -34,6 +37,7 @@ ValueWrapper, Criterion, PeriodCriterion, + WrappedConstantValue, WrappedConstant, ) from pypika.utils import ( @@ -53,9 +57,16 @@ _T = TypeVar("_T") +SchemaT = TypeVar("SchemaT", bound="Schema") +if TYPE_CHECKING: + from typing_extensions import TypeVar + QueryBuilderType = TypeVar("QueryBuilderType", bound="QueryBuilder", covariant=True, default="QueryBuilder") +else: + QueryBuilderType = TypeVar("QueryBuilderType", bound="QueryBuilder", covariant=True) -class Selectable(Node): + +class Selectable(Term): def __init__(self, alias: Optional[str]) -> None: self.alias = alias @@ -79,10 +90,13 @@ def __getitem__(self, name: str) -> Field: return self.field(name) def get_table_name(self) -> str: - if not self.alias: + if self.alias is None: raise TypeError("expect str, got None") return self.alias + def get_sql(self, **kwargs) -> str: + raise NotImplementedError + class AliasedQuery(Selectable, SQLPart): def __init__(self, name: str, query: Optional[Selectable] = None) -> None: @@ -136,7 +150,7 @@ def __getattr__(self, item: str) -> Schema: return Schema(item, parent=self) -class Table(Selectable): +class Table(Selectable, Generic[QueryBuilderType]): @staticmethod def _init_schema(schema: Union[str, list, tuple, Schema, None]) -> Optional[Schema]: # This is a bit complicated in order to support backwards compatibility. It should probably be cleaned up for @@ -152,14 +166,14 @@ def _init_schema(schema: Union[str, list, tuple, Schema, None]) -> Optional[Sche def __init__( self, name: str, - schema: Optional[Union[Schema, str]] = None, + schema: Union[str, list, tuple, Schema, None] = None, alias: Optional[str] = None, - query_cls: Optional[Type["Query"]] = None, + query_cls: Optional[Type["Query[QueryBuilderType]"]] = None, ) -> None: super().__init__(alias) self._table_name = name self._schema = self._init_schema(schema) - self._query_cls = query_cls or Query + self._query_cls: Type["Query[QueryBuilderType]"] = query_cls or Query self._for: Optional[Criterion] = None self._for_portion: Optional[PeriodCriterion] = None if not issubclass(self._query_cls, Query): @@ -230,7 +244,7 @@ def __ne__(self, other: Any) -> bool: def __hash__(self) -> int: return hash(str(self)) - def select(self, *terms: Sequence[Union[int, float, str, bool, Term, Field]]) -> "QueryBuilder": + def select(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBuilderType": """ Perform a SELECT operation on the current table @@ -243,7 +257,7 @@ def select(self, *terms: Sequence[Union[int, float, str, bool, Term, Field]]) -> """ return self._query_cls.from_(self).select(*terms) - def update(self) -> "QueryBuilder": + def update(self) -> "QueryBuilderType": """ Perform an UPDATE operation on the current table @@ -251,7 +265,7 @@ def update(self) -> "QueryBuilder": """ return self._query_cls.update(self) - def insert(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBuilder": + def insert(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBuilderType": """ Perform an INSERT operation on the current table @@ -265,13 +279,15 @@ def insert(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBui return self._query_cls.into(self).insert(*terms) -def make_tables(*names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List[Table]: +def make_tables( + *names: Union[TypedTuple[str, str], str], query_cls: "Optional[Type[Query[QueryBuilderType]]]" = None, **kwargs: Any +) -> List[Table[QueryBuilderType]]: """ Shortcut to create many tables. If `names` param is a tuple, the first position will refer to the `_table_name` while the second will be its `alias`. Any other data structure will be treated as a whole as the `_table_name`. """ - tables = [] + tables: List["Table[QueryBuilderType]"] = [] for name in names: if isinstance(name, tuple): if len(name) == 2: @@ -279,7 +295,7 @@ def make_tables(*names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List name=name[0], alias=name[1], schema=kwargs.get("schema"), - query_cls=kwargs.get("query_cls"), + query_cls=query_cls, ) else: raise TypeError("expect tuple[str, str] or str, got a tuple with {} element(s)".format(len(name))) @@ -301,7 +317,7 @@ def __init__( column_name: str, column_type: Optional[str] = None, nullable: Optional[bool] = None, - default: Optional[Union[Any, Term]] = None, + default: object = None, ) -> None: self.name = column_name self.type = column_type @@ -373,7 +389,7 @@ def get_sql(self, **kwargs: Any) -> str: _TableClass = Table -class Query: +class Query(Generic[QueryBuilderType]): """ Query is the primary class and entry point in pypika. It is used to build queries iteratively using the builder design @@ -383,11 +399,11 @@ class Query: """ @classmethod - def _builder(cls, **kwargs: Any) -> "QueryBuilder": + def _builder(cls, **kwargs: Any) -> "QueryBuilderType": return QueryBuilder(**kwargs) @classmethod - def from_(cls, table: Union[Selectable, str], **kwargs: Any) -> "QueryBuilder": + def from_(cls, table: Union[Selectable, str], **kwargs: Any) -> "QueryBuilderType": """ Query builder entry point. Initializes query building and sets the table to select from. When using this function, the query becomes a SELECT query. @@ -462,7 +478,7 @@ def drop_view(cls, view: str) -> "DropQueryBuilder": return DropQueryBuilder().drop_view(view) @classmethod - def into(cls, table: Union[Table, str], **kwargs: Any) -> "QueryBuilder": + def into(cls, table: Union[Table, str], **kwargs: Any) -> "QueryBuilderType": """ Query builder entry point. Initializes query building and sets the table to insert into. When using this function, the query becomes an INSERT query. @@ -477,11 +493,11 @@ def into(cls, table: Union[Table, str], **kwargs: Any) -> "QueryBuilder": return cls._builder(**kwargs).into(table) @classmethod - def with_(cls, table: Union[str, Selectable], name: str, **kwargs: Any) -> "QueryBuilder": + def with_(cls, table: Selectable, name: str, **kwargs: Any) -> "QueryBuilderType": return cls._builder(**kwargs).with_(table, name) @classmethod - def select(cls, *terms: Union[int, float, str, bool, Term], **kwargs: Any) -> "QueryBuilder": + def select(cls, *terms: Union[int, float, str, bool, Term], **kwargs: Any) -> "QueryBuilderType": """ Query builder entry point. Initializes query building without a table and selects fields. Useful when testing SQL functions. @@ -497,7 +513,7 @@ def select(cls, *terms: Union[int, float, str, bool, Term], **kwargs: Any) -> "Q return cls._builder(**kwargs).select(*terms) @classmethod - def update(cls, table: Union[str, Table], **kwargs) -> "QueryBuilder": + def update(cls, table: Union[str, Table], **kwargs) -> "QueryBuilderType": """ Query builder entry point. Initializes query building and sets the table to update. When using this function, the query becomes an UPDATE query. @@ -512,7 +528,7 @@ def update(cls, table: Union[str, Table], **kwargs) -> "QueryBuilder": return cls._builder(**kwargs).update(table) @classmethod - def Table(cls, table_name: str, **kwargs) -> _TableClass: + def Table(cls, table_name: str, **kwargs) -> Table[QueryBuilderType]: """ Convenience method for creating a Table that uses this Query class. @@ -523,11 +539,10 @@ def Table(cls, table_name: str, **kwargs) -> _TableClass: :returns Table """ - kwargs["query_cls"] = cls - return Table(table_name, **kwargs) + return Table(table_name, query_cls=cls, **kwargs) @classmethod - def Tables(cls, *names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List[_TableClass]: + def Tables(cls, *names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List["Table[QueryBuilderType]"]: """ Convenience method for creating many tables that uses this Query class. See ``Query.make_tables`` for details. @@ -539,8 +554,7 @@ def Tables(cls, *names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List :returns Table """ - kwargs["query_cls"] = cls - return make_tables(*names, **kwargs) + return make_tables(*names, query_cls=cls, **kwargs) class _SetOperation(Selectable, Term, SQLPart): @@ -562,9 +576,7 @@ def __init__( ): super().__init__(alias) self.base_query = base_query - self._set_operation: List[TypedTuple[SetOperation, Union[QueryBuilder, Selectable]]] = [ - (set_operation, set_operation_query) - ] + self._set_operation: List[TypedTuple[SetOperation, QueryBuilder]] = [(set_operation, set_operation_query)] self._orderbys: List[TypedTuple[Union[Field, WrappedConstant, None], Optional[Order]]] = [] self._limit: Optional[int] = None @@ -599,29 +611,29 @@ def offset(self, offset: int): self._offset = offset @builder - def union(self, other: Selectable): + def union(self, other: "QueryBuilder"): self._set_operation.append((SetOperation.union, other)) @builder - def union_all(self, other: Selectable): + def union_all(self, other: "QueryBuilder"): self._set_operation.append((SetOperation.union_all, other)) @builder - def intersect(self, other: Selectable): + def intersect(self, other: "QueryBuilder"): self._set_operation.append((SetOperation.intersect, other)) @builder - def except_of(self, other: Selectable): + def except_of(self, other: "QueryBuilder"): self._set_operation.append((SetOperation.except_of, other)) @builder - def minus(self, other: Selectable): + def minus(self, other: "QueryBuilder"): self._set_operation.append((SetOperation.minus, other)) - def __add__(self, other: Selectable) -> "_SetOperation": # type: ignore + def __add__(self, other: "QueryBuilder") -> "_SetOperation": # type: ignore return self.union(other) - def __mul__(self, other: Selectable) -> "_SetOperation": # type: ignore + def __mul__(self, other: "QueryBuilder") -> "_SetOperation": # type: ignore return self.union_all(other) def __sub__(self, other: "QueryBuilder") -> "_SetOperation": # type: ignore @@ -671,7 +683,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An querystring = "({query})".format(query=querystring, **kwargs) if with_alias: - return format_alias_sql(querystring, self.alias or self._table_name, **kwargs) + return format_alias_sql(querystring, self.alias or self.get_table_name(), **kwargs) return querystring @@ -750,7 +762,7 @@ def __init__( self._groupbys: List[Union[Term, WrappedConstant]] = [] self._with_totals = False self._havings: Optional[Union[Term, Criterion]] = None - self._orderbys: List[TypedTuple[Union[Field, WrappedConstant], Optional[Order]]] = [] + self._orderbys: List[TypedTuple[WrappedConstant, Optional[Order]]] = [] self._joins: List[Join] = [] self._unions: List[None] = [] self._using: List[Union[Selectable, str]] = [] @@ -776,7 +788,7 @@ def __init__( self.immutable = immutable - def __copy__(self) -> "QueryBuilder": + def __copy__(self) -> Self: newone = type(self).__new__(type(self)) newone.__dict__.update(self.__dict__) newone._select_star_tables = copy(self._select_star_tables) @@ -897,7 +909,7 @@ def select(self, *terms: Any): self._select_other(term) else: value = self.wrap_constant(term, wrapper_cls=self._wrapper_cls) - self._select_other(Term._assert_guard(value)) + self._select_other(value) @builder def delete(self): @@ -1048,7 +1060,7 @@ def rollup(self, *terms: Union[list, tuple, set, Term], **kwargs: Any): self._groupbys.append(Rollup(*wrapped_terms)) @builder - def orderby(self, *fields: Union[str, Field], order: Optional[Order] = None): + def orderby(self, *fields: WrappedConstantValue, order: Optional[Order] = None): table = self._from[0] if not isinstance(table, Selectable): raise TypeError("expect table is a Selectable, got {}".format(type(table).__name__)) @@ -1060,7 +1072,7 @@ def orderby(self, *fields: Union[str, Field], order: Optional[Order] = None): @builder def join( self, item: Union[Table, "QueryBuilder", AliasedQuery, _SetOperation], how: JoinType = JoinType.inner - ) -> "Joiner": + ) -> "Joiner[Self]": if isinstance(item, Table): return Joiner(self, item, how, type_label="table") @@ -1074,31 +1086,31 @@ def join( raise ValueError("Cannot join on type '%s'" % type(item)) - def inner_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def inner_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.inner) - def left_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def left_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.left) - def left_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def left_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.left_outer) - def right_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def right_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.right) - def right_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def right_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.right_outer) - def outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.outer) - def full_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def full_outer_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.full_outer) - def cross_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def cross_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.cross) - def hash_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner": + def hash_join(self, item: Union[Table, "QueryBuilder", AliasedQuery]) -> "Joiner[Self]": return self.join(item, JoinType.hash) @builder @@ -1148,7 +1160,15 @@ def slice(self, slice: slice): self._offset = slice.start self._limit = slice.stop - def __getitem__(self, item: Any) -> Union["QueryBuilder", Field]: # type: ignore + @overload + def __getitem__(self, item: str) -> Field: + ... + + @overload + def __getitem__(self, item: builtins.slice) -> Self: + ... + + def __getitem__(self, item: Union[str, builtins.slice]) -> Union[Self, Field]: if not isinstance(item, slice): return super().__getitem__(item) return self.slice(item) @@ -1644,14 +1664,14 @@ def _set_sql(self, **kwargs: Any) -> str: JoinableTerm = Union[Table, "QueryBuilder", AliasedQuery, _SetOperation] -class Joiner: - def __init__(self, query: QueryBuilder, item: JoinableTerm, how: JoinType, type_label: str) -> None: +class Joiner(Generic[QueryBuilderType]): + def __init__(self, query: "QueryBuilderType", item: JoinableTerm, how: JoinType, type_label: str) -> None: self.query = query self.item = item self.how = how self.type_label = type_label - def on(self, criterion: Optional[Criterion], collate: Optional[str] = None) -> QueryBuilder: + def on(self, criterion: Optional[Criterion], collate: Optional[str] = None) -> "QueryBuilderType": if criterion is None: raise JoinException( "Parameter 'criterion' is required for a " @@ -1661,7 +1681,7 @@ def on(self, criterion: Optional[Criterion], collate: Optional[str] = None) -> Q self.query.do_join(JoinOn(self.item, self.how, criterion, collate)) return self.query - def on_field(self, *fields: Any) -> QueryBuilder: + def on_field(self, *fields: Any) -> "QueryBuilderType": if not fields: raise JoinException( "Parameter 'fields' is required for a " "{type} JOIN but was not supplied.".format(type=self.type_label) @@ -1675,14 +1695,14 @@ def on_field(self, *fields: Any) -> QueryBuilder: self.query.do_join(JoinOn(self.item, self.how, cast(Criterion, criterion))) return self.query - def using(self, *fields: Any) -> QueryBuilder: + def using(self, *fields: Any) -> "QueryBuilderType": if not fields: raise JoinException("Parameter 'fields' is required when joining with a using clause but was not supplied.") self.query.do_join(JoinUsing(self.item, self.how, [Field(field) for field in fields])) return self.query - def cross(self) -> QueryBuilder: + def cross(self) -> "QueryBuilderType": """Return cross join""" self.query.do_join(Join(self.item, JoinType.cross)) @@ -1971,8 +1991,8 @@ def foreign_key( columns: List[Union[str, Column]], reference_table: Union[str, Table], reference_columns: List[Union[str, Column]], - on_delete: ReferenceOption = None, - on_update: ReferenceOption = None, + on_delete: Optional[ReferenceOption] = None, + on_update: Optional[ReferenceOption] = None, ): """ Adds a foreign key constraint. @@ -2112,6 +2132,7 @@ def _primary_key_clause(self, **kwargs) -> str: ) def _foreign_key_clause(self, **kwargs) -> str: + assert self._foreign_key_reference_table is not None clause = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})".format( columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key), # type: ignore table_name=( From c97a96659d430a372a3c34c878274999a4d3ce75 Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 26 Feb 2023 14:30:56 +0800 Subject: [PATCH 07/23] feat: add dialects.py type hint add Query generic __copy__: return Self type change classmethod first argument: self -> cls fix bug: in queries.make_tables --- pypika/dialects.py | 48 ++++++++++++++++++++++++++++------------------ pypika/queries.py | 2 +- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/pypika/dialects.py b/pypika/dialects.py index ab3f1f20..387ee5e6 100644 --- a/pypika/dialects.py +++ b/pypika/dialects.py @@ -1,6 +1,7 @@ import itertools from copy import copy -from typing import Any, Iterable, List, NoReturn, Optional, Set, Union, Tuple as TypedTuple, cast +from typing import Any, Iterable, List, Optional, Set, Union, Tuple as TypedTuple, cast +from typing_extensions import Self, NoReturn from pypika.enums import Dialects from pypika.queries import ( @@ -13,11 +14,20 @@ QueryBuilder, JoinOn, ) -from pypika.terms import ArithmeticExpression, Criterion, EmptyCriterion, Field, Function, Star, Term, ValueWrapper +from pypika.terms import ( + ArithmeticExpression, + Criterion, + EmptyCriterion, + Field, + Function, + Star, + Term, + ValueWrapper, +) from pypika.utils import QueryException, builder, format_quotes -class SnowflakeQuery(Query): +class SnowflakeQuery(Query["SnowflakeQueryBuilder"]): """ Defines a query class for use with Snowflake. """ @@ -61,7 +71,7 @@ def __init__(self) -> None: super().__init__(dialect=Dialects.SNOWFLAKE) -class MySQLQuery(Query): +class MySQLQuery(Query["MySQLQueryBuilder"]): """ Defines a query class for use with MySQL. """ @@ -97,7 +107,7 @@ def __init__(self, **kwargs: Any) -> None: self._for_update_skip_locked = False self._for_update_of: Set[str] = set() - def __copy__(self) -> "MySQLQueryBuilder": + def __copy__(self) -> Self: newone = cast(MySQLQueryBuilder, super().__copy__()) newone._duplicate_updates = copy(self._duplicate_updates) newone._ignore_duplicates = copy(self._ignore_duplicates) @@ -228,7 +238,7 @@ class MySQLDropQueryBuilder(DropQueryBuilder): QUOTE_CHAR = "`" -class VerticaQuery(Query): +class VerticaQuery(Query["VerticaQueryBuilder"]): """ Defines a query class for use with Vertica. """ @@ -350,7 +360,7 @@ def __str__(self) -> str: return self.get_sql() -class OracleQuery(Query): +class OracleQuery(Query["OracleQueryBuilder"]): """ Defines a query class for use with Oracle. """ @@ -374,7 +384,7 @@ def get_sql(self, *args: Any, **kwargs: Any) -> str: return super().get_sql(*args, **kwargs) -class PostgreSQLQuery(Query): +class PostgreSQLQuery(Query["PostgreSQLQueryBuilder"]): """ Defines a query class for use with PostgreSQL. """ @@ -406,7 +416,7 @@ def __init__(self, **kwargs: Any) -> None: self._for_update_skip_locked = False self._for_update_of: Set[str] = set() - def __copy__(self) -> "PostgreSQLQueryBuilder": + def __copy__(self) -> Self: newone = cast(PostgreSQLQueryBuilder, super().__copy__()) newone._returns = copy(self._returns) newone._on_conflict_do_updates = copy(self._on_conflict_do_updates) @@ -428,7 +438,7 @@ def for_update(self, nowait: bool = False, skip_locked: bool = False, of: TypedT self._for_update_of = set(of) @builder - def on_conflict(self, *target_fields: Union[str, Term]) -> None: + def on_conflict(self, *target_fields: Union[str, Term, None]) -> None: if not self._insert_table: raise QueryException("On conflict only applies to insert query") @@ -655,7 +665,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An return querystring -class RedshiftQuery(Query): +class RedshiftQuery(Query["RedShiftQueryBuilder"]): """ Defines a query class for use with Amazon Redshift. """ @@ -669,7 +679,7 @@ class RedShiftQueryBuilder(QueryBuilder): QUERY_CLS = RedshiftQuery -class MSSQLQuery(Query): +class MSSQLQuery(Query["MSSQLQueryBuilder"]): """ Defines a query class for use with Microsoft SQL Server. """ @@ -751,7 +761,7 @@ def _select_sql(self, **kwargs: Any) -> str: ) -class ClickHouseQuery(Query): +class ClickHouseQuery(Query["ClickHouseQueryBuilder"]): """ Defines a query class for use with Yandex ClickHouse. """ @@ -767,23 +777,23 @@ def drop_database(cls, database: Union[Database, str]) -> "ClickHouseDropQueryBu return ClickHouseDropQueryBuilder().drop_database(database) @classmethod - def drop_table(self, table: Union[Table, str]) -> "ClickHouseDropQueryBuilder": + def drop_table(cls, table: Union[Table, str]) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_table(table) @classmethod - def drop_dictionary(self, dictionary: str) -> "ClickHouseDropQueryBuilder": + def drop_dictionary(cls, dictionary: str) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_dictionary(dictionary) @classmethod - def drop_quota(self, quota: str) -> "ClickHouseDropQueryBuilder": + def drop_quota(cls, quota: str) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_quota(quota) @classmethod - def drop_user(self, user: str) -> "ClickHouseDropQueryBuilder": + def drop_user(cls, user: str) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_user(user) @classmethod - def drop_view(self, view: str) -> "ClickHouseDropQueryBuilder": + def drop_view(cls, view: str) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_view(view) @@ -858,7 +868,7 @@ def get_value_sql(self, **kwargs: Any) -> str: return super().get_value_sql(**kwargs) -class SQLLiteQuery(Query): +class SQLLiteQuery(Query["SQLLiteQueryBuilder"]): """ Defines a query class for use with Microsoft SQL Server. """ diff --git a/pypika/queries.py b/pypika/queries.py index cd693f96..28230a2f 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -303,7 +303,7 @@ def make_tables( t = Table( name=name, schema=kwargs.get("schema"), - query_cls=kwargs.get("query_cls"), + query_cls=query_cls, ) tables.append(t) return tables From 83c461293c1dfdb515d2ca18f684f97a42297f46 Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 26 Feb 2023 14:30:56 +0800 Subject: [PATCH 08/23] feat: add dialects.py type hint add Query generic __copy__: return Self type change classmethod first argument: self -> cls fix bug: in queries.make_tables --- pypika/dialects.py | 48 ++++++++++++++++++++++++++++------------------ pypika/queries.py | 2 +- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/pypika/dialects.py b/pypika/dialects.py index ab3f1f20..387ee5e6 100644 --- a/pypika/dialects.py +++ b/pypika/dialects.py @@ -1,6 +1,7 @@ import itertools from copy import copy -from typing import Any, Iterable, List, NoReturn, Optional, Set, Union, Tuple as TypedTuple, cast +from typing import Any, Iterable, List, Optional, Set, Union, Tuple as TypedTuple, cast +from typing_extensions import Self, NoReturn from pypika.enums import Dialects from pypika.queries import ( @@ -13,11 +14,20 @@ QueryBuilder, JoinOn, ) -from pypika.terms import ArithmeticExpression, Criterion, EmptyCriterion, Field, Function, Star, Term, ValueWrapper +from pypika.terms import ( + ArithmeticExpression, + Criterion, + EmptyCriterion, + Field, + Function, + Star, + Term, + ValueWrapper, +) from pypika.utils import QueryException, builder, format_quotes -class SnowflakeQuery(Query): +class SnowflakeQuery(Query["SnowflakeQueryBuilder"]): """ Defines a query class for use with Snowflake. """ @@ -61,7 +71,7 @@ def __init__(self) -> None: super().__init__(dialect=Dialects.SNOWFLAKE) -class MySQLQuery(Query): +class MySQLQuery(Query["MySQLQueryBuilder"]): """ Defines a query class for use with MySQL. """ @@ -97,7 +107,7 @@ def __init__(self, **kwargs: Any) -> None: self._for_update_skip_locked = False self._for_update_of: Set[str] = set() - def __copy__(self) -> "MySQLQueryBuilder": + def __copy__(self) -> Self: newone = cast(MySQLQueryBuilder, super().__copy__()) newone._duplicate_updates = copy(self._duplicate_updates) newone._ignore_duplicates = copy(self._ignore_duplicates) @@ -228,7 +238,7 @@ class MySQLDropQueryBuilder(DropQueryBuilder): QUOTE_CHAR = "`" -class VerticaQuery(Query): +class VerticaQuery(Query["VerticaQueryBuilder"]): """ Defines a query class for use with Vertica. """ @@ -350,7 +360,7 @@ def __str__(self) -> str: return self.get_sql() -class OracleQuery(Query): +class OracleQuery(Query["OracleQueryBuilder"]): """ Defines a query class for use with Oracle. """ @@ -374,7 +384,7 @@ def get_sql(self, *args: Any, **kwargs: Any) -> str: return super().get_sql(*args, **kwargs) -class PostgreSQLQuery(Query): +class PostgreSQLQuery(Query["PostgreSQLQueryBuilder"]): """ Defines a query class for use with PostgreSQL. """ @@ -406,7 +416,7 @@ def __init__(self, **kwargs: Any) -> None: self._for_update_skip_locked = False self._for_update_of: Set[str] = set() - def __copy__(self) -> "PostgreSQLQueryBuilder": + def __copy__(self) -> Self: newone = cast(PostgreSQLQueryBuilder, super().__copy__()) newone._returns = copy(self._returns) newone._on_conflict_do_updates = copy(self._on_conflict_do_updates) @@ -428,7 +438,7 @@ def for_update(self, nowait: bool = False, skip_locked: bool = False, of: TypedT self._for_update_of = set(of) @builder - def on_conflict(self, *target_fields: Union[str, Term]) -> None: + def on_conflict(self, *target_fields: Union[str, Term, None]) -> None: if not self._insert_table: raise QueryException("On conflict only applies to insert query") @@ -655,7 +665,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An return querystring -class RedshiftQuery(Query): +class RedshiftQuery(Query["RedShiftQueryBuilder"]): """ Defines a query class for use with Amazon Redshift. """ @@ -669,7 +679,7 @@ class RedShiftQueryBuilder(QueryBuilder): QUERY_CLS = RedshiftQuery -class MSSQLQuery(Query): +class MSSQLQuery(Query["MSSQLQueryBuilder"]): """ Defines a query class for use with Microsoft SQL Server. """ @@ -751,7 +761,7 @@ def _select_sql(self, **kwargs: Any) -> str: ) -class ClickHouseQuery(Query): +class ClickHouseQuery(Query["ClickHouseQueryBuilder"]): """ Defines a query class for use with Yandex ClickHouse. """ @@ -767,23 +777,23 @@ def drop_database(cls, database: Union[Database, str]) -> "ClickHouseDropQueryBu return ClickHouseDropQueryBuilder().drop_database(database) @classmethod - def drop_table(self, table: Union[Table, str]) -> "ClickHouseDropQueryBuilder": + def drop_table(cls, table: Union[Table, str]) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_table(table) @classmethod - def drop_dictionary(self, dictionary: str) -> "ClickHouseDropQueryBuilder": + def drop_dictionary(cls, dictionary: str) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_dictionary(dictionary) @classmethod - def drop_quota(self, quota: str) -> "ClickHouseDropQueryBuilder": + def drop_quota(cls, quota: str) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_quota(quota) @classmethod - def drop_user(self, user: str) -> "ClickHouseDropQueryBuilder": + def drop_user(cls, user: str) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_user(user) @classmethod - def drop_view(self, view: str) -> "ClickHouseDropQueryBuilder": + def drop_view(cls, view: str) -> "ClickHouseDropQueryBuilder": return ClickHouseDropQueryBuilder().drop_view(view) @@ -858,7 +868,7 @@ def get_value_sql(self, **kwargs: Any) -> str: return super().get_value_sql(**kwargs) -class SQLLiteQuery(Query): +class SQLLiteQuery(Query["SQLLiteQueryBuilder"]): """ Defines a query class for use with Microsoft SQL Server. """ diff --git a/pypika/queries.py b/pypika/queries.py index cd693f96..28230a2f 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -303,7 +303,7 @@ def make_tables( t = Table( name=name, schema=kwargs.get("schema"), - query_cls=kwargs.get("query_cls"), + query_cls=query_cls, ) tables.append(t) return tables From 975417d3cee8c7183cd952b9ad5d80762e0ed4c4 Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 26 Feb 2023 14:37:56 +0800 Subject: [PATCH 09/23] feat: fix array.py --- pypika/clickhouse/array.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/pypika/clickhouse/array.py b/pypika/clickhouse/array.py index 40de67c5..20c422f6 100644 --- a/pypika/clickhouse/array.py +++ b/pypika/clickhouse/array.py @@ -1,5 +1,5 @@ import abc -from typing import Union +from typing import Union, Optional, TYPE_CHECKING from pypika.terms import ( Field, @@ -8,9 +8,14 @@ ) from pypika.utils import format_alias_sql +if TYPE_CHECKING: + from pypika.queries import Schema + class Array(Term): - def __init__(self, values: list, converter_cls=None, converter_options: dict = None, alias: str = None): + def __init__( + self, values: list, converter_cls=None, converter_options: Optional[dict] = None, alias: Optional[str] = None + ): super().__init__(alias) self._values = values self._converter_cls = converter_cls @@ -35,14 +40,14 @@ def __init__( self, left_array: Union[Array, Field], right_array: Union[Array, Field], - alias: str = None, - schema: str = None, + alias: Optional[str] = None, + schema: Optional["Schema"] = None, ): self._left_array = left_array self._right_array = right_array self.alias = alias self.schema = schema - self.args = tuple() + self.args = [] self.name = "hasAny" def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, dialect=None, **kwargs): @@ -57,7 +62,7 @@ def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, diale class _AbstractArrayFunction(Function, metaclass=abc.ABCMeta): - def __init__(self, array: Union[Array, Field], alias: str = None, schema: str = None): + def __init__(self, array: Union[Array, Field], alias: Optional[str] = None, schema: Optional["Schema"] = None): self.schema = schema self.alias = alias self.name = self.clickhouse_function() From 105b13db2959dbf5d00ed1641c21d4d1ee8a0d05 Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 26 Feb 2023 14:37:56 +0800 Subject: [PATCH 10/23] feat: fix array.py --- pypika/clickhouse/array.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/pypika/clickhouse/array.py b/pypika/clickhouse/array.py index 40de67c5..20c422f6 100644 --- a/pypika/clickhouse/array.py +++ b/pypika/clickhouse/array.py @@ -1,5 +1,5 @@ import abc -from typing import Union +from typing import Union, Optional, TYPE_CHECKING from pypika.terms import ( Field, @@ -8,9 +8,14 @@ ) from pypika.utils import format_alias_sql +if TYPE_CHECKING: + from pypika.queries import Schema + class Array(Term): - def __init__(self, values: list, converter_cls=None, converter_options: dict = None, alias: str = None): + def __init__( + self, values: list, converter_cls=None, converter_options: Optional[dict] = None, alias: Optional[str] = None + ): super().__init__(alias) self._values = values self._converter_cls = converter_cls @@ -35,14 +40,14 @@ def __init__( self, left_array: Union[Array, Field], right_array: Union[Array, Field], - alias: str = None, - schema: str = None, + alias: Optional[str] = None, + schema: Optional["Schema"] = None, ): self._left_array = left_array self._right_array = right_array self.alias = alias self.schema = schema - self.args = tuple() + self.args = [] self.name = "hasAny" def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, dialect=None, **kwargs): @@ -57,7 +62,7 @@ def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, diale class _AbstractArrayFunction(Function, metaclass=abc.ABCMeta): - def __init__(self, array: Union[Array, Field], alias: str = None, schema: str = None): + def __init__(self, array: Union[Array, Field], alias: Optional[str] = None, schema: Optional["Schema"] = None): self.schema = schema self.alias = alias self.name = self.clickhouse_function() From 33ca0b6398c447c66fe36adcc05ec61e59763b9e Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 26 Feb 2023 14:41:08 +0800 Subject: [PATCH 11/23] feat: fix search_string.py --- pypika/clickhouse/search_string.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pypika/clickhouse/search_string.py b/pypika/clickhouse/search_string.py index 22a03027..5445f1b7 100644 --- a/pypika/clickhouse/search_string.py +++ b/pypika/clickhouse/search_string.py @@ -1,11 +1,13 @@ import abc +from typing import Optional + from pypika.terms import Function from pypika.utils import format_alias_sql class _AbstractSearchString(Function, metaclass=abc.ABCMeta): - def __init__(self, name, pattern: str, alias: str = None): + def __init__(self, name, pattern: str, alias: Optional[str] = None): super(_AbstractSearchString, self).__init__(self.clickhouse_function(), name, alias=alias) self._pattern = pattern @@ -50,7 +52,7 @@ def clickhouse_function(cls) -> str: class _AbstractMultiSearchString(Function, metaclass=abc.ABCMeta): - def __init__(self, name, patterns: list, alias: str = None): + def __init__(self, name, patterns: list, alias: Optional[str] = None): super(_AbstractMultiSearchString, self).__init__(self.clickhouse_function(), name, alias=alias) self._patterns = patterns From 1e5371d953d643100bc0f843636f3a45aa0e7e81 Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 26 Feb 2023 14:41:08 +0800 Subject: [PATCH 12/23] feat: fix search_string.py --- pypika/clickhouse/search_string.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pypika/clickhouse/search_string.py b/pypika/clickhouse/search_string.py index 22a03027..5445f1b7 100644 --- a/pypika/clickhouse/search_string.py +++ b/pypika/clickhouse/search_string.py @@ -1,11 +1,13 @@ import abc +from typing import Optional + from pypika.terms import Function from pypika.utils import format_alias_sql class _AbstractSearchString(Function, metaclass=abc.ABCMeta): - def __init__(self, name, pattern: str, alias: str = None): + def __init__(self, name, pattern: str, alias: Optional[str] = None): super(_AbstractSearchString, self).__init__(self.clickhouse_function(), name, alias=alias) self._pattern = pattern @@ -50,7 +52,7 @@ def clickhouse_function(cls) -> str: class _AbstractMultiSearchString(Function, metaclass=abc.ABCMeta): - def __init__(self, name, patterns: list, alias: str = None): + def __init__(self, name, patterns: list, alias: Optional[str] = None): super(_AbstractMultiSearchString, self).__init__(self.clickhouse_function(), name, alias=alias) self._patterns = patterns From 50d7817101d98269ca67f2abf436dc97b09d4a03 Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 26 Feb 2023 14:41:42 +0800 Subject: [PATCH 13/23] feat: fix type_conversion.py --- pypika/clickhouse/type_conversion.py | 33 +++++++++++++++------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/pypika/clickhouse/type_conversion.py b/pypika/clickhouse/type_conversion.py index 80229b7e..a42ea6c6 100644 --- a/pypika/clickhouse/type_conversion.py +++ b/pypika/clickhouse/type_conversion.py @@ -2,22 +2,25 @@ Field, Function, ) +from pypika.queries import Schema from pypika.utils import format_alias_sql +from typing import Optional + class ToString(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToString, self).__init__("toString", name, alias=alias) class ToFixedString(Function): - def __init__(self, field, length: int, alias: str = None, schema: str = None): + def __init__(self, field, length: int, alias: Optional[str] = None, schema: Optional[Schema] = None): self._length = length self._field = field self.alias = alias self.name = "toFixedString" self.schema = schema - self.args = () + self.args = [] def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, dialect=None, **kwargs): sql = "{name}({field},{length})".format( @@ -29,60 +32,60 @@ def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, diale class ToInt8(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToInt8, self).__init__("toInt8", name, alias=alias) class ToInt16(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToInt16, self).__init__("toInt16", name, alias=alias) class ToInt32(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToInt32, self).__init__("toInt32", name, alias=alias) class ToInt64(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToInt64, self).__init__("toInt64", name, alias=alias) class ToUInt8(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToUInt8, self).__init__("toUInt8", name, alias=alias) class ToUInt16(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToUInt16, self).__init__("toUInt16", name, alias=alias) class ToUInt32(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToUInt32, self).__init__("toUInt32", name, alias=alias) class ToUInt64(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToUInt64, self).__init__("toUInt64", name, alias=alias) class ToFloat32(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToFloat32, self).__init__("toFloat32", name, alias=alias) class ToFloat64(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToFloat64, self).__init__("toFloat64", name, alias=alias) class ToDate(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToDate, self).__init__("toDate", name, alias=alias) class ToDateTime(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToDateTime, self).__init__("toDateTime", name, alias=alias) From f04ed6190c41dbe2fec7f0cd54429b822d12308b Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 26 Feb 2023 14:41:42 +0800 Subject: [PATCH 14/23] feat: fix type_conversion.py --- pypika/clickhouse/type_conversion.py | 33 +++++++++++++++------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/pypika/clickhouse/type_conversion.py b/pypika/clickhouse/type_conversion.py index 80229b7e..a42ea6c6 100644 --- a/pypika/clickhouse/type_conversion.py +++ b/pypika/clickhouse/type_conversion.py @@ -2,22 +2,25 @@ Field, Function, ) +from pypika.queries import Schema from pypika.utils import format_alias_sql +from typing import Optional + class ToString(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToString, self).__init__("toString", name, alias=alias) class ToFixedString(Function): - def __init__(self, field, length: int, alias: str = None, schema: str = None): + def __init__(self, field, length: int, alias: Optional[str] = None, schema: Optional[Schema] = None): self._length = length self._field = field self.alias = alias self.name = "toFixedString" self.schema = schema - self.args = () + self.args = [] def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, dialect=None, **kwargs): sql = "{name}({field},{length})".format( @@ -29,60 +32,60 @@ def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, diale class ToInt8(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToInt8, self).__init__("toInt8", name, alias=alias) class ToInt16(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToInt16, self).__init__("toInt16", name, alias=alias) class ToInt32(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToInt32, self).__init__("toInt32", name, alias=alias) class ToInt64(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToInt64, self).__init__("toInt64", name, alias=alias) class ToUInt8(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToUInt8, self).__init__("toUInt8", name, alias=alias) class ToUInt16(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToUInt16, self).__init__("toUInt16", name, alias=alias) class ToUInt32(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToUInt32, self).__init__("toUInt32", name, alias=alias) class ToUInt64(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToUInt64, self).__init__("toUInt64", name, alias=alias) class ToFloat32(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToFloat32, self).__init__("toFloat32", name, alias=alias) class ToFloat64(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToFloat64, self).__init__("toFloat64", name, alias=alias) class ToDate(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToDate, self).__init__("toDate", name, alias=alias) class ToDateTime(Function): - def __init__(self, name, alias: str = None): + def __init__(self, name, alias: Optional[str] = None): super(ToDateTime, self).__init__("toDateTime", name, alias=alias) From 3c0be4c30502fd5706ab294011aedf6ee3747ddd Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 26 Feb 2023 16:08:04 +0800 Subject: [PATCH 15/23] feat: fix return type: Term.any Term.all --- pypika/terms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pypika/terms.py b/pypika/terms.py index 2b3ef073..3fbc45fd 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -599,7 +599,7 @@ def __xor__(self, other: Any) -> "ComplexCriterion": return ComplexCriterion(Boolean.xor_, self, other) @staticmethod - def any(terms: Iterable[Term] = ()) -> "EmptyCriterion": + def any(terms: Iterable[Term] = ()) -> "Criterion": crit = EmptyCriterion() for term in terms: @@ -608,7 +608,7 @@ def any(terms: Iterable[Term] = ()) -> "EmptyCriterion": return crit @staticmethod - def all(terms: Iterable[Any] = ()) -> "EmptyCriterion": + def all(terms: Iterable[Any] = ()) -> "Criterion": crit = EmptyCriterion() for term in terms: From 496e68dcd2362a7511f326c59ffee84119b0af5c Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 26 Feb 2023 16:08:04 +0800 Subject: [PATCH 16/23] feat: fix return type: Term.any Term.all --- pypika/terms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pypika/terms.py b/pypika/terms.py index 2b3ef073..3fbc45fd 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -599,7 +599,7 @@ def __xor__(self, other: Any) -> "ComplexCriterion": return ComplexCriterion(Boolean.xor_, self, other) @staticmethod - def any(terms: Iterable[Term] = ()) -> "EmptyCriterion": + def any(terms: Iterable[Term] = ()) -> "Criterion": crit = EmptyCriterion() for term in terms: @@ -608,7 +608,7 @@ def any(terms: Iterable[Term] = ()) -> "EmptyCriterion": return crit @staticmethod - def all(terms: Iterable[Any] = ()) -> "EmptyCriterion": + def all(terms: Iterable[Any] = ()) -> "Criterion": crit = EmptyCriterion() for term in terms: From 05d5de91cabf6c8645e868bab420a2e3eee35b49 Mon Sep 17 00:00:00 2001 From: fish Date: Wed, 1 Mar 2023 00:14:44 +0800 Subject: [PATCH 17/23] style: reformat file using black --- pypika/terms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pypika/terms.py b/pypika/terms.py index 3fbc45fd..15e8068f 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -70,7 +70,6 @@ def find_(self, type: Type[NodeT]) -> List[NodeT]: WrappedConstant = Union["Term", WrappedConstantStrict] - class Term(Node, SQLPart): def __init__(self, alias: Optional[str] = None) -> None: self.alias = alias From 6338289caaf4431d6aab8fd6b6643f99d69fdefd Mon Sep 17 00:00:00 2001 From: fish Date: Wed, 1 Mar 2023 00:14:44 +0800 Subject: [PATCH 18/23] style: reformat file using black --- pypika/terms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pypika/terms.py b/pypika/terms.py index 3fbc45fd..15e8068f 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -70,7 +70,6 @@ def find_(self, type: Type[NodeT]) -> List[NodeT]: WrappedConstant = Union["Term", WrappedConstantStrict] - class Term(Node, SQLPart): def __init__(self, alias: Optional[str] = None) -> None: self.alias = alias From a30ce3e6216ea33edc7de9a968194fcd62f9df1f Mon Sep 17 00:00:00 2001 From: fish Date: Wed, 1 Mar 2023 00:44:49 +0800 Subject: [PATCH 19/23] feat: add type hint in Criterion --- pypika/terms.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pypika/terms.py b/pypika/terms.py index 15e8068f..8dc492eb 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -588,17 +588,17 @@ def __init__(self, alias: Optional[str] = None) -> None: class Criterion(Term): - def __and__(self, other: Any) -> "ComplexCriterion": + def __and__(self, other: "Criterion") -> "ComplexCriterion": return ComplexCriterion(Boolean.and_, self, other) - def __or__(self, other: Any) -> "ComplexCriterion": + def __or__(self, other: "Criterion") -> "ComplexCriterion": return ComplexCriterion(Boolean.or_, self, other) - def __xor__(self, other: Any) -> "ComplexCriterion": + def __xor__(self, other: "Criterion") -> "ComplexCriterion": return ComplexCriterion(Boolean.xor_, self, other) @staticmethod - def any(terms: Iterable[Term] = ()) -> "Criterion": + def any(terms: Iterable["Criterion"] = ()) -> "Criterion": crit = EmptyCriterion() for term in terms: @@ -607,7 +607,7 @@ def any(terms: Iterable[Term] = ()) -> "Criterion": return crit @staticmethod - def all(terms: Iterable[Any] = ()) -> "Criterion": + def all(terms: Iterable["Criterion"] = ()) -> "Criterion": crit = EmptyCriterion() for term in terms: @@ -623,13 +623,13 @@ class EmptyCriterion(Criterion): def fields_(self) -> Set["Field"]: return set() - def __and__(self, other: Any) -> Any: + def __and__(self, other: CriterionT) -> CriterionT: return other - def __or__(self, other: Any) -> Any: + def __or__(self, other: CriterionT) -> CriterionT: return other - def __xor__(self, other: Any) -> Any: + def __xor__(self, other: CriterionT) -> CriterionT: return other @property From 040b2586f2f418f089649927661bf0ac55a724d0 Mon Sep 17 00:00:00 2001 From: fish Date: Wed, 1 Mar 2023 13:10:43 +0800 Subject: [PATCH 20/23] fix: fix mypy error in terms.py --- pypika/terms.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/pypika/terms.py b/pypika/terms.py index 8dc492eb..157cab4d 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -156,27 +156,22 @@ def wrap_constant( @overload @staticmethod - def wrap_json(val: TermT, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> TermT: + def wrap_json(val: TermT, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> TermT: # type: ignore[misc] ... @overload @staticmethod - def wrap_json(val: IntervalT, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> IntervalT: + def wrap_json(val: None, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> "NullValue": # type: ignore[misc] ... @overload @staticmethod - def wrap_json(val: None, wrapper_cls: Optional[Type["ValueWrapper"]] = None) -> "NullValue": + def wrap_json(val: Union[str, int, bool], wrapper_cls: Type["ValueWrapperT"]) -> "ValueWrapperT": # type: ignore[misc] ... @overload @staticmethod - def wrap_json(val: Union[str, int, bool], wrapper_cls: Type["ValueWrapperT"]) -> "ValueWrapperT": - ... - - @overload - @staticmethod - def wrap_json(val: Union[str, int, bool], wrapper_cls: None = None) -> "ValueWrapper": + def wrap_json(val: Union[str, int, bool], wrapper_cls: None = None) -> "ValueWrapper": # type: ignore[misc] ... @overload @@ -588,18 +583,18 @@ def __init__(self, alias: Optional[str] = None) -> None: class Criterion(Term): - def __and__(self, other: "Criterion") -> "ComplexCriterion": + def __and__(self, other: "Criterion") -> "Criterion": return ComplexCriterion(Boolean.and_, self, other) - def __or__(self, other: "Criterion") -> "ComplexCriterion": + def __or__(self, other: "Criterion") -> "Criterion": return ComplexCriterion(Boolean.or_, self, other) - def __xor__(self, other: "Criterion") -> "ComplexCriterion": + def __xor__(self, other: "Criterion") -> "Criterion": return ComplexCriterion(Boolean.xor_, self, other) @staticmethod def any(terms: Iterable["Criterion"] = ()) -> "Criterion": - crit = EmptyCriterion() + crit: Criterion = EmptyCriterion() for term in terms: crit |= term @@ -608,7 +603,7 @@ def any(terms: Iterable["Criterion"] = ()) -> "Criterion": @staticmethod def all(terms: Iterable["Criterion"] = ()) -> "Criterion": - crit = EmptyCriterion() + crit: Criterion = EmptyCriterion() for term in terms: crit &= term From e87d84915b3731b1e3074b8fd06e01e05bb6f644 Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 5 Mar 2023 15:36:13 +0800 Subject: [PATCH 21/23] feat: fix type checking using mypy --- pypika/dialects.py | 24 +-- pypika/queries.py | 387 +++++++++++++++++++++++---------------------- 2 files changed, 211 insertions(+), 200 deletions(-) diff --git a/pypika/dialects.py b/pypika/dialects.py index 387ee5e6..f63edbb9 100644 --- a/pypika/dialects.py +++ b/pypika/dialects.py @@ -10,7 +10,7 @@ DropQueryBuilder, Selectable, Table, - Query, + BaseQuery, QueryBuilder, JoinOn, ) @@ -27,7 +27,7 @@ from pypika.utils import QueryException, builder, format_quotes -class SnowflakeQuery(Query["SnowflakeQueryBuilder"]): +class SnowflakeQuery(BaseQuery["SnowflakeQueryBuilder"]): """ Defines a query class for use with Snowflake. """ @@ -71,7 +71,7 @@ def __init__(self) -> None: super().__init__(dialect=Dialects.SNOWFLAKE) -class MySQLQuery(Query["MySQLQueryBuilder"]): +class MySQLQuery(BaseQuery["MySQLQueryBuilder"]): """ Defines a query class for use with MySQL. """ @@ -108,7 +108,7 @@ def __init__(self, **kwargs: Any) -> None: self._for_update_of: Set[str] = set() def __copy__(self) -> Self: - newone = cast(MySQLQueryBuilder, super().__copy__()) + newone = super().__copy__() newone._duplicate_updates = copy(self._duplicate_updates) newone._ignore_duplicates = copy(self._ignore_duplicates) return newone @@ -238,7 +238,7 @@ class MySQLDropQueryBuilder(DropQueryBuilder): QUOTE_CHAR = "`" -class VerticaQuery(Query["VerticaQueryBuilder"]): +class VerticaQuery(BaseQuery["VerticaQueryBuilder"]): """ Defines a query class for use with Vertica. """ @@ -360,7 +360,7 @@ def __str__(self) -> str: return self.get_sql() -class OracleQuery(Query["OracleQueryBuilder"]): +class OracleQuery(BaseQuery["OracleQueryBuilder"]): """ Defines a query class for use with Oracle. """ @@ -384,7 +384,7 @@ def get_sql(self, *args: Any, **kwargs: Any) -> str: return super().get_sql(*args, **kwargs) -class PostgreSQLQuery(Query["PostgreSQLQueryBuilder"]): +class PostgreSQLQuery(BaseQuery["PostgreSQLQueryBuilder"]): """ Defines a query class for use with PostgreSQL. """ @@ -417,7 +417,7 @@ def __init__(self, **kwargs: Any) -> None: self._for_update_of: Set[str] = set() def __copy__(self) -> Self: - newone = cast(PostgreSQLQueryBuilder, super().__copy__()) + newone = super().__copy__() newone._returns = copy(self._returns) newone._on_conflict_do_updates = copy(self._on_conflict_do_updates) return newone @@ -665,7 +665,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An return querystring -class RedshiftQuery(Query["RedShiftQueryBuilder"]): +class RedshiftQuery(BaseQuery["RedShiftQueryBuilder"]): """ Defines a query class for use with Amazon Redshift. """ @@ -679,7 +679,7 @@ class RedShiftQueryBuilder(QueryBuilder): QUERY_CLS = RedshiftQuery -class MSSQLQuery(Query["MSSQLQueryBuilder"]): +class MSSQLQuery(BaseQuery["MSSQLQueryBuilder"]): """ Defines a query class for use with Microsoft SQL Server. """ @@ -761,7 +761,7 @@ def _select_sql(self, **kwargs: Any) -> str: ) -class ClickHouseQuery(Query["ClickHouseQueryBuilder"]): +class ClickHouseQuery(BaseQuery["ClickHouseQueryBuilder"]): """ Defines a query class for use with Yandex ClickHouse. """ @@ -868,7 +868,7 @@ def get_value_sql(self, **kwargs: Any) -> str: return super().get_value_sql(**kwargs) -class SQLLiteQuery(Query["SQLLiteQueryBuilder"]): +class SQLLiteQuery(BaseQuery["SQLLiteQueryBuilder"]): """ Defines a query class for use with Microsoft SQL Server. """ diff --git a/pypika/queries.py b/pypika/queries.py index 28230a2f..38fdc1dd 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -18,7 +18,6 @@ cast, TypeVar, overload, - TYPE_CHECKING, ) from typing_extensions import Self @@ -32,6 +31,7 @@ Index, Rollup, Star, + Node, Term, Tuple, ValueWrapper, @@ -58,15 +58,10 @@ _T = TypeVar("_T") SchemaT = TypeVar("SchemaT", bound="Schema") -if TYPE_CHECKING: - from typing_extensions import TypeVar +QueryBuilderType = TypeVar("QueryBuilderType", bound="QueryBuilder", covariant=True) - QueryBuilderType = TypeVar("QueryBuilderType", bound="QueryBuilder", covariant=True, default="QueryBuilder") -else: - QueryBuilderType = TypeVar("QueryBuilderType", bound="QueryBuilder", covariant=True) - -class Selectable(Term): +class Selectable(Node): def __init__(self, alias: Optional[str]) -> None: self.alias = alias @@ -146,10 +141,192 @@ def get_sql(self, quote_char: Optional[str] = None, **kwargs: Any) -> str: class Database(Schema): @ignore_copy - def __getattr__(self, item: str) -> Schema: + def __getattr__(self, item: str) -> Schema: # type: ignore return Schema(item, parent=self) +class BaseQuery(Generic[QueryBuilderType]): + """ + Query is the primary class and entry point in pypika. It is used to build queries iteratively using the builder + design + pattern. + + This class is immutable. + """ + + @classmethod + def _builder(cls, **kwargs: Any) -> "QueryBuilderType": + raise NotImplementedError + + @classmethod + def from_(cls, table: Union[Selectable, str], **kwargs: Any) -> "QueryBuilderType": + """ + Query builder entry point. Initializes query building and sets the table to select from. When using this + function, the query becomes a SELECT query. + + :param table: + Type: Table or str + + An instance of a Table object or a string table name. + + :returns QueryBuilder + """ + return cls._builder(**kwargs).from_(table) + + @classmethod + def create_table(cls, table: Union[str, "Table"]) -> "CreateQueryBuilder": + """ + Query builder entry point. Initializes query building and sets the table name to be created. When using this + function, the query becomes a CREATE statement. + + :param table: An instance of a Table object or a string table name. + + :return: CreateQueryBuilder + """ + return CreateQueryBuilder().create_table(table) + + @classmethod + def drop_database(cls, database: Union[Database, str]) -> "DropQueryBuilder": + """ + Query builder entry point. Initializes query building and sets the table name to be dropped. When using this + function, the query becomes a DROP statement. + + :param database: An instance of a Database object or a string database name. + + :return: DropQueryBuilder + """ + return DropQueryBuilder().drop_database(database) + + @classmethod + def drop_table(cls, table: Union[str, "Table"]) -> "DropQueryBuilder": + """ + Query builder entry point. Initializes query building and sets the table name to be dropped. When using this + function, the query becomes a DROP statement. + + :param table: An instance of a Table object or a string table name. + + :return: DropQueryBuilder + """ + return DropQueryBuilder().drop_table(table) + + @classmethod + def drop_user(cls, user: str) -> "DropQueryBuilder": + """ + Query builder entry point. Initializes query building and sets the table name to be dropped. When using this + function, the query becomes a DROP statement. + + :param user: String user name. + + :return: DropQueryBuilder + """ + return DropQueryBuilder().drop_user(user) + + @classmethod + def drop_view(cls, view: str) -> "DropQueryBuilder": + """ + Query builder entry point. Initializes query building and sets the table name to be dropped. When using this + function, the query becomes a DROP statement. + + :param view: String view name. + + :return: DropQueryBuilder + """ + return DropQueryBuilder().drop_view(view) + + @classmethod + def into(cls, table: Union["Table", str], **kwargs: Any) -> "QueryBuilderType": + """ + Query builder entry point. Initializes query building and sets the table to insert into. When using this + function, the query becomes an INSERT query. + + :param table: + Type: Table or str + + An instance of a Table object or a string table name. + + :returns QueryBuilder + """ + return cls._builder(**kwargs).into(table) + + @classmethod + def with_(cls, table: Selectable, name: str, **kwargs: Any) -> "QueryBuilderType": + return cls._builder(**kwargs).with_(table, name) + + @classmethod + def select(cls, *terms: Union[int, float, str, bool, Term], **kwargs: Any) -> "QueryBuilderType": + """ + Query builder entry point. Initializes query building without a table and selects fields. Useful when testing + SQL functions. + + :param terms: + Type: list[expression] + + A list of terms to select. These can be any type of int, float, str, bool, or Term. They cannot be a Field + unless the function ``Query.from_`` is called first. + + :returns QueryBuilder + """ + return cls._builder(**kwargs).select(*terms) + + @classmethod + def update(cls, table: Union[str, "Table"], **kwargs) -> "QueryBuilderType": + """ + Query builder entry point. Initializes query building and sets the table to update. When using this + function, the query becomes an UPDATE query. + + :param table: + Type: Table or str + + An instance of a Table object or a string table name. + + :returns QueryBuilder + """ + return cls._builder(**kwargs).update(table) + + @classmethod + def Table(cls, table_name: str, **kwargs) -> "Table[QueryBuilderType]": + """ + Convenience method for creating a Table that uses this Query class. + + :param table_name: + Type: str + + A string table name. + + :returns Table + """ + return Table(table_name, query_cls=cls, **kwargs) + + @classmethod + def Tables(cls, *names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List["Table[QueryBuilderType]"]: # type: ignore + """ + Convenience method for creating many tables that uses this Query class. + See ``Query.make_tables`` for details. + + :param names: + Type: list[str or tuple] + + A list of string table names, or name and alias tuples. + + :returns Table + """ + return make_tables(*names, query_cls=cls, **kwargs) + + +class Query(BaseQuery["QueryBuilder"]): + """ + Query is the primary class and entry point in pypika. It is used to build queries iteratively using the builder + design + pattern. + + This class is immutable. + """ + + @classmethod + def _builder(cls, **kwargs: Any) -> "QueryBuilder": + return QueryBuilder(**kwargs) + + class Table(Selectable, Generic[QueryBuilderType]): @staticmethod def _init_schema(schema: Union[str, list, tuple, Schema, None]) -> Optional[Schema]: @@ -168,15 +345,15 @@ def __init__( name: str, schema: Union[str, list, tuple, Schema, None] = None, alias: Optional[str] = None, - query_cls: Optional[Type["Query[QueryBuilderType]"]] = None, + query_cls: Type["BaseQuery[QueryBuilderType]"] = Query, # type: ignore ) -> None: super().__init__(alias) self._table_name = name self._schema = self._init_schema(schema) - self._query_cls: Type["Query[QueryBuilderType]"] = query_cls or Query + self._query_cls: Type["BaseQuery[QueryBuilderType]"] = query_cls self._for: Optional[Criterion] = None self._for_portion: Optional[PeriodCriterion] = None - if not issubclass(self._query_cls, Query): + if not issubclass(self._query_cls, BaseQuery): raise TypeError("Expected 'query_cls' to be subclass of Query") def get_table_name(self) -> str: @@ -280,7 +457,9 @@ def insert(self, *terms: Union[int, float, str, bool, Term, Field]) -> "QueryBui def make_tables( - *names: Union[TypedTuple[str, str], str], query_cls: "Optional[Type[Query[QueryBuilderType]]]" = None, **kwargs: Any + *names: Union[TypedTuple[str, str], str], + query_cls: "Type[BaseQuery[QueryBuilderType]]" = Query, # type: ignore + **kwargs: Any, ) -> List[Table[QueryBuilderType]]: """ Shortcut to create many tables. If `names` param is a tuple, the first @@ -389,175 +568,7 @@ def get_sql(self, **kwargs: Any) -> str: _TableClass = Table -class Query(Generic[QueryBuilderType]): - """ - Query is the primary class and entry point in pypika. It is used to build queries iteratively using the builder - design - pattern. - - This class is immutable. - """ - - @classmethod - def _builder(cls, **kwargs: Any) -> "QueryBuilderType": - return QueryBuilder(**kwargs) - - @classmethod - def from_(cls, table: Union[Selectable, str], **kwargs: Any) -> "QueryBuilderType": - """ - Query builder entry point. Initializes query building and sets the table to select from. When using this - function, the query becomes a SELECT query. - - :param table: - Type: Table or str - - An instance of a Table object or a string table name. - - :returns QueryBuilder - """ - return cls._builder(**kwargs).from_(table) - - @classmethod - def create_table(cls, table: Union[str, Table]) -> "CreateQueryBuilder": - """ - Query builder entry point. Initializes query building and sets the table name to be created. When using this - function, the query becomes a CREATE statement. - - :param table: An instance of a Table object or a string table name. - - :return: CreateQueryBuilder - """ - return CreateQueryBuilder().create_table(table) - - @classmethod - def drop_database(cls, database: Union[Database, str]) -> "DropQueryBuilder": - """ - Query builder entry point. Initializes query building and sets the table name to be dropped. When using this - function, the query becomes a DROP statement. - - :param database: An instance of a Database object or a string database name. - - :return: DropQueryBuilder - """ - return DropQueryBuilder().drop_database(database) - - @classmethod - def drop_table(cls, table: Union[str, Table]) -> "DropQueryBuilder": - """ - Query builder entry point. Initializes query building and sets the table name to be dropped. When using this - function, the query becomes a DROP statement. - - :param table: An instance of a Table object or a string table name. - - :return: DropQueryBuilder - """ - return DropQueryBuilder().drop_table(table) - - @classmethod - def drop_user(cls, user: str) -> "DropQueryBuilder": - """ - Query builder entry point. Initializes query building and sets the table name to be dropped. When using this - function, the query becomes a DROP statement. - - :param user: String user name. - - :return: DropQueryBuilder - """ - return DropQueryBuilder().drop_user(user) - - @classmethod - def drop_view(cls, view: str) -> "DropQueryBuilder": - """ - Query builder entry point. Initializes query building and sets the table name to be dropped. When using this - function, the query becomes a DROP statement. - - :param view: String view name. - - :return: DropQueryBuilder - """ - return DropQueryBuilder().drop_view(view) - - @classmethod - def into(cls, table: Union[Table, str], **kwargs: Any) -> "QueryBuilderType": - """ - Query builder entry point. Initializes query building and sets the table to insert into. When using this - function, the query becomes an INSERT query. - - :param table: - Type: Table or str - - An instance of a Table object or a string table name. - - :returns QueryBuilder - """ - return cls._builder(**kwargs).into(table) - - @classmethod - def with_(cls, table: Selectable, name: str, **kwargs: Any) -> "QueryBuilderType": - return cls._builder(**kwargs).with_(table, name) - - @classmethod - def select(cls, *terms: Union[int, float, str, bool, Term], **kwargs: Any) -> "QueryBuilderType": - """ - Query builder entry point. Initializes query building without a table and selects fields. Useful when testing - SQL functions. - - :param terms: - Type: list[expression] - - A list of terms to select. These can be any type of int, float, str, bool, or Term. They cannot be a Field - unless the function ``Query.from_`` is called first. - - :returns QueryBuilder - """ - return cls._builder(**kwargs).select(*terms) - - @classmethod - def update(cls, table: Union[str, Table], **kwargs) -> "QueryBuilderType": - """ - Query builder entry point. Initializes query building and sets the table to update. When using this - function, the query becomes an UPDATE query. - - :param table: - Type: Table or str - - An instance of a Table object or a string table name. - - :returns QueryBuilder - """ - return cls._builder(**kwargs).update(table) - - @classmethod - def Table(cls, table_name: str, **kwargs) -> Table[QueryBuilderType]: - """ - Convenience method for creating a Table that uses this Query class. - - :param table_name: - Type: str - - A string table name. - - :returns Table - """ - return Table(table_name, query_cls=cls, **kwargs) - - @classmethod - def Tables(cls, *names: Union[TypedTuple[str, str], str], **kwargs: Any) -> List["Table[QueryBuilderType]"]: - """ - Convenience method for creating many tables that uses this Query class. - See ``Query.make_tables`` for details. - - :param names: - Type: list[str or tuple] - - A list of string table names, or name and alias tuples. - - :returns Table - """ - return make_tables(*names, query_cls=cls, **kwargs) - - -class _SetOperation(Selectable, Term, SQLPart): +class _SetOperation(Selectable, Term, SQLPart): # type: ignore """ A Query class wrapper for a all set operations, Union DISTINCT or ALL, Intersect, Except or Minus @@ -728,7 +739,7 @@ class QueryBuilder(Selectable, Term, SQLPart): SECONDARY_QUOTE_CHAR: Optional[str] = "'" ALIAS_QUOTE_CHAR: Optional[str] = None QUERY_ALIAS_QUOTE_CHAR: Optional[str] = None - QUERY_CLS = Query + QUERY_CLS: Type[BaseQuery] = Query def __init__( self, @@ -850,7 +861,7 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl self._insert_table = new_table if self._insert_table == current_table else self._insert_table self._update_table = new_table if self._update_table == current_table else self._update_table - self._with = [alias_query.replace_table(current_table, new_table) for alias_query in self._with] + self._with = [alias_query.replace_table(current_table, new_table) for alias_query in self._with] # TODO: why? self._selects = [ select.replace_table(current_table, new_table) if isinstance(select, Term) else select for select in self._selects @@ -1160,7 +1171,7 @@ def slice(self, slice: slice): self._offset = slice.start self._limit = slice.stop - @overload + @overload # type: ignore[override] def __getitem__(self, item: str) -> Field: ... @@ -1739,7 +1750,7 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl :return: A copy of the join with the tables replaced. """ - self.item = self.item.replace_table(current_table, new_table) + self.item = self.item.replace_table(current_table, new_table) # TODO: why? class JoinOn(Join): @@ -1831,7 +1842,7 @@ class CreateQueryBuilder(SQLPart): QUOTE_CHAR: Optional[str] = '"' SECONDARY_QUOTE_CHAR: Optional[str] = "'" ALIAS_QUOTE_CHAR: Optional[str] = None - QUERY_CLS = Query + QUERY_CLS: Type[BaseQuery] = Query def __init__(self, dialect: Optional[Dialects] = None) -> None: self._create_table: Optional[Table] = None @@ -2184,7 +2195,7 @@ class DropQueryBuilder(SQLPart): QUOTE_CHAR: Optional[str] = '"' SECONDARY_QUOTE_CHAR: Optional[str] = "'" ALIAS_QUOTE_CHAR: Optional[str] = None - QUERY_CLS = Query + QUERY_CLS: Type[BaseQuery] = Query def __init__(self, dialect: Optional[Dialects] = None) -> None: self._drop_target_kind: Optional[str] = None From 0795d329cd24c4eafeb0cc5b8156b52cc099ffb9 Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 5 Mar 2023 15:38:17 +0800 Subject: [PATCH 22/23] feat: fix Case.else_ return type --- pypika/terms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypika/terms.py b/pypika/terms.py index 157cab4d..4946b90c 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -1254,7 +1254,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T self._else = self._else.replace_table(current_table, new_table) if self._else else None @builder - def else_(self, term: WrappedConstantValue) -> "Case": + def else_(self, term: WrappedConstantValue) -> Self: self._else = self.wrap_constant(term) return self From c2d46024dc4a5b47cc799b8b1f06be3cd2c0e3fd Mon Sep 17 00:00:00 2001 From: fish Date: Sun, 5 Mar 2023 16:06:52 +0800 Subject: [PATCH 23/23] feat: fix python 3.8+ support --- pypika/dialects.py | 12 +++++++----- pypika/queries.py | 32 +++++++++++++++++--------------- pypika/terms.py | 11 ++++++----- pypika/utils.py | 31 +++++++++++++------------------ 4 files changed, 43 insertions(+), 43 deletions(-) diff --git a/pypika/dialects.py b/pypika/dialects.py index f63edbb9..5a2fb4e7 100644 --- a/pypika/dialects.py +++ b/pypika/dialects.py @@ -1,7 +1,9 @@ import itertools from copy import copy -from typing import Any, Iterable, List, Optional, Set, Union, Tuple as TypedTuple, cast -from typing_extensions import Self, NoReturn +from typing import Any, Iterable, List, Optional, Set, Union, Tuple as TypedTuple, cast, TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import Self, NoReturn from pypika.enums import Dialects from pypika.queries import ( @@ -107,7 +109,7 @@ def __init__(self, **kwargs: Any) -> None: self._for_update_skip_locked = False self._for_update_of: Set[str] = set() - def __copy__(self) -> Self: + def __copy__(self) -> "Self": newone = super().__copy__() newone._duplicate_updates = copy(self._duplicate_updates) newone._ignore_duplicates = copy(self._ignore_duplicates) @@ -416,7 +418,7 @@ def __init__(self, **kwargs: Any) -> None: self._for_update_skip_locked = False self._for_update_of: Set[str] = set() - def __copy__(self) -> Self: + def __copy__(self) -> "Self": newone = super().__copy__() newone._returns = copy(self._returns) newone._on_conflict_do_updates = copy(self._on_conflict_do_updates) @@ -809,7 +811,7 @@ def _update_sql(self, **kwargs: Any) -> str: return "ALTER TABLE {table}".format(table=self._update_table.get_sql(**kwargs)) def _from_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: - def _error_none(v) -> NoReturn: + def _error_none(v) -> "NoReturn": raise TypeError("expect Selectable or QueryBuilder, got {}".format(type(v).__name__)) selectable = ",".join( diff --git a/pypika/queries.py b/pypika/queries.py index 38fdc1dd..78cef989 100644 --- a/pypika/queries.py +++ b/pypika/queries.py @@ -18,8 +18,11 @@ cast, TypeVar, overload, + TYPE_CHECKING, ) -from typing_extensions import Self + +if TYPE_CHECKING: + from typing_extensions import Self from pypika.enums import Dialects, JoinType, ReferenceOption, SetOperation, Order from pypika.terms import ( @@ -49,7 +52,6 @@ format_alias_sql, format_quotes, ignore_copy, - SQLPart, ) __author__ = "Timothy Heys" @@ -93,7 +95,7 @@ def get_sql(self, **kwargs) -> str: raise NotImplementedError -class AliasedQuery(Selectable, SQLPart): +class AliasedQuery(Selectable): def __init__(self, name: str, query: Optional[Selectable] = None) -> None: super().__init__(alias=name) self.name = name @@ -111,7 +113,7 @@ def __hash__(self) -> int: return hash(str(self.name)) -class Schema(SQLPart): +class Schema: def __init__(self, name: str, parent: Optional["Schema"] = None) -> None: self._name = name self._parent = parent @@ -488,7 +490,7 @@ def make_tables( return tables -class Column(SQLPart): +class Column: """Represents a column.""" def __init__( @@ -546,7 +548,7 @@ def make_columns(*names: Union[TypedTuple[str, str], str]) -> List[Column]: return columns -class PeriodFor(SQLPart): +class PeriodFor: def __init__(self, name: str, start_column: Union[str, Column], end_column: Union[str, Column]) -> None: self.name = name self.start_column = start_column if isinstance(start_column, Column) else Column(start_column) @@ -568,7 +570,7 @@ def get_sql(self, **kwargs: Any) -> str: _TableClass = Table -class _SetOperation(Selectable, Term, SQLPart): # type: ignore +class _SetOperation(Selectable, Term): # type: ignore """ A Query class wrapper for a all set operations, Union DISTINCT or ALL, Intersect, Except or Minus @@ -729,7 +731,7 @@ def _limit_sql(self) -> str: return " LIMIT {limit}".format(limit=self._limit) -class QueryBuilder(Selectable, Term, SQLPart): +class QueryBuilder(Selectable, Term): """ Query Builder is the main class in pypika which stores the state of a query and offers functions which allow the state to be branched immutably. @@ -799,7 +801,7 @@ def __init__( self.immutable = immutable - def __copy__(self) -> Self: + def __copy__(self) -> "Self": newone = type(self).__new__(type(self)) newone.__dict__.update(self.__dict__) newone._select_star_tables = copy(self._select_star_tables) @@ -1176,10 +1178,10 @@ def __getitem__(self, item: str) -> Field: ... @overload - def __getitem__(self, item: builtins.slice) -> Self: + def __getitem__(self, item: builtins.slice) -> "Self": ... - def __getitem__(self, item: Union[str, builtins.slice]) -> Union[Self, Field]: + def __getitem__(self, item: Union[str, builtins.slice]) -> Union["Self", Field]: if not isinstance(item, slice): return super().__getitem__(item) return self.slice(item) @@ -1559,7 +1561,7 @@ def _from_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: def _using_sql(self, with_namespace: bool = False, **kwargs: Any) -> str: return " USING {selectable}".format( selectable=",".join( - clause.get_sql(subquery=True, with_alias=True, **kwargs) if isinstance(clause, SQLPart) else clause + clause.get_sql(subquery=True, with_alias=True, **kwargs) if isinstance(clause, Selectable) else clause for clause in self._using ) ) @@ -1720,7 +1722,7 @@ def cross(self) -> "QueryBuilderType": return self.query -class Join(SQLPart): +class Join: def __init__(self, item: JoinableTerm, how: JoinType) -> None: self.item = item self.how = how @@ -1834,7 +1836,7 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl raise ValueError("new_table should not be None for {}".format(type(self).__name__)) -class CreateQueryBuilder(SQLPart): +class CreateQueryBuilder: """ Query builder used to build CREATE queries. """ @@ -2187,7 +2189,7 @@ def __repr__(self) -> str: return self.__str__() -class DropQueryBuilder(SQLPart): +class DropQueryBuilder: """ Query builder used to build DROP queries. """ diff --git a/pypika/terms.py b/pypika/terms.py index 4946b90c..e062c5cd 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -20,7 +20,9 @@ Union, overload, ) -from typing_extensions import Self + +if TYPE_CHECKING: + from typing_extensions import Self from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order from pypika.utils import ( @@ -31,7 +33,6 @@ format_quotes, ignore_copy, resolve_is_aggregate, - SQLPart, ) if TYPE_CHECKING: @@ -70,7 +71,7 @@ def find_(self, type: Type[NodeT]) -> List[NodeT]: WrappedConstant = Union["Term", WrappedConstantStrict] -class Term(Node, SQLPart): +class Term(Node): def __init__(self, alias: Optional[str] = None) -> None: self.alias = alias @@ -1254,7 +1255,7 @@ def replace_table(self, current_table: Optional["Table"], new_table: Optional["T self._else = self._else.replace_table(current_table, new_table) if self._else else None @builder - def else_(self, term: WrappedConstantValue) -> Self: + def else_(self, term: WrappedConstantValue) -> "Self": self._else = self.wrap_constant(term) return self @@ -1742,7 +1743,7 @@ def get_sql(self, **kwargs: Any) -> str: return self.name -class AtTimezone(Term, SQLPart): +class AtTimezone(Term): """ Generates AT TIME ZONE SQL. Examples: diff --git a/pypika/utils.py b/pypika/utils.py index 35a58e9e..4fd2da8a 100644 --- a/pypika/utils.py +++ b/pypika/utils.py @@ -1,5 +1,7 @@ -from typing import Any, Callable, List, Optional, Type, Union, overload, TypeVar -from typing_extensions import ParamSpec, Concatenate, Protocol, runtime_checkable +from typing import Any, Callable, List, Optional, Type, Union, overload, TypeVar, TYPE_CHECKING + +if TYPE_CHECKING: + from typing_extensions import ParamSpec, Concatenate __author__ = "Timothy Heys" __email__ = "theys@kayak.com" @@ -37,22 +39,23 @@ class FunctionException(Exception): pass -_T = TypeVar('_T') -_S = TypeVar('_S') -_P = ParamSpec('_P') +if TYPE_CHECKING: + _T = TypeVar('_T') + _S = TypeVar('_S') + _P = ParamSpec('_P') @overload -def builder(func: Callable[Concatenate[_S, _P], Union[_S, None]]) -> Callable[Concatenate[_S, _P], _S]: +def builder(func: "Callable[Concatenate[_S, _P], Union[_S, None]]") -> "Callable[Concatenate[_S, _P], _S]": ... @overload -def builder(func: Callable[Concatenate[_S, _P], _T]) -> Callable[Concatenate[_S, _P], _T]: +def builder(func: "Callable[Concatenate[_S, _P], _T]") -> "Callable[Concatenate[_S, _P], _T]": ... -def builder(func: Callable[Concatenate[_S, _P], Union[_T, None]]) -> Callable[Concatenate[_S, _P], Union[_T, _S]]: +def builder(func: "Callable[Concatenate[_S, _P], Union[_T, None]]") -> "Callable[Concatenate[_S, _P], Union[_T, _S]]": """ Decorator for wrapper "builder" functions. These are functions on the Query class or other classes used for building queries which mutate the query and return self. To make the build functions immutable, this decorator is @@ -61,7 +64,7 @@ def builder(func: Callable[Concatenate[_S, _P], Union[_T, None]]) -> Callable[Co """ import copy - def _copy(self: _S, *args: _P.args, **kwargs: _P.kwargs): + def _copy(self: "_S", *args: "_P.args", **kwargs: "_P.kwargs"): self_copy = copy.copy(self) if getattr(self, "immutable", True) else self result = func(self_copy, *args, **kwargs) @@ -75,7 +78,7 @@ def _copy(self: _S, *args: _P.args, **kwargs: _P.kwargs): return _copy -def ignore_copy(func: Callable[[_S, str], _T]) -> Callable[[_S, str], _T]: +def ignore_copy(func: "Callable[[_S, str], _T]") -> "Callable[[_S, str], _T]": """ Decorator for wrapping the __getattr__ function for classes that are copied via deepcopy. This prevents infinite recursion caused by deepcopy looking for magic functions in the class. Any class implementing __getattr__ that is @@ -140,11 +143,3 @@ def validate(*args: Any, exc: Optional[Exception], type: Optional[Type] = None) for arg in args: if not isinstance(arg, type): raise exc - - -@runtime_checkable -class SQLPart(Protocol): - """This protocol indicates the class can generate a part of SQL""" - - def get_sql(self, **kwargs) -> str: - ...