Skip to content

Commit

Permalink
[Inductor][fx pass] Fuse pointwise operators in the post grad (#114778)
Browse files Browse the repository at this point in the history
Summary:

We construct a unified API that can be easily add pointwise ops to be batched in the post grad

Test Plan:
# unit test
```
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:group_batch_fusion
```
Buck UI: https://www.internalfb.com/buck2/6c5d1d31-e4d1-4865-bf79-1e7ac3b6e051
Test UI: https://www.internalfb.com/intern/testinfra/testrun/1970325050015770
Network: Up: 72KiB  Down: 22KiB  (reSessionID-44adc8b2-54e9-453a-bd20-710cefefaed1)
Jobs completed: 20. Time elapsed: 1:44.6s.
Cache hits: 0%. Commands: 2 (cached: 0, remote: 0, local: 2)
Tests finished: Pass 7. Fail 0. Fatal 0. Skip 0. Build failure 0
# local reproduce
### cmf
P887605070
### igctr
P892987433
### mai
P893109069
### icvr
P893075846
### oc
P893109069
### mixed precison training
P898569125

# e2e test
baseline
f509792379
proposal
f509792025
https://pxl.cl/3Xbcf

Reviewed By: xuzhao9

Differential Revision: D51332067
  • Loading branch information
mengluy authored and facebook-github-bot committed Dec 7, 2023
1 parent 622688f commit 4228fe8
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 6 deletions.
92 changes: 89 additions & 3 deletions torch/_inductor/fx_passes/group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ def list_group_batch_fusions(pre_grad=True) -> List[str]:
def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any:
unsqueezed_inputs = []
for input_tensor in input_tensors:
unsqueezed_input = graph.call_function(aten.unsqueeze, args=(input_tensor, 0))
unsqueezed_input = graph.call_function(
aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0}
)
unsqueezed_inputs.append(unsqueezed_input)
stacked_inputs = graph.call_function(
aten.cat,
args=(unsqueezed_inputs, 0),
aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0}
)
return stacked_inputs

Expand Down Expand Up @@ -276,6 +277,79 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
graph.erase_node(original_mm)


class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
"""
Batch pointwise operator (e.g., add, mul) in post grad pass.
"""

def __init__(self, op, **kwargs):
super().__init__(op, **kwargs)
self.op = op

def _pointwise_node_can_be_fused(self, node: torch.fx.Node):
# note: we only consider the case where the inputs are tensors
# for mixed precision training, we need to make sure the inputs
# of the aten.cat when do the stack should be the same dtype
# otherwise, the output of the aten.cat may be not the same as
# its inputs, and cause dtype not same error in mm or addmm
input, other = node.args
return (
input.meta["tensor_meta"].shape == other.meta["tensor_meta"].shape
if hasattr(input, "meta")
and hasattr(other, "meta")
and "dtype" in input.meta
and "dtype" in other.meta
and "tensor_meta" in input.meta
and "tensor_meta" in other.meta
else False
)

def match(self, node: torch.fx.Node):
if CallFunctionVarArgs(self.op).match(
node
) and self._pointwise_node_can_be_fused(node):
alpha = node.kwargs.get("alpha", 1.0)
input, other = node.args
shape = list(input.meta["tensor_meta"].shape)
group_key = (
"batch_" + self.op.__name__.lower() + "_post_grad",
str(shape),
str(input.meta["dtype"]),
str(other.meta["dtype"]),
str(alpha),
)
else:
group_key = None
return group_key

def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
batch_inputs, batch_others = [], []
alpha = subset[0].kwargs.get("alpha", 1.0)

for node in subset:
input, other = node.args
batch_inputs.append(input)
batch_others.append(other)

with graph.inserting_before(subset[0]):
stack_inputs = decompose_stack(graph, batch_inputs)
stack_others = decompose_stack(graph, batch_others)

batch_op = graph.call_function(
self.op,
args=(stack_inputs, stack_others),
kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {},
)
for i, original_add in enumerate(subset):
with graph.inserting_after(batch_op):
new_add = graph.call_function(
torch.ops.aten.select, args=((batch_op, 0, i))
)
original_add.replace_all_uses_with(new_add)
new_add.meta.update(original_add.meta)
graph.erase_node(original_add)


@register_fusion("batch_linear_lhs")
class BatchLinearLHSFusion(BatchFusion):
"""
Expand Down Expand Up @@ -638,6 +712,18 @@ def __init__(self, **kwargs):
super().__init__(torch.nn.functional.relu, **kwargs)


@register_fusion("batch_aten_add", pre_grad=False)
class BatchAddPostGradFusion(BatchPointwiseOpsPostGradFusion):
def __init__(self, **kwargs):
super().__init__(aten.add.Tensor, **kwargs)


@register_fusion("batch_aten_mul", pre_grad=False)
class BatchMulPostGradFusion(BatchPointwiseOpsPostGradFusion):
def __init__(self, **kwargs):
super().__init__(aten.mul.Tensor, **kwargs)


def find_independent_subset_greedy(
node_list: List[torch.fx.Node],
graph_search_options: Dict[str, Any],
Expand Down
6 changes: 4 additions & 2 deletions torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,16 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
if config.pattern_matcher:
lazy_init()

print_graph(gm.graph, "Before group batch fusion in post grad pass.")
group_batch_fusion_passes(gm.graph, pre_grad=False)
print_graph(gm.graph, "After group batch fusion in post grad pass.")
remove_noop_ops(gm.graph)
print_graph(gm.graph, "Before split cat in post grad pass.")
for patterns in pass_patterns:
patterns.apply(gm.graph)
print_graph(
gm.graph,
f"Apply split cat pattern matcher {patterns.__class__.__name__} in post grad.",
"Apply split cat pattern matcher PatternMatcherPass in post grad.",
)
if is_inference:
inference_patterns.apply(gm.graph)
Expand All @@ -110,7 +112,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
gm.recompile()
gm.graph.lint()

print_graph(gm.graph, "Aftre recompile in post grad pass.")
print_graph(gm.graph, "After recompile in post grad pass.")


@init_once_fakemode
Expand Down
8 changes: 7 additions & 1 deletion torch/_inductor/fx_passes/pre_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,23 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
else:
gm = fuse_fx(gm, example_inputs)
numpy_compat_normalization(gm.graph)
print_graph(gm.graph, "Before group batch fusion in pre grad pass.")
group_batch_fusion_passes(gm.graph, pre_grad=True)
print_graph(gm.graph, "Before split cat in pre grad pass.")
for pattern_matcher_pass in pattern_matcher_passes:
pattern_matcher_pass.apply(gm.graph)
print_graph(
gm.graph,
"Apply split cat pattern matcher PatternMatcherPass in pre grad.",
)

if config.pre_grad_custom_pass is not None:
config.pre_grad_custom_pass(gm.graph)
stable_topological_sort(gm.graph)
gm.graph.lint()
gm.recompile()

print_graph(gm.graph, "Aftre recompile in pre grad pass.")
print_graph(gm.graph, "After recompile in pre grad pass.")

return gm

Expand Down

0 comments on commit 4228fe8

Please sign in to comment.