Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor][fx pass] Fuse pointwise operators in the post grad #114778

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
90 changes: 87 additions & 3 deletions torch/_inductor/fx_passes/group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ def list_group_batch_fusions(pre_grad=True) -> List[str]:
def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any:
unsqueezed_inputs = []
for input_tensor in input_tensors:
unsqueezed_input = graph.call_function(aten.unsqueeze, args=(input_tensor, 0))
unsqueezed_input = graph.call_function(
aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0}
)
unsqueezed_inputs.append(unsqueezed_input)
stacked_inputs = graph.call_function(
aten.cat,
args=(unsqueezed_inputs, 0),
aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0}
)
return stacked_inputs

Expand Down Expand Up @@ -276,6 +277,77 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
graph.erase_node(original_mm)


class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
"""
Batch pointwise operator (e.g., add, mul) in post grad pass.
"""

def __init__(self, op, **kwargs):
super().__init__(op, **kwargs)
self.op = op

def _pointwise_node_can_be_fused(self, node: torch.fx.Node):
# note: we only consider the case where the inputs are tensors
# for mixed precision training, we need to make sure the inputs
# of the aten.cat when do the stack should be the same dtype
# otherwise, the output of the aten.cat may be not the same as
# its inputs, and cause dtype not same error in mm or addmm
input, other = node.args
return (
input.meta["tensor_meta"].shape == other.meta["tensor_meta"].shape
if hasattr(input, "meta")
and hasattr(other, "meta")
and "tensor_meta" in input.meta
and "tensor_meta" in other.meta
else False
)

def match(self, node: torch.fx.Node):
if CallFunctionVarArgs(self.op).match(
node
) and self._pointwise_node_can_be_fused(node):
alpha = node.kwargs.get("alpha", 1.0)
input, other = node.args
shape = list(input.meta["tensor_meta"].shape)
group_key = (
"batch_" + self.op.__name__.lower() + "_post_grad",
str(shape),
str(input.meta["tensor_meta"].dtype),
str(other.meta["tensor_meta"].dtype),
str(alpha),
)
else:
group_key = None
return group_key

def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
batch_inputs, batch_others = [], []
alpha = subset[0].kwargs.get("alpha", 1.0)

for node in subset:
input, other = node.args
batch_inputs.append(input)
batch_others.append(other)

with graph.inserting_before(subset[0]):
stack_inputs = decompose_stack(graph, batch_inputs)
stack_others = decompose_stack(graph, batch_others)

batch_op = graph.call_function(
self.op,
args=(stack_inputs, stack_others),
kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {},
)
for i, original_add in enumerate(subset):
with graph.inserting_after(batch_op):
new_add = graph.call_function(
torch.ops.aten.select, args=((batch_op, 0, i))
)
original_add.replace_all_uses_with(new_add)
new_add.meta.update(original_add.meta)
graph.erase_node(original_add)


@register_fusion("batch_linear_lhs")
class BatchLinearLHSFusion(BatchFusion):
"""
Expand Down Expand Up @@ -638,6 +710,18 @@ def __init__(self, **kwargs):
super().__init__(torch.nn.functional.relu, **kwargs)


@register_fusion("batch_aten_add", pre_grad=False)
class BatchAddPostGradFusion(BatchPointwiseOpsPostGradFusion):
def __init__(self, **kwargs):
super().__init__(aten.add.Tensor, **kwargs)


@register_fusion("batch_aten_mul", pre_grad=False)
class BatchMulPostGradFusion(BatchPointwiseOpsPostGradFusion):
def __init__(self, **kwargs):
super().__init__(aten.mul.Tensor, **kwargs)


def find_independent_subset_greedy(
node_list: List[torch.fx.Node],
graph_search_options: Dict[str, Any],
Expand Down
6 changes: 4 additions & 2 deletions torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,16 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
if config.pattern_matcher:
lazy_init()

print_graph(gm.graph, "Before group batch fusion in post grad pass.")
group_batch_fusion_passes(gm.graph, pre_grad=False)
print_graph(gm.graph, "After group batch fusion in post grad pass.")
remove_noop_ops(gm.graph)
print_graph(gm.graph, "Before split cat in post grad pass.")
for patterns in pass_patterns:
patterns.apply(gm.graph)
print_graph(
gm.graph,
f"Apply split cat pattern matcher {patterns.__class__.__name__} in post grad.",
"Apply split cat pattern matcher PatternMatcherPass in post grad.",
)
if is_inference:
inference_patterns.apply(gm.graph)
Expand All @@ -110,7 +112,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
gm.recompile()
gm.graph.lint()

print_graph(gm.graph, "Aftre recompile in post grad pass.")
print_graph(gm.graph, "After recompile in post grad pass.")


@init_once_fakemode
Expand Down
8 changes: 7 additions & 1 deletion torch/_inductor/fx_passes/pre_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,23 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
else:
gm = fuse_fx(gm, example_inputs)
numpy_compat_normalization(gm.graph)
print_graph(gm.graph, "Before group batch fusion in pre grad pass.")
group_batch_fusion_passes(gm.graph, pre_grad=True)
print_graph(gm.graph, "Before split cat in pre grad pass.")
for pattern_matcher_pass in pattern_matcher_passes:
pattern_matcher_pass.apply(gm.graph)
print_graph(
gm.graph,
"Apply split cat pattern matcher PatternMatcherPass in pre grad.",
)

if config.pre_grad_custom_pass is not None:
config.pre_grad_custom_pass(gm.graph)
stable_topological_sort(gm.graph)
gm.graph.lint()
gm.recompile()

print_graph(gm.graph, "Aftre recompile in pre grad pass.")
print_graph(gm.graph, "After recompile in pre grad pass.")

return gm

Expand Down