From 184e415502add001425ef93fb74464618973879f Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Tue, 25 Apr 2023 02:09:07 -0600 Subject: [PATCH 1/7] Speed up make_simplified_union, remove a potential crash The following code optimises make_simplified_union in the common case that there are exact duplicates in the union. In this regard, this is similar to #15104 To get this to work, I needed to use partial tuple fallbacks in a couple places (these maybe had the potential to be latent crashes anyway?) There were some interesting things going on with recursive type aliases and type state assumptions This is about a 25% speedup on the pydantic codebase and about a 2% speedup on self check (measured with uncompiled mypy) --- mypy/subtypes.py | 7 ++- mypy/test/testtypes.py | 2 +- mypy/typeops.py | 121 +++++++++++++---------------------------- 3 files changed, 45 insertions(+), 85 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 59919456ab5c..88e4c6929aaf 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -439,7 +439,7 @@ def visit_instance(self, left: Instance) -> bool: # dynamic base classes correctly, see #5456. return not isinstance(self.right, NoneType) right = self.right - if isinstance(right, TupleType) and mypy.typeops.tuple_fallback(right).type.is_enum: + if isinstance(right, TupleType) and right.partial_fallback.type.is_enum: return self._is_subtype(left, mypy.typeops.tuple_fallback(right)) if isinstance(right, Instance): if type_state.is_cached_subtype_check(self._subtype_kind, left, right): @@ -753,7 +753,10 @@ def visit_tuple_type(self, left: TupleType) -> bool: # for isinstance(x, tuple), though it's unclear why. return True return all(self._is_subtype(li, iter_type) for li in left.items) - elif self._is_subtype(mypy.typeops.tuple_fallback(left), right): + elif ( + self._is_subtype(left.partial_fallback, right) + and self._is_subtype(mypy.typeops.tuple_fallback(left), right) + ): return True return False elif isinstance(right, TupleType): diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 601cdf27466e..3ac91e078b1c 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -613,7 +613,7 @@ def test_simplified_union_with_mixed_str_literals(self) -> None: ) self.assert_simplified_union( [fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], - UnionType([fx.lit_str1, fx.lit_str1_inst]), + fx.lit_str1, ) def assert_simplified_union(self, original: list[Type], union: Type) -> None: diff --git a/mypy/typeops.py b/mypy/typeops.py index 8ed59b6fbe55..fb00d4363a54 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -385,25 +385,6 @@ def callable_corresponding_argument( return by_name if by_name is not None else by_pos -def simple_literal_value_key(t: ProperType) -> tuple[str, ...] | None: - """Return a hashable description of simple literal type. - - Return None if not a simple literal type. - - The return value can be used to simplify away duplicate types in - unions by comparing keys for equality. For now enum, string or - Instance with string last_known_value are supported. - """ - if isinstance(t, LiteralType): - if t.fallback.type.is_enum or t.fallback.type.fullname == "builtins.str": - assert isinstance(t.value, str) - return "literal", t.value, t.fallback.type.fullname - if isinstance(t, Instance): - if t.last_known_value is not None and isinstance(t.last_known_value.value, str): - return "instance", t.last_known_value.value, t.type.fullname - return None - - def simple_literal_type(t: ProperType | None) -> Instance | None: """Extract the underlying fallback Instance type for a simple Literal""" if isinstance(t, Instance) and t.last_known_value is not None: @@ -414,7 +395,6 @@ def simple_literal_type(t: ProperType | None) -> Instance | None: def is_simple_literal(t: ProperType) -> bool: - """Fast way to check if simple_literal_value_key() would return a non-None value.""" if isinstance(t, LiteralType): return t.fallback.type.is_enum or t.fallback.type.fullname == "builtins.str" if isinstance(t, Instance): @@ -500,68 +480,45 @@ def make_simplified_union( def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[Type]: from mypy.subtypes import is_proper_subtype - removed: set[int] = set() - seen: set[tuple[str, ...]] = set() - - # NB: having a separate fast path for Union of Literal and slow path for other things - # would arguably be cleaner, however it breaks down when simplifying the Union of two - # different enum types as try_expanding_sum_type_to_union works recursively and will - # trigger intermediate simplifications that would render the fast path useless - for i, item in enumerate(items): - proper_item = get_proper_type(item) - if i in removed: - continue - # Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169) - k = simple_literal_value_key(proper_item) - if k is not None: - if k in seen: - removed.add(i) - continue - - # NB: one would naively expect that it would be safe to skip the slow path - # always for literals. One would be sorely mistaken. Indeed, some simplifications - # such as that of None/Optional when strict optional is false, do require that we - # proceed with the slow path. Thankfully, all literals will have the same subtype - # relationship to non-literal types, so we only need to do that walk for the first - # literal, which keeps the fast path fast even in the presence of a mixture of - # literals and other types. - safe_skip = len(seen) > 0 - seen.add(k) - if safe_skip: - continue - - # Keep track of the truthiness info for deleted subtypes which can be relevant - cbt = cbf = False - for j, tj in enumerate(items): - proper_tj = get_proper_type(tj) - if ( - i == j - # avoid further checks if this item was already marked redundant. - or j in removed - # if the current item is a simple literal then this simplification loop can - # safely skip all other simple literals as two literals will only ever be - # subtypes of each other if they are equal, which is already handled above. - # However, if the current item is not a literal, it might plausibly be a - # supertype of other literals in the union, so we must check them again. - # This is an important optimization as is_proper_subtype is pretty expensive. - or (k is not None and is_simple_literal(proper_tj)) - ): - continue - # actual redundancy checks (XXX?) - if is_redundant_literal_instance(proper_item, proper_tj) and is_proper_subtype( - tj, item, keep_erased_types=keep_erased, ignore_promotions=True - ): - # We found a redundant item in the union. - removed.add(j) - cbt = cbt or tj.can_be_true - cbf = cbf or tj.can_be_false - # if deleted subtypes had more general truthiness, use that - if not item.can_be_true and cbt: - items[i] = true_or_false(item) - elif not item.can_be_false and cbf: - items[i] = true_or_false(item) - - return [items[i] for i in range(len(items)) if i not in removed] + # The first pass through this loop, we check if later items are subtypes of earlier items. + # The second pass through this loop, we check if earlier items are subtypes of later items + # (by reversing the remaining items) + for _direction in range(2): + new_items: list[Type] = [] + # seen is a map from a type to its index in new_items + seen: dict[ProperType, int] = {} + for ti in items: + proper_ti = get_proper_type(ti) + + duplicate_index = -1 + # Quickly check if we've seen this type + if proper_ti in seen: + duplicate_index = seen[proper_ti] + else: + # If not, check if we've seen a supertype of this type + for j, tj in enumerate(new_items): + tj = get_proper_type(tj) + if is_redundant_literal_instance(tj, proper_ti) and is_proper_subtype( + proper_ti, tj, keep_erased_types=keep_erased, ignore_promotions=True + ): + duplicate_index = j + break + if duplicate_index != -1: + # If deleted subtypes had more general truthiness, use that + orig_item = new_items[duplicate_index] + if not orig_item.can_be_true and ti.can_be_true: + new_items[duplicate_index] = true_or_false(orig_item) + elif not orig_item.can_be_false and ti.can_be_false: + new_items[duplicate_index] = true_or_false(orig_item) + else: + # We have a non-duplicate item, add it to new_items + seen[proper_ti] = len(new_items) + new_items.append(ti) + + items = new_items + items.reverse() + + return items def _get_type_special_method_bool_ret_type(t: Type) -> Type | None: From cd5f40bca7f450db88a70188db5f61ea72c6e8d5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Apr 2023 16:26:47 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/subtypes.py | 5 ++--- mypy/test/testtypes.py | 5 +---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 88e4c6929aaf..94b7e07fc2ba 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -753,9 +753,8 @@ def visit_tuple_type(self, left: TupleType) -> bool: # for isinstance(x, tuple), though it's unclear why. return True return all(self._is_subtype(li, iter_type) for li in left.items) - elif ( - self._is_subtype(left.partial_fallback, right) - and self._is_subtype(mypy.typeops.tuple_fallback(left), right) + elif self._is_subtype(left.partial_fallback, right) and self._is_subtype( + mypy.typeops.tuple_fallback(left), right ): return True return False diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 3ac91e078b1c..6621c14eacf8 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -611,10 +611,7 @@ def test_simplified_union_with_mixed_str_literals(self) -> None: [fx.lit_str1, fx.lit_str2, fx.lit_str3_inst], UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst]), ) - self.assert_simplified_union( - [fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], - fx.lit_str1, - ) + self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], fx.lit_str1) def assert_simplified_union(self, original: list[Type], union: Type) -> None: assert_equal(make_simplified_union(original), union) From b0094833435f29a933d33319b8182e6c06947236 Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Tue, 25 Apr 2023 14:34:33 -0600 Subject: [PATCH 3/7] fix performance on materialize --- mypy/typeops.py | 51 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/mypy/typeops.py b/mypy/typeops.py index fb00d4363a54..525feff5b836 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -480,6 +480,20 @@ def make_simplified_union( def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[Type]: from mypy.subtypes import is_proper_subtype + # As an optimisation, sort so that we check is_proper_subtype against non literal types first + # test_simplify_very_large_union should pass quickly + if len(items) > 5: + literals = [] + others = [] + for item in items: + # Ignore proper type error, since this is just a speed optimisation + if isinstance(item, LiteralType): # type: ignore[misc] + literals.append(item) + else: + others.append(item) + items = others + items.extend(literals) + # The first pass through this loop, we check if later items are subtypes of earlier items. # The second pass through this loop, we check if earlier items are subtypes of later items # (by reversing the remaining items) @@ -490,6 +504,10 @@ def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[ for ti in items: proper_ti = get_proper_type(ti) + # UninhabitedType is always redundant + if isinstance(proper_ti, UninhabitedType): + continue + duplicate_index = -1 # Quickly check if we've seen this type if proper_ti in seen: @@ -497,8 +515,24 @@ def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[ else: # If not, check if we've seen a supertype of this type for j, tj in enumerate(new_items): + # An optimisation: LiteralTypes are only subtypes if they're equal, which we + # checked in the fast path above with seen + if isinstance(proper_ti, LiteralType) and isinstance(tj, LiteralType): # type: ignore[misc] + continue tj = get_proper_type(tj) - if is_redundant_literal_instance(tj, proper_ti) and is_proper_subtype( + # If tj is an Instance with a last_known_value, do not remove proper_ti + # (unless it's an instance with the same last_known_value) + if ( + isinstance(tj, Instance) + and tj.last_known_value is not None + and not ( + isinstance(proper_ti, Instance) + and tj.last_known_value == proper_ti.last_known_value + ) + ): + continue + + if is_proper_subtype( proper_ti, tj, keep_erased_types=keep_erased, ignore_promotions=True ): duplicate_index = j @@ -660,9 +694,7 @@ def function_type(func: FuncBase, fallback: Instance) -> FunctionLike: return Overloaded([dummy]) -def callable_type( - fdef: FuncItem, fallback: Instance, ret_type: Type | None = None -) -> CallableType: +def callable_type(fdef: FuncItem, fallback: Instance, ret_type: Type | None = None) -> CallableType: # TODO: somewhat unfortunate duplication with prepare_method_signature in semanal if fdef.info and not fdef.is_static and fdef.arg_names: self_type: Type = fill_typevars(fdef.info) @@ -949,17 +981,6 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool return False -def is_redundant_literal_instance(general: ProperType, specific: ProperType) -> bool: - if not isinstance(general, Instance) or general.last_known_value is None: - return True - if isinstance(specific, Instance) and specific.last_known_value == general.last_known_value: - return True - if isinstance(specific, UninhabitedType): - return True - - return False - - def separate_union_literals(t: UnionType) -> tuple[Sequence[LiteralType], Sequence[Type]]: """Separate literals from other members in a union type.""" literal_items = [] From bea67d29c7b7e002658d3e44441147e75c178aab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Apr 2023 21:39:57 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/typeops.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mypy/typeops.py b/mypy/typeops.py index 525feff5b836..d6f659c959ed 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -694,7 +694,9 @@ def function_type(func: FuncBase, fallback: Instance) -> FunctionLike: return Overloaded([dummy]) -def callable_type(fdef: FuncItem, fallback: Instance, ret_type: Type | None = None) -> CallableType: +def callable_type( + fdef: FuncItem, fallback: Instance, ret_type: Type | None = None +) -> CallableType: # TODO: somewhat unfortunate duplication with prepare_method_signature in semanal if fdef.info and not fdef.is_static and fdef.arg_names: self_type: Type = fill_typevars(fdef.info) From fd8463b7a98ae74378b742066ed8fa50848d8903 Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Tue, 25 Apr 2023 17:37:19 -0600 Subject: [PATCH 5/7] more perf --- mypy/typeops.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/mypy/typeops.py b/mypy/typeops.py index d6f659c959ed..44d865b40b2b 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -480,20 +480,6 @@ def make_simplified_union( def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[Type]: from mypy.subtypes import is_proper_subtype - # As an optimisation, sort so that we check is_proper_subtype against non literal types first - # test_simplify_very_large_union should pass quickly - if len(items) > 5: - literals = [] - others = [] - for item in items: - # Ignore proper type error, since this is just a speed optimisation - if isinstance(item, LiteralType): # type: ignore[misc] - literals.append(item) - else: - others.append(item) - items = others - items.extend(literals) - # The first pass through this loop, we check if later items are subtypes of earlier items. # The second pass through this loop, we check if earlier items are subtypes of later items # (by reversing the remaining items) @@ -501,6 +487,7 @@ def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[ new_items: list[Type] = [] # seen is a map from a type to its index in new_items seen: dict[ProperType, int] = {} + unduplicated_literal_fallbacks: set[Instance] = set() for ti in items: proper_ti = get_proper_type(ti) @@ -512,13 +499,20 @@ def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[ # Quickly check if we've seen this type if proper_ti in seen: duplicate_index = seen[proper_ti] + elif ( + isinstance(proper_ti, LiteralType) + and proper_ti.fallback in unduplicated_literal_fallbacks + ): + # This is an optimisation for unions with many LiteralType + # We've already checked for exact duplicates. This means that any super type of + # the LiteralType must be a super type of its fallback. If we've gone through + # the expensive loop below and found no super type for a previous LiteralType + # with the same fallback, we can skip doing that work again and just add the type + # to new_items + pass else: # If not, check if we've seen a supertype of this type for j, tj in enumerate(new_items): - # An optimisation: LiteralTypes are only subtypes if they're equal, which we - # checked in the fast path above with seen - if isinstance(proper_ti, LiteralType) and isinstance(tj, LiteralType): # type: ignore[misc] - continue tj = get_proper_type(tj) # If tj is an Instance with a last_known_value, do not remove proper_ti # (unless it's an instance with the same last_known_value) @@ -548,6 +542,8 @@ def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[ # We have a non-duplicate item, add it to new_items seen[proper_ti] = len(new_items) new_items.append(ti) + if isinstance(proper_ti, LiteralType): + unduplicated_literal_fallbacks.add(proper_ti.fallback) items = new_items items.reverse() From b13266362ebccea7f78c12453bf98cbac30373a2 Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Fri, 5 May 2023 13:30:48 -0700 Subject: [PATCH 6/7] add regression test for crash --- test-data/unit/check-type-aliases.test | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test-data/unit/check-type-aliases.test b/test-data/unit/check-type-aliases.test index 9dd56ad309f3..05a03ecaf7b0 100644 --- a/test-data/unit/check-type-aliases.test +++ b/test-data/unit/check-type-aliases.test @@ -1043,3 +1043,19 @@ class C(Generic[T]): def test(cls) -> None: cls.attr [builtins fixtures/classmethod.pyi] + +[case testRecursiveAliasTuple] +from typing_extensions import Literal, TypeAlias +from typing import Tuple, Union + +Expr: TypeAlias = Union[ + Tuple[Literal[123], int], + Tuple[Literal[456], "Expr"], +] + +def eval(e: Expr) -> int: + if e[0] == 123: + return e[1] + elif e[0] == 456: + return -eval(e[1]) +[builtins fixtures/dict.pyi] From bc79133751238230eb77fe54fff157026134a973 Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Fri, 5 May 2023 13:32:37 -0700 Subject: [PATCH 7/7] micro optimisations --- mypy/typeops.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mypy/typeops.py b/mypy/typeops.py index 44d865b40b2b..a0976ee41617 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -487,7 +487,7 @@ def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[ new_items: list[Type] = [] # seen is a map from a type to its index in new_items seen: dict[ProperType, int] = {} - unduplicated_literal_fallbacks: set[Instance] = set() + unduplicated_literal_fallbacks: set[Instance] | None = None for ti in items: proper_ti = get_proper_type(ti) @@ -501,6 +501,7 @@ def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[ duplicate_index = seen[proper_ti] elif ( isinstance(proper_ti, LiteralType) + and unduplicated_literal_fallbacks is not None and proper_ti.fallback in unduplicated_literal_fallbacks ): # This is an optimisation for unions with many LiteralType @@ -543,9 +544,13 @@ def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[ seen[proper_ti] = len(new_items) new_items.append(ti) if isinstance(proper_ti, LiteralType): + if unduplicated_literal_fallbacks is None: + unduplicated_literal_fallbacks = set() unduplicated_literal_fallbacks.add(proper_ti.fallback) items = new_items + if len(items) <= 1: + break items.reverse() return items