Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from mypyc.options import CompilerOptions
from mypyc.transform.copy_propagation import do_copy_propagation
from mypyc.transform.exceptions import insert_exception_handling
from mypyc.transform.flag_elimination import do_flag_elimination
from mypyc.transform.refcount import insert_ref_count_opcodes
from mypyc.transform.uninit import insert_uninit_checks

Expand Down Expand Up @@ -234,8 +235,9 @@ def compile_scc_to_ir(
insert_exception_handling(fn)
# Insert refcount handling.
insert_ref_count_opcodes(fn)
# Perform copy propagation optimization.
# Perform optimizations.
do_copy_propagation(fn, compiler_options)
do_flag_elimination(fn, compiler_options)

return modules

Expand Down
300 changes: 300 additions & 0 deletions mypyc/test-data/opt-flag-elimination.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
-- Test cases for "flag elimination" optimization. Used to optimize away
-- registers that are always used immediately after assignment as branch conditions.

[case testFlagEliminationSimple]
def c() -> bool:
return True
def d() -> bool:
return True

def f(x: bool) -> int:
if x:
b = c()
else:
b = d()
if b:
return 1
else:
return 2
[out]
def c():
L0:
return 1
def d():
L0:
return 1
def f(x):
x, r0, r1 :: bool
L0:
if x goto L1 else goto L2 :: bool
L1:
r0 = c()
if r0 goto L4 else goto L5 :: bool
L2:
r1 = d()
if r1 goto L4 else goto L5 :: bool
L3:
unreachable
L4:
return 2
L5:
return 4

[case testFlagEliminationOneAssignment]
def c() -> bool:
return True

def f(x: bool) -> int:
# Not applied here
b = c()
if b:
return 1
else:
return 2
[out]
def c():
L0:
return 1
def f(x):
x, r0, b :: bool
L0:
r0 = c()
b = r0
if b goto L1 else goto L2 :: bool
L1:
return 2
L2:
return 4

[case testFlagEliminationThreeCases]
def c(x: int) -> bool:
return True

def f(x: bool, y: bool) -> int:
if x:
b = c(1)
elif y:
b = c(2)
else:
b = c(3)
if b:
return 1
else:
return 2
[out]
def c(x):
x :: int
L0:
return 1
def f(x, y):
x, y, r0, r1, r2 :: bool
L0:
if x goto L1 else goto L2 :: bool
L1:
r0 = c(2)
if r0 goto L6 else goto L7 :: bool
L2:
if y goto L3 else goto L4 :: bool
L3:
r1 = c(4)
if r1 goto L6 else goto L7 :: bool
L4:
r2 = c(6)
if r2 goto L6 else goto L7 :: bool
L5:
unreachable
L6:
return 2
L7:
return 4

[case testFlagEliminationAssignmentNotLastOp]
def f(x: bool) -> int:
y = 0
if x:
b = True
y = 1
else:
b = False
if b:
return 1
else:
return 2
[out]
def f(x):
x :: bool
y :: int
b :: bool
L0:
y = 0
if x goto L1 else goto L2 :: bool
L1:
b = 1
y = 2
goto L3
L2:
b = 0
L3:
if b goto L4 else goto L5 :: bool
L4:
return 2
L5:
return 4

[case testFlagEliminationAssignmentNoDirectGoto]
def f(x: bool) -> int:
if x:
b = True
else:
b = False
if x:
if b:
return 1
else:
return 2
return 4
[out]
def f(x):
x, b :: bool
L0:
if x goto L1 else goto L2 :: bool
L1:
b = 1
goto L3
L2:
b = 0
L3:
if x goto L4 else goto L7 :: bool
L4:
if b goto L5 else goto L6 :: bool
L5:
return 2
L6:
return 4
L7:
return 8

[case testFlagEliminationBranchNotNextOpAfterGoto]
def f(x: bool) -> int:
if x:
b = True
else:
b = False
y = 1 # Prevents the optimization
if b:
return 1
else:
return 2
[out]
def f(x):
x, b :: bool
y :: int
L0:
if x goto L1 else goto L2 :: bool
L1:
b = 1
goto L3
L2:
b = 0
L3:
y = 2
if b goto L4 else goto L5 :: bool
L4:
return 2
L5:
return 4

[case testFlagEliminationFlagReadTwice]
def f(x: bool) -> bool:
if x:
b = True
else:
b = False
if b:
return b # Prevents the optimization
else:
return False
[out]
def f(x):
x, b :: bool
L0:
if x goto L1 else goto L2 :: bool
L1:
b = 1
goto L3
L2:
b = 0
L3:
if b goto L4 else goto L5 :: bool
L4:
return b
L5:
return 0

[case testFlagEliminationArgumentNotEligible]
def f(x: bool, b: bool) -> bool:
if x:
b = True
else:
b = False
if b:
return True
else:
return False
[out]
def f(x, b):
x, b :: bool
L0:
if x goto L1 else goto L2 :: bool
L1:
b = 1
goto L3
L2:
b = 0
L3:
if b goto L4 else goto L5 :: bool
L4:
return 1
L5:
return 0

[case testFlagEliminationFlagNotAlwaysDefined]
def f(x: bool, y: bool) -> bool:
if x:
b = True
elif y:
b = False
else:
bb = False # b not assigned here -> can't optimize
if b:
return True
else:
return False
[out]
def f(x, y):
x, y, r0, b, bb, r1 :: bool
L0:
r0 = <error> :: bool
b = r0
if x goto L1 else goto L2 :: bool
L1:
b = 1
goto L5
L2:
if y goto L3 else goto L4 :: bool
L3:
b = 0
goto L5
L4:
bb = 0
L5:
if is_error(b) goto L6 else goto L7
L6:
r1 = raise UnboundLocalError('local variable "b" referenced before assignment')
unreachable
L7:
if b goto L8 else goto L9 :: bool
L8:
return 1
L9:
return 0
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Runner for copy propagation optimization tests."""
"""Runner for IR optimization tests."""

from __future__ import annotations

Expand All @@ -8,6 +8,7 @@
from mypy.test.config import test_temp_dir
from mypy.test.data import DataDrivenTestCase
from mypyc.common import TOP_LEVEL_NAME
from mypyc.ir.func_ir import FuncIR
from mypyc.ir.pprint import format_func
from mypyc.options import CompilerOptions
from mypyc.test.testutil import (
Expand All @@ -19,13 +20,16 @@
use_custom_builtins,
)
from mypyc.transform.copy_propagation import do_copy_propagation
from mypyc.transform.flag_elimination import do_flag_elimination
from mypyc.transform.uninit import insert_uninit_checks

files = ["opt-copy-propagation.test"]

class OptimizationSuite(MypycDataSuite):
"""Base class for IR optimization test suites.

To use this, add a base class and define "files" and "do_optimizations".
"""

class TestCopyPropagation(MypycDataSuite):
files = files
base_path = test_temp_dir

def run_case(self, testcase: DataDrivenTestCase) -> None:
Expand All @@ -41,7 +45,24 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
if fn.name == TOP_LEVEL_NAME and not testcase.name.endswith("_toplevel"):
continue
insert_uninit_checks(fn)
do_copy_propagation(fn, CompilerOptions())
self.do_optimizations(fn)
actual.extend(format_func(fn))

assert_test_output(testcase, actual, "Invalid source code output", expected_output)

def do_optimizations(self, fn: FuncIR) -> None:
raise NotImplementedError


class TestCopyPropagation(OptimizationSuite):
files = ["opt-copy-propagation.test"]

def do_optimizations(self, fn: FuncIR) -> None:
do_copy_propagation(fn, CompilerOptions())


class TestFlagElimination(OptimizationSuite):
files = ["opt-flag-elimination.test"]

def do_optimizations(self, fn: FuncIR) -> None:
do_flag_elimination(fn, CompilerOptions())
Loading