diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 0757415f67536..573ca334a5d15 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -165,6 +165,9 @@ def __init__( self.runtime_args: list[list[RuntimeArg]] = [[]] self.function_name_stack: list[str] = [] self.class_ir_stack: list[ClassIR] = [] + # Keep track of whether the next statement in a block is reachable + # or not, separately for each block nesting level + self.block_reachable_stack: list[bool] = [True] self.current_module = current_module self.mapper = mapper @@ -1302,6 +1305,14 @@ def is_native_attr_ref(self, expr: MemberExpr) -> bool: and not obj_rtype.class_ir.get_method(expr.name) ) + def mark_block_unreachable(self) -> None: + """Mark statements in the innermost block being processed as unreachable. + + This should be called after a statement that unconditionally leaves the + block, such as 'break' or 'return'. + """ + self.block_reachable_stack[-1] = False + # Lacks a good type because there wasn't a reasonable type in 3.5 :( def catch_errors(self, line: int) -> Any: return catch_errors(self.module_path, line) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index d7e01456139dd..2c17eb2bb14d0 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -118,8 +118,13 @@ def transform_block(builder: IRBuilder, block: Block) -> None: if not block.is_unreachable: + builder.block_reachable_stack.append(True) for stmt in block.body: builder.accept(stmt) + if not builder.block_reachable_stack[-1]: + # The rest of the block is unreachable, so skip it + break + builder.block_reachable_stack.pop() # Raise a RuntimeError if we hit a non-empty unreachable block. # Don't complain about empty unreachable blocks, since mypy inserts # those after `if MYPY`. diff --git a/mypyc/irbuild/visitor.py b/mypyc/irbuild/visitor.py index d8725ee04dc5c..12e186fd40d8a 100644 --- a/mypyc/irbuild/visitor.py +++ b/mypyc/irbuild/visitor.py @@ -194,6 +194,7 @@ def visit_expression_stmt(self, stmt: ExpressionStmt) -> None: def visit_return_stmt(self, stmt: ReturnStmt) -> None: transform_return_stmt(self.builder, stmt) + self.builder.mark_block_unreachable() def visit_assignment_stmt(self, stmt: AssignmentStmt) -> None: transform_assignment_stmt(self.builder, stmt) @@ -212,12 +213,15 @@ def visit_for_stmt(self, stmt: ForStmt) -> None: def visit_break_stmt(self, stmt: BreakStmt) -> None: transform_break_stmt(self.builder, stmt) + self.builder.mark_block_unreachable() def visit_continue_stmt(self, stmt: ContinueStmt) -> None: transform_continue_stmt(self.builder, stmt) + self.builder.mark_block_unreachable() def visit_raise_stmt(self, stmt: RaiseStmt) -> None: transform_raise_stmt(self.builder, stmt) + self.builder.mark_block_unreachable() def visit_try_stmt(self, stmt: TryStmt) -> None: transform_try_stmt(self.builder, stmt) diff --git a/mypyc/test-data/irbuild-unreachable.test b/mypyc/test-data/irbuild-unreachable.test index 1c024a249bf1c..b5188c91ac580 100644 --- a/mypyc/test-data/irbuild-unreachable.test +++ b/mypyc/test-data/irbuild-unreachable.test @@ -1,4 +1,4 @@ -# Test cases for unreachable expressions +# Test cases for unreachable expressions and statements [case testUnreachableMemberExpr] import sys @@ -104,3 +104,138 @@ L5: L6: y = r11 return 1 + +[case testUnreachableStatementAfterReturn] +def f(x: bool) -> int: + if x: + return 1 + f(False) + return 2 +[out] +def f(x): + x :: bool +L0: + if x goto L1 else goto L2 :: bool +L1: + return 2 +L2: + return 4 + +[case testUnreachableStatementAfterContinue] +def c() -> bool: + return False + +def f() -> None: + n = True + while n: + if c(): + continue + if int(): + f() + n = False +[out] +def c(): +L0: + return 0 +def f(): + n, r0 :: bool +L0: + n = 1 +L1: + if n goto L2 else goto L5 :: bool +L2: + r0 = c() + if r0 goto L3 else goto L4 :: bool +L3: + goto L1 +L4: + n = 0 + goto L1 +L5: + return 1 + +[case testUnreachableStatementAfterBreak] +def c() -> bool: + return False + +def f() -> None: + n = True + while n: + if c(): + break + if int(): + f() + n = False +[out] +def c(): +L0: + return 0 +def f(): + n, r0 :: bool +L0: + n = 1 +L1: + if n goto L2 else goto L5 :: bool +L2: + r0 = c() + if r0 goto L3 else goto L4 :: bool +L3: + goto L5 +L4: + n = 0 + goto L1 +L5: + return 1 + +[case testUnreachableStatementAfterRaise] +def f(x: bool) -> int: + if x: + raise ValueError() + print('hello') + return 2 +[out] +def f(x): + x :: bool + r0 :: object + r1 :: str + r2, r3 :: object +L0: + if x goto L1 else goto L2 :: bool +L1: + r0 = builtins :: module + r1 = 'ValueError' + r2 = CPyObject_GetAttr(r0, r1) + r3 = PyObject_CallFunctionObjArgs(r2, 0) + CPy_Raise(r3) + unreachable +L2: + return 4 + +[case testUnreachableStatementAfterAssertFalse] +def f(x: bool) -> int: + if x: + assert False + print('hello') + return 2 +[out] +def f(x): + x, r0 :: bool + r1 :: str + r2 :: object + r3 :: str + r4, r5 :: object +L0: + if x goto L1 else goto L4 :: bool +L1: + if 0 goto L3 else goto L2 :: bool +L2: + r0 = raise AssertionError + unreachable +L3: + r1 = 'hello' + r2 = builtins :: module + r3 = 'print' + r4 = CPyObject_GetAttr(r2, r3) + r5 = PyObject_CallFunctionObjArgs(r4, r1, 0) +L4: + return 4