From 72e3ab276940d6aaec6bf16fd07d069823115078 Mon Sep 17 00:00:00 2001 From: Naomi Seyfer Date: Wed, 29 Mar 2017 16:05:18 -0700 Subject: [PATCH] Make TypeQuery more general, handling nonboolean queries. Instead of TypeQuery always returning a boolean and having the strategy be an enum, the strategy is now a Callable describing how to combine partial results, and the two default strategies are plain old funcitons. To preserve the short-circuiting behavior of the previous code, this PR uses an exception. This is a pure refactor that I am using in my experimentation regarding fixing https://github.com/python/mypy/issues/1551. It should result in exactly no change to current behavior. It's separable from the other things I'm experimenting with, so I'm filing it as a separate pull request now. It enables me to rewrite the code that pulls type variables out of types as a TypeQuery. Consider waiting to merge this PR until I have some code that uses it ready for review. Or merge it now, if you think it's a pleasant cleanup instead of an ugly complication. I'm of two minds on that particular question. --- mypy/checkexpr.py | 6 +-- mypy/constraints.py | 2 +- mypy/stats.py | 2 +- mypy/types.py | 108 ++++++++++++++++++++++---------------------- 4 files changed, 59 insertions(+), 59 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 51a6f92e9ce64..f6a0a9d716a5e 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2340,7 +2340,7 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl return c.copy_modified(ret_type=new_ret_type) -class ArgInferSecondPassQuery(types.TypeQuery): +class ArgInferSecondPassQuery(types.TypeQuery[bool]): """Query whether an argument type should be inferred in the second pass. The result is True if the type has a type variable in a callable return @@ -2354,7 +2354,7 @@ def visit_callable_type(self, t: CallableType) -> bool: return self.query_types(t.arg_types) or t.accept(HasTypeVarQuery()) -class HasTypeVarQuery(types.TypeQuery): +class HasTypeVarQuery(types.TypeQuery[bool]): """Visitor for querying whether a type has a type variable component.""" def __init__(self) -> None: super().__init__(False, types.ANY_TYPE_STRATEGY) @@ -2367,7 +2367,7 @@ def has_erased_component(t: Type) -> bool: return t is not None and t.accept(HasErasedComponentsQuery()) -class HasErasedComponentsQuery(types.TypeQuery): +class HasErasedComponentsQuery(types.TypeQuery[bool]): """Visitor for querying whether a type has an erased component.""" def __init__(self) -> None: super().__init__(False, types.ANY_TYPE_STRATEGY) diff --git a/mypy/constraints.py b/mypy/constraints.py index d6e44bea857d8..e1c5f47b8f995 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -250,7 +250,7 @@ def is_complete_type(typ: Type) -> bool: return typ.accept(CompleteTypeVisitor()) -class CompleteTypeVisitor(TypeQuery): +class CompleteTypeVisitor(TypeQuery[bool]): def __init__(self) -> None: super().__init__(default=True, strategy=ALL_TYPES_STRATEGY) diff --git a/mypy/stats.py b/mypy/stats.py index 2b809a6d6267e..39c954bfac90a 100644 --- a/mypy/stats.py +++ b/mypy/stats.py @@ -226,7 +226,7 @@ def is_imprecise(t: Type) -> bool: return t.accept(HasAnyQuery()) -class HasAnyQuery(TypeQuery): +class HasAnyQuery(TypeQuery[bool]): def __init__(self) -> None: super().__init__(False, ANY_TYPE_STRATEGY) diff --git a/mypy/types.py b/mypy/types.py index 41ef8ec238f51..54d45a14eaf26 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -5,7 +5,7 @@ from collections import OrderedDict from typing import ( Any, TypeVar, Dict, List, Tuple, cast, Generic, Set, Sequence, Optional, Union, Iterable, - NamedTuple, + NamedTuple, Callable, ) import mypy.nodes @@ -1500,112 +1500,112 @@ def keywords_str(self, a: Iterable[Tuple[str, Type]]) -> str: ]) -# These constants define the method used by TypeQuery to combine multiple -# query results, e.g. for tuple types. The strategy is not used for empty -# result lists; in that case the default value takes precedence. -ANY_TYPE_STRATEGY = 0 # Return True if any of the results are True. -ALL_TYPES_STRATEGY = 1 # Return True if all of the results are True. +# Combination strategies for boolean type queries +def ANY_TYPE_STRATEGY(current: bool, accumulated: bool) -> bool: + """True if any type's result is True""" + if accumulated: + raise ShortCircuitQuery() + return current -class TypeQuery(TypeVisitor[bool]): - """Visitor for performing simple boolean queries of types. +def ALL_TYPES_STRATEGY(current: bool, accumulated: bool) -> bool: + """True if all types' results are True""" + if not accumulated: + raise ShortCircuitQuery() + return current - This class allows defining the default value for leafs to simplify the - implementation of many queries. - """ - default = False # Default result - strategy = 0 # Strategy for combining multiple values (ANY_TYPE_STRATEGY or ALL_TYPES_...). +class ShortCircuitQuery(Exception): + pass - def __init__(self, default: bool, strategy: int) -> None: - """Construct a query visitor. - Use the given default result and strategy for combining - multiple results. The strategy must be either - ANY_TYPE_STRATEGY or ALL_TYPES_STRATEGY. - """ +class TypeQuery(Generic[T], TypeVisitor[T]): + """Visitor for performing queries of types. + + default is used as the query result unless a method for that type is + overridden. + + strategy is used to combine a partial result with a result for a particular + type in a series of types. + + Common use cases involve a boolean query using ANY_TYPE_STRATEGY and a + default of False or ALL_TYPES_STRATEGY and a default of True. + """ + + def __init__(self, default: T, strategy: Callable[[T, T], T]) -> None: self.default = default self.strategy = strategy - def visit_unbound_type(self, t: UnboundType) -> bool: + def visit_unbound_type(self, t: UnboundType) -> T: return self.default - def visit_type_list(self, t: TypeList) -> bool: + def visit_type_list(self, t: TypeList) -> T: return self.default - def visit_error_type(self, t: ErrorType) -> bool: + def visit_error_type(self, t: ErrorType) -> T: return self.default - def visit_any(self, t: AnyType) -> bool: + def visit_any(self, t: AnyType) -> T: return self.default - def visit_uninhabited_type(self, t: UninhabitedType) -> bool: + def visit_uninhabited_type(self, t: UninhabitedType) -> T: return self.default - def visit_none_type(self, t: NoneTyp) -> bool: + def visit_none_type(self, t: NoneTyp) -> T: return self.default - def visit_erased_type(self, t: ErasedType) -> bool: + def visit_erased_type(self, t: ErasedType) -> T: return self.default - def visit_deleted_type(self, t: DeletedType) -> bool: + def visit_deleted_type(self, t: DeletedType) -> T: return self.default - def visit_type_var(self, t: TypeVarType) -> bool: + def visit_type_var(self, t: TypeVarType) -> T: return self.default - def visit_partial_type(self, t: PartialType) -> bool: + def visit_partial_type(self, t: PartialType) -> T: return self.default - def visit_instance(self, t: Instance) -> bool: + def visit_instance(self, t: Instance) -> T: return self.query_types(t.args) - def visit_callable_type(self, t: CallableType) -> bool: + def visit_callable_type(self, t: CallableType) -> T: # FIX generics return self.query_types(t.arg_types + [t.ret_type]) - def visit_tuple_type(self, t: TupleType) -> bool: + def visit_tuple_type(self, t: TupleType) -> T: return self.query_types(t.items) - def visit_typeddict_type(self, t: TypedDictType) -> bool: + def visit_typeddict_type(self, t: TypedDictType) -> T: return self.query_types(t.items.values()) - def visit_star_type(self, t: StarType) -> bool: + def visit_star_type(self, t: StarType) -> T: return t.type.accept(self) - def visit_union_type(self, t: UnionType) -> bool: + def visit_union_type(self, t: UnionType) -> T: return self.query_types(t.items) - def visit_overloaded(self, t: Overloaded) -> bool: + def visit_overloaded(self, t: Overloaded) -> T: return self.query_types(t.items()) - def visit_type_type(self, t: TypeType) -> bool: + def visit_type_type(self, t: TypeType) -> T: return t.item.accept(self) - def query_types(self, types: Iterable[Type]) -> bool: + def query_types(self, types: Iterable[Type]) -> T: """Perform a query for a list of types. - Use the strategy constant to combine the results. + Use the strategy to combine the results. """ if not types: # Use default result for empty list. return self.default - if self.strategy == ANY_TYPE_STRATEGY: - # Return True if at least one component is true. - res = False + res = self.default + try: for t in types: - res = res or t.accept(self) - if res: - break - return res - else: - # Return True if all components are true. - res = True - for t in types: - res = res and t.accept(self) - if not res: - break - return res + res = self.strategy(t.accept(self), res) + except ShortCircuitQuery: + pass + return res def strip_type(typ: Type) -> Type: