diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index bfbe961adc7a..0149f1971477 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -69,7 +69,7 @@ try_expanding_sum_type_to_union, tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound, function_type, callable_type, try_getting_str_literals, custom_special_method, - is_literal_type_like, + is_literal_type_like, simple_literal_type, ) from mypy.message_registry import ErrorMessage import mypy.errorcodes as codes @@ -3874,26 +3874,43 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F if_type = self.analyze_cond_branch(if_map, e.if_expr, context=ctx, allow_none_return=allow_none_return) + # we want to keep the narrowest value of if_type for union'ing the branches + # however, it would be silly to pass a literal as a type context. Pass the + # underlying fallback type instead. + if_type_fallback = simple_literal_type(get_proper_type(if_type)) or if_type + # Analyze the right branch using full type context and store the type full_context_else_type = self.analyze_cond_branch(else_map, e.else_expr, context=ctx, allow_none_return=allow_none_return) + if not mypy.checker.is_valid_inferred_type(if_type): # Analyze the right branch disregarding the left branch. else_type = full_context_else_type + # we want to keep the narrowest value of else_type for union'ing the branches + # however, it would be silly to pass a literal as a type context. Pass the + # underlying fallback type instead. + else_type_fallback = simple_literal_type(get_proper_type(else_type)) or else_type # If it would make a difference, re-analyze the left # branch using the right branch's type as context. - if ctx is None or not is_equivalent(else_type, ctx): + if ctx is None or not is_equivalent(else_type_fallback, ctx): # TODO: If it's possible that the previous analysis of # the left branch produced errors that are avoided # using this context, suppress those errors. - if_type = self.analyze_cond_branch(if_map, e.if_expr, context=else_type, + if_type = self.analyze_cond_branch(if_map, e.if_expr, context=else_type_fallback, allow_none_return=allow_none_return) + elif if_type_fallback == ctx: + # There is no point re-running the analysis if if_type is equal to ctx. + # That would be an exact duplicate of the work we just did. + # This optimization is particularly important to avoid exponential blowup with nested + # if/else expressions: https://github.com/python/mypy/issues/9591 + # TODO: would checking for is_proper_subtype also work and cover more cases? + else_type = full_context_else_type else: # Analyze the right branch in the context of the left # branch's type. - else_type = self.analyze_cond_branch(else_map, e.else_expr, context=if_type, + else_type = self.analyze_cond_branch(else_map, e.else_expr, context=if_type_fallback, allow_none_return=allow_none_return) # Only create a union type if the type context is a union, to be mostly diff --git a/mypy/typeops.py b/mypy/typeops.py index e2e44b915c0c..32ec5b6e7114 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -318,7 +318,7 @@ def simple_literal_value_key(t: ProperType) -> Optional[Tuple[str, ...]]: return None -def simple_literal_type(t: ProperType) -> Optional[Instance]: +def simple_literal_type(t: Optional[ProperType]) -> Optional[Instance]: """Extract the underlying fallback Instance type for a simple Literal""" if isinstance(t, Instance) and t.last_known_value is not None: t = t.last_known_value