Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT2][Optimus] Read the patterns from the config instead of hard-code passes #125136

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,19 @@ def forward(self, x, y):
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)

@torch._inductor.config.patch(
pre_grad_fusion_options={
"normalization_pass": {},
"remove_split_with_size_one_pass": {},
"merge_getitem_cat_pass": {},
"merge_stack_tahn_unbind_pass": {},
"merge_splits_pass": {},
"mutate_cat_pass": {},
"split_cat_pass": {},
"unbind_stack_pass": {},
},
post_grad_fusion_options={},
)
def test_simple_split(self):
class Model(torch.nn.Module):
def __init__(self):
Expand Down
4 changes: 3 additions & 1 deletion test/inductor/test_decompose_mem_bound_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def forward(self, input1, input2):

@requires_cuda
@torch._inductor.config.patch(
decompose_mem_bound_mm=True,
post_grad_fusion_options={
"decompose_mm_pass": {},
}
)
@instantiate_parametrized_tests
class TestDecomposeMemMM(TestCase):
Expand Down
17 changes: 13 additions & 4 deletions test/inductor/test_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def fn(a, b, c):
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
self.common(fn, args, 2, 5)
self.common(fn, args, 1, 4)

def test_cat_addmm(self):
def fn(a, b, c):
Expand All @@ -538,7 +538,7 @@ def fn(a, b, c):
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
self.common(fn, args, 2, 5)
self.common(fn, args, 1, 4)

def test_cat_slice_cat_cuda(self):
def fn(a, b):
Expand Down Expand Up @@ -839,7 +839,9 @@ def foo(x, y):

def test_match_with_mutation(self):
counter = 0
test_pass = PatternMatcherPass(prevent_match_across_mutations=True)
test_pass = PatternMatcherPass(
prevent_match_across_mutations=True, pass_name="test"
)

@register_graph_pattern(
CallFunction(
Expand Down Expand Up @@ -892,7 +894,14 @@ def fn5(x, y):
]

with unittest.mock.patch(
"torch._inductor.fx_passes.pre_grad.pattern_matcher_passes", [test_pass]
"torch._inductor.fx_passes.pre_grad.config.pre_grad_fusion_options",
{"test": {}},
), unittest.mock.patch(
"torch._inductor.fx_passes.pre_grad.PRE_GRAD_FUSIONS",
[],
), unittest.mock.patch(
"torch._inductor.fx_passes.pre_grad.PRE_GRAD_PATTERNS",
{"test": test_pass},
):
for fn in (fn0, fn1, fn2, fn3, fn4, fn5):
counter = 0
Expand Down
19 changes: 17 additions & 2 deletions test/inductor/test_split_cat_fx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,19 @@


def patch(f):
f = torch._inductor.config.patch(split_cat_fx_passes=True)(f)
f = torch._inductor.config.patch(
pre_grad_fusion_options={
"normalization_pass": {},
"remove_split_with_size_one_pass": {},
"merge_getitem_cat_pass": {},
"merge_stack_tahn_unbind_pass": {},
"merge_splits_pass": {},
"mutate_cat_pass": {},
"split_cat_pass": {},
"unbind_stack_pass": {},
},
post_grad_fusion_options={},
)(f)
return f


Expand Down Expand Up @@ -605,7 +617,10 @@ def multi_split_cat(x1, x2):
)
counters.clear()

@torch._inductor.config.patch(split_cat_fx_passes=False)
@torch._inductor.config.patch(
pre_grad_fusion_options={},
post_grad_fusion_options={},
)
def test_config_flag_is_respected(self):
def split_with_cat(x):
fs = torch.split(x, [4, 4, 24], dim=-1)
Expand Down
5 changes: 1 addition & 4 deletions torch/_inductor/fx_passes/decompose_mem_bound_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .. import config

from ..pattern_matcher import Arg, CallFunction, Match, register_graph_pattern
from .split_cat import construct_pattern_matcher_pass, get_config_flag
from .split_cat import construct_pattern_matcher_pass

aten = torch.ops.aten
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -94,7 +94,6 @@ def print_decompose_pattern(match: Match, inputs: List[torch.fx.Node]):
@register_graph_pattern(
CallFunction(aten.bmm, Arg(), Arg()),
pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"),
extra_check=get_config_flag("decompose_mm_pass", "decompose_mem_bound_mm"),
)
def decompose_bmm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node):
def repl(mat1, mat2):
Expand All @@ -111,7 +110,6 @@ def repl(mat1, mat2):
@register_graph_pattern(
CallFunction(aten.addmm, Arg(), Arg(), Arg()),
pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"),
extra_check=get_config_flag("decompose_mm_pass", "decompose_mem_bound_mm"),
)
def decompose_addmm(
match: Match,
Expand All @@ -133,7 +131,6 @@ def repl(mat1, mat2, mat3):
@register_graph_pattern(
CallFunction(aten.mm, Arg(), Arg()),
pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"),
extra_check=get_config_flag("decompose_mm_pass", "decompose_mem_bound_mm"),
)
def decompose_mm(
match: Match,
Expand Down
3 changes: 0 additions & 3 deletions torch/_inductor/fx_passes/group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,6 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):

if all(bias is None for bias in group_biases):
group_biases = None # type: ignore[assignment]
group_biases: Optional[List[Any]]

with graph.inserting_before(subset[0]):
fused_mm = graph.call_function(
Expand Down Expand Up @@ -649,10 +648,8 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):

if all(bias is None for bias in group_biases):
group_biases = None # type: ignore[assignment]
group_biases: Optional[List[Any]]
if all(weight is None for weight in group_weights):
group_weights = None # type: ignore[assignment]
group_weights: Optional[List[Any]]
assert all(
eps == group_epss[0] for eps in group_epss
), "all epsilon values must be equal"
Expand Down
9 changes: 6 additions & 3 deletions torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ..utils import decode_device, is_pointwise_use
from ..virtualized import V
from .ddp_fusion import fuse_ddp_communication
from .group_batch_fusion import group_batch_fusion_passes
from .group_batch_fusion import group_batch_fusion_passes, POST_GRAD_FUSIONS
from .pre_grad import is_same_dict, save_inductor_dict
from .reinplace import reinplace_inplaceable_ops
from .split_cat import POST_GRAD_PATTERNS
Expand All @@ -54,7 +54,6 @@
aten = torch.ops.aten
prims = torch.ops.prims

pattern_matcher_passes = POST_GRAD_PATTERNS.values()
# First pass_patterns[0] are applied, then [1], then [2]
pass_patterns = [
PatternMatcherPass(),
Expand Down Expand Up @@ -89,7 +88,11 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
remove_noop_ops(gm.graph)
for patterns in pass_patterns:
patterns.apply(gm.graph) # type: ignore[arg-type]
for pattern_matcher_pass in pattern_matcher_passes:
for pass_name in config.post_grad_fusion_options:
# skip all patterns for group batch fusions
if pass_name in POST_GRAD_FUSIONS:
continue
pattern_matcher_pass = POST_GRAD_PATTERNS[pass_name]
inductor_before_change = save_inductor_dict(
[pattern_matcher_pass.pass_name]
)
Expand Down
20 changes: 10 additions & 10 deletions torch/_inductor/fx_passes/pre_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
stable_topological_sort,
)
from ..utils import is_cpu_device, pass_execution_and_save
from .group_batch_fusion import group_batch_fusion_passes
from .group_batch_fusion import group_batch_fusion_passes, PRE_GRAD_FUSIONS
from .misc_patterns import numpy_compat_normalization
from .split_cat import PRE_GRAD_PATTERNS

Expand Down Expand Up @@ -85,12 +85,6 @@ def remove_split_ops(graph, shape_prop):
return None


# split_cat related fusions
pattern_matcher_passes = list(PRE_GRAD_PATTERNS.values())
# non-split_cat related fusions
# TODO: move them to the fusions dict too.
pattern_matcher_passes.append(efficient_conv_bn_eval_pass)

pattern_matcher_passes_aten: List[PatternMatcherPass] = [
remove_split_with_size_one_pass_aten,
merge_getitem_cat_pass_aten,
Expand Down Expand Up @@ -134,6 +128,7 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs=None):
def shape_prop(mod) -> None:
ShapeProp(
gm=mod,
# pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode`
fake_mode=detect_fake_mode(example_inputs),
).propagate(*example_inputs)

Expand Down Expand Up @@ -202,10 +197,13 @@ def shape_prop(mod) -> None:
if example_inputs is not None:
gm = fuse_fx(gm, example_inputs)
numpy_compat_normalization(gm.graph)

optimus_scuba_log["before_recompile_pre_grad"] = upload_graph(gm.graph)
group_batch_fusion_passes(gm.graph, pre_grad=True)
for pattern_matcher_pass in pattern_matcher_passes:
for pass_name in config.pre_grad_fusion_options:
# skip all patterns for group batch fusions
if pass_name in PRE_GRAD_FUSIONS:
continue
pattern_matcher_pass = PRE_GRAD_PATTERNS[pass_name]
inductor_before_change = save_inductor_dict(
[pattern_matcher_pass.pass_name]
)
Expand All @@ -214,6 +212,8 @@ def shape_prop(mod) -> None:
optimus_scuba_log[
f"{pattern_matcher_pass.pass_name}_pre_grad"
] = upload_graph(gm.graph)
# TODO: move efficient_conv_bn_eval_pass to the fusions dict too.
efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type]

if config.pre_grad_custom_pass is not None:
config.pre_grad_custom_pass(gm.graph)
Expand Down Expand Up @@ -249,7 +249,7 @@ def shape_prop(mod) -> None:

def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule:
is_cpu = is_cpu_device(example_inputs)

# pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode`
fake_mode = detect_fake_mode(example_inputs)

gm = sink_cat_after_pointwise(gm)
Expand Down