diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 0bafb2298eae..027c0d85a854 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1497,6 +1497,11 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: restricted_left_type = true_only(left_type) result_is_left = not left_type.can_be_false + if e.right_unreachable: + right_map = None + elif e.right_always: + left_map = None + right_type = self.analyze_cond_branch(right_map, e.right, left_type) if right_map is None: diff --git a/mypy/nodes.py b/mypy/nodes.py index 48715662d9d6..cfd74b69f577 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1440,6 +1440,10 @@ class OpExpr(Expression): right = None # type: Expression # Inferred type for the operator method type (when relevant). method_type = None # type: Optional[mypy.types.Type] + # Is the right side going to be evaluated every time? + right_always = False + # Is the right side unreachable? + right_unreachable = False def __init__(self, op: str, left: Expression, right: Expression) -> None: self.op = op diff --git a/mypy/semanal.py b/mypy/semanal.py index b46fbc1e736f..523edc8563e0 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -3064,6 +3064,19 @@ def visit_member_expr(self, expr: MemberExpr) -> None: def visit_op_expr(self, expr: OpExpr) -> None: expr.left.accept(self) + + if expr.op in ('and', 'or'): + inferred = infer_condition_value(expr.left, + pyversion=self.options.python_version, + platform=self.options.platform) + if ((inferred == ALWAYS_FALSE and expr.op == 'and') or + (inferred == ALWAYS_TRUE and expr.op == 'or')): + expr.right_unreachable = True + return + elif ((inferred == ALWAYS_TRUE and expr.op == 'and') or + (inferred == ALWAYS_FALSE and expr.op == 'or')): + expr.right_always = True + expr.right.accept(self) def visit_comparison_expr(self, expr: ComparisonExpr) -> None: @@ -3986,7 +3999,7 @@ def infer_reachability_of_if_statement(s: IfStmt, pyversion: Tuple[int, int], platform: str) -> None: for i in range(len(s.expr)): - result = infer_if_condition_value(s.expr[i], pyversion, platform) + result = infer_condition_value(s.expr[i], pyversion, platform) if result in (ALWAYS_FALSE, MYPY_FALSE): # The condition is considered always false, so we skip the if/elif body. mark_block_unreachable(s.body[i]) @@ -4004,8 +4017,8 @@ def infer_reachability_of_if_statement(s: IfStmt, break -def infer_if_condition_value(expr: Expression, pyversion: Tuple[int, int], platform: str) -> int: - """Infer whether if condition is always true/false. +def infer_condition_value(expr: Expression, pyversion: Tuple[int, int], platform: str) -> int: + """Infer whether the given condition is always true/false. Return ALWAYS_TRUE if always true, ALWAYS_FALSE if always false, MYPY_TRUE if true under mypy and false at runtime, MYPY_FALSE if @@ -4023,6 +4036,17 @@ def infer_if_condition_value(expr: Expression, pyversion: Tuple[int, int], platf name = expr.name elif isinstance(expr, MemberExpr): name = expr.name + elif isinstance(expr, OpExpr) and expr.op in ('and', 'or'): + left = infer_condition_value(expr.left, pyversion, platform) + if ((left == ALWAYS_TRUE and expr.op == 'and') or + (left == ALWAYS_FALSE and expr.op == 'or')): + # Either `True and ` or `False or `: the result will + # always be the right-hand-side. + return infer_condition_value(expr.right, pyversion, platform) + else: + # The result will always be the left-hand-side (e.g. ALWAYS_* or + # TRUTH_VALUE_UNKNOWN). + return left else: result = consider_sys_version_info(expr, pyversion) if result == TRUTH_VALUE_UNKNOWN: diff --git a/test-data/unit/check-unreachable-code.test b/test-data/unit/check-unreachable-code.test index 75e6f88c0c35..fa29290dbbf4 100644 --- a/test-data/unit/check-unreachable-code.test +++ b/test-data/unit/check-unreachable-code.test @@ -457,3 +457,63 @@ else: reveal_type(x) # E: Revealed type is 'builtins.str' [builtins fixtures/ops.pyi] [out] + +[case testShortCircuitInExpression] +import typing +def make() -> bool: pass +PY2 = PY3 = make() + +a = PY2 and 's' +b = PY3 and 's' +c = PY2 or 's' +d = PY3 or 's' +e = (PY2 or PY3) and 's' +f = (PY3 or PY2) and 's' +g = (PY2 or PY3) or 's' +h = (PY3 or PY2) or 's' +reveal_type(a) # E: Revealed type is 'builtins.bool' +reveal_type(b) # E: Revealed type is 'builtins.str' +reveal_type(c) # E: Revealed type is 'builtins.str' +reveal_type(d) # E: Revealed type is 'builtins.bool' +reveal_type(e) # E: Revealed type is 'builtins.str' +reveal_type(f) # E: Revealed type is 'builtins.str' +reveal_type(g) # E: Revealed type is 'builtins.bool' +reveal_type(h) # E: Revealed type is 'builtins.bool' +[builtins fixtures/ops.pyi] +[out] + +[case testShortCircuitAndWithConditionalAssignment] +# flags: --platform linux +import sys + +def f(): pass +PY2 = f() +if PY2 and sys.platform == 'linux': + x = 'foo' +else: + x = 3 +reveal_type(x) # E: Revealed type is 'builtins.int' +if sys.platform == 'linux' and PY2: + y = 'foo' +else: + y = 3 +reveal_type(y) # E: Revealed type is 'builtins.int' +[builtins fixtures/ops.pyi] + +[case testShortCircuitOrWithConditionalAssignment] +# flags: --platform linux +import sys + +def f(): pass +PY2 = f() +if PY2 or sys.platform == 'linux': + x = 'foo' +else: + x = 3 +reveal_type(x) # E: Revealed type is 'builtins.str' +if sys.platform == 'linux' or PY2: + y = 'foo' +else: + y = 3 +reveal_type(y) # E: Revealed type is 'builtins.str' +[builtins fixtures/ops.pyi] diff --git a/test-data/unit/fixtures/ops.pyi b/test-data/unit/fixtures/ops.pyi index 8e18aeae2afd..ae48a1f2019a 100644 --- a/test-data/unit/fixtures/ops.pyi +++ b/test-data/unit/fixtures/ops.pyi @@ -29,6 +29,7 @@ class bool: pass class str: def __init__(self, x: 'int') -> None: pass def __add__(self, x: 'str') -> 'str': pass + def __eq__(self, x: object) -> bool: pass def startswith(self, x: 'str') -> bool: pass class unicode: pass