diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9b9cb148c69f..9b430529033c 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2375,7 +2375,7 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl return c.copy_modified(ret_type=new_ret_type) -class ArgInferSecondPassQuery(types.TypeQuery): +class ArgInferSecondPassQuery(types.TypeQuery[bool]): """Query whether an argument type should be inferred in the second pass. The result is True if the type has a type variable in a callable return @@ -2383,16 +2383,16 @@ class ArgInferSecondPassQuery(types.TypeQuery): a type variable. """ def __init__(self) -> None: - super().__init__(False, types.ANY_TYPE_STRATEGY) + super().__init__(any) def visit_callable_type(self, t: CallableType) -> bool: return self.query_types(t.arg_types) or t.accept(HasTypeVarQuery()) -class HasTypeVarQuery(types.TypeQuery): +class HasTypeVarQuery(types.TypeQuery[bool]): """Visitor for querying whether a type has a type variable component.""" def __init__(self) -> None: - super().__init__(False, types.ANY_TYPE_STRATEGY) + super().__init__(any) def visit_type_var(self, t: TypeVarType) -> bool: return True @@ -2402,10 +2402,10 @@ def has_erased_component(t: Type) -> bool: return t is not None and t.accept(HasErasedComponentsQuery()) -class HasErasedComponentsQuery(types.TypeQuery): +class HasErasedComponentsQuery(types.TypeQuery[bool]): """Visitor for querying whether a type has an erased component.""" def __init__(self) -> None: - super().__init__(False, types.ANY_TYPE_STRATEGY) + super().__init__(any) def visit_erased_type(self, t: ErasedType) -> bool: return True diff --git a/mypy/constraints.py b/mypy/constraints.py index d6e44bea857d..006e255fa9ea 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -6,8 +6,7 @@ from mypy.types import ( CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneTyp, TypeVarType, Instance, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, - DeletedType, UninhabitedType, TypeType, TypeVarId, TypeQuery, ALL_TYPES_STRATEGY, - is_named_instance + DeletedType, UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance ) from mypy.maptype import map_instance_to_supertype from mypy import nodes @@ -250,9 +249,9 @@ def is_complete_type(typ: Type) -> bool: return typ.accept(CompleteTypeVisitor()) -class CompleteTypeVisitor(TypeQuery): +class CompleteTypeVisitor(TypeQuery[bool]): def __init__(self) -> None: - super().__init__(default=True, strategy=ALL_TYPES_STRATEGY) + super().__init__(all) def visit_none_type(self, t: NoneTyp) -> bool: return experiments.STRICT_OPTIONAL diff --git a/mypy/stats.py b/mypy/stats.py index 2b809a6d6267..3739763069bd 100644 --- a/mypy/stats.py +++ b/mypy/stats.py @@ -7,8 +7,7 @@ from mypy.traverser import TraverserVisitor from mypy.types import ( - Type, AnyType, Instance, FunctionLike, TupleType, TypeVarType, - TypeQuery, ANY_TYPE_STRATEGY, CallableType + Type, AnyType, Instance, FunctionLike, TupleType, TypeVarType, TypeQuery, CallableType ) from mypy import nodes from mypy.nodes import ( @@ -226,9 +225,9 @@ def is_imprecise(t: Type) -> bool: return t.accept(HasAnyQuery()) -class HasAnyQuery(TypeQuery): +class HasAnyQuery(TypeQuery[bool]): def __init__(self) -> None: - super().__init__(False, ANY_TYPE_STRATEGY) + super().__init__(any) def visit_any(self, t: AnyType) -> bool: return True diff --git a/mypy/types.py b/mypy/types.py index 1bba37724ced..e6239a944cdf 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -5,7 +5,7 @@ from collections import OrderedDict from typing import ( Any, TypeVar, Dict, List, Tuple, cast, Generic, Set, Sequence, Optional, Union, Iterable, - NamedTuple, + NamedTuple, Callable, ) import mypy.nodes @@ -1500,109 +1500,78 @@ def keywords_str(self, a: Iterable[Tuple[str, Type]]) -> str: ]) -# These constants define the method used by TypeQuery to combine multiple -# query results, e.g. for tuple types. The strategy is not used for empty -# result lists; in that case the default value takes precedence. -ANY_TYPE_STRATEGY = 0 # Return True if any of the results are True. -ALL_TYPES_STRATEGY = 1 # Return True if all of the results are True. +class TypeQuery(Generic[T], TypeVisitor[T]): + """Visitor for performing queries of types. + strategy is used to combine results for a series of types -class TypeQuery(TypeVisitor[bool]): - """Visitor for performing simple boolean queries of types. - - This class allows defining the default value for leafs to simplify the - implementation of many queries. + Common use cases involve a boolean query using `any` or `all` """ - default = False # Default result - strategy = 0 # Strategy for combining multiple values (ANY_TYPE_STRATEGY or ALL_TYPES_...). - - def __init__(self, default: bool, strategy: int) -> None: - """Construct a query visitor. - - Use the given default result and strategy for combining - multiple results. The strategy must be either - ANY_TYPE_STRATEGY or ALL_TYPES_STRATEGY. - """ - self.default = default + def __init__(self, strategy: Callable[[Iterable[T]], T]) -> None: self.strategy = strategy - def visit_unbound_type(self, t: UnboundType) -> bool: - return self.default + def visit_unbound_type(self, t: UnboundType) -> T: + return self.query_types(t.args) - def visit_type_list(self, t: TypeList) -> bool: - return self.default + def visit_type_list(self, t: TypeList) -> T: + return self.query_types(t.items) - def visit_any(self, t: AnyType) -> bool: - return self.default + def visit_any(self, t: AnyType) -> T: + return self.strategy([]) - def visit_uninhabited_type(self, t: UninhabitedType) -> bool: - return self.default + def visit_uninhabited_type(self, t: UninhabitedType) -> T: + return self.strategy([]) - def visit_none_type(self, t: NoneTyp) -> bool: - return self.default + def visit_none_type(self, t: NoneTyp) -> T: + return self.strategy([]) - def visit_erased_type(self, t: ErasedType) -> bool: - return self.default + def visit_erased_type(self, t: ErasedType) -> T: + return self.strategy([]) - def visit_deleted_type(self, t: DeletedType) -> bool: - return self.default + def visit_deleted_type(self, t: DeletedType) -> T: + return self.strategy([]) - def visit_type_var(self, t: TypeVarType) -> bool: - return self.default + def visit_type_var(self, t: TypeVarType) -> T: + return self.strategy([]) - def visit_partial_type(self, t: PartialType) -> bool: - return self.default + def visit_partial_type(self, t: PartialType) -> T: + return self.query_types(t.inner_types) - def visit_instance(self, t: Instance) -> bool: + def visit_instance(self, t: Instance) -> T: return self.query_types(t.args) - def visit_callable_type(self, t: CallableType) -> bool: + def visit_callable_type(self, t: CallableType) -> T: # FIX generics return self.query_types(t.arg_types + [t.ret_type]) - def visit_tuple_type(self, t: TupleType) -> bool: + def visit_tuple_type(self, t: TupleType) -> T: return self.query_types(t.items) - def visit_typeddict_type(self, t: TypedDictType) -> bool: + def visit_typeddict_type(self, t: TypedDictType) -> T: return self.query_types(t.items.values()) - def visit_star_type(self, t: StarType) -> bool: + def visit_star_type(self, t: StarType) -> T: return t.type.accept(self) - def visit_union_type(self, t: UnionType) -> bool: + def visit_union_type(self, t: UnionType) -> T: return self.query_types(t.items) - def visit_overloaded(self, t: Overloaded) -> bool: + def visit_overloaded(self, t: Overloaded) -> T: return self.query_types(t.items()) - def visit_type_type(self, t: TypeType) -> bool: + def visit_type_type(self, t: TypeType) -> T: return t.item.accept(self) - def query_types(self, types: Iterable[Type]) -> bool: + def visit_ellipsis_type(self, t: EllipsisType) -> T: + return self.strategy([]) + + def query_types(self, types: Iterable[Type]) -> T: """Perform a query for a list of types. - Use the strategy constant to combine the results. + Use the strategy to combine the results. """ - if not types: - # Use default result for empty list. - return self.default - if self.strategy == ANY_TYPE_STRATEGY: - # Return True if at least one component is true. - res = False - for t in types: - res = res or t.accept(self) - if res: - break - return res - else: - # Return True if all components are true. - res = True - for t in types: - res = res and t.accept(self) - if not res: - break - return res + return self.strategy(t.accept(self) for t in types) def strip_type(typ: Type) -> Type: