Skip to content

Commit

Permalink
[inductor] post_grad batched linear fusion (pytorch#112504)
Browse files Browse the repository at this point in the history
Summary:

Fusing independent nn.Linear() functions with aten.bmm and aten.cat.

Test Plan:
Without the BMM fusion:
```
buck2 run mode/opt //pytorch/benchmark:run -- test_module -d cuda --module test_linear_module --torchdynamo inductor --torchinductor_cudagraph 0 --torchinductor_batch_fusion 0
```
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/torchbench_test_module_20231030_072536_6535183793.json.gz&bucket=pyper_traces

100 aten::mm operators

With the BMM fusion:
```
buck2 run mode/opt //pytorch/benchmark:run -- test_module -d cuda --module test_linear_module --torchdynamo inductor --torchinductor_cudagraph 0 --torchinductor_batch_fusion 1
```

20 aten::bmm operators

https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/torchbench_test_module_20231030_072157_6535183793.json.gz&bucket=pyper_traces

Passes accuracy test:
```
$ buck2 run mode/opt //pytorch/benchmark:run -- test_module -d cuda --module test_linear_module --torchdynamo inductor --torchinductor_cudagraph 0 --torchinductor_batch_fusion 1 --accuracy
Running eval method from test_module on cuda in dynamo inductor mode with input batch size 4 and precision tf32.
Accuracy:                            pass
```
Looks like the bmm and input cat has been fused successfully.

Checking the triton codegen:

```
TORCH_LOGS=+dynamo,+aot,+inductor buck2 run mode/opt //pytorch/benchmark:run -- test_module -d cuda --module test_linear_module --torchdynamo inductor --torchinductor_cudagraph 0 --torchinductor_batch_fusion 1 --dump_triton 1
```

Triton code dump: https://www.internalfb.com/intern/everpaste/?handle=GHp1ABaqYuTjYCUBALiTWmteaI1PbsIXAAAB

Differential Revision: D46910718
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Oct 31, 2023
1 parent 481a7a9 commit 8107494
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 1 deletion.
35 changes: 35 additions & 0 deletions test/inductor/test_group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,5 +435,40 @@ def test_batch_relu_pre_grad_fusion(self):
counters.clear()


class TestBMMFusionModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.my_modules = torch.nn.ModuleList()
for _ in range(10):
self.my_modules.append(torch.nn.Linear(10, 10))

def forward(self, inputs):
output = None
for linear, input in zip(self.my_modules, inputs):
if output is None:
output = linear(input)
else:
output += linear(input)
return output


@requires_cuda()
@torch._inductor.config.patch(post_grad_batch_fusion=True)
class TestBatchLinearPostGradFusion(TestCase):
def test_batch_linear_post_grad_fusion(self):
pt1_module = TestBMMFusionModule().cuda()
inputs = []
for _ in range(10):
inputs.append(torch.randn(10, 10).cuda())
eager_output = pt1_module(inputs)
pt2_module = torch.compile(pt1_module)
pt2_output = pt2_module(inputs)
self.assertTrue(torch.allclose(eager_output, pt2_output))
self.assertEqual(
counters["inductor"]["post_grad_batch_fusion"],
2,
)


if __name__ == "__main__":
run_tests()
3 changes: 3 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@
# enable pattern match with batch fusion (using torch op)
batch_fusion = True

# enable post-grad pattern match with batch fusion (using torch op)
post_grad_batch_fusion = False

# enable reordering pass for improving memory locality
reorder_for_locality = True

Expand Down
89 changes: 88 additions & 1 deletion torch/_inductor/fx_passes/group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,87 @@ class BatchFusion(GroupBatchFusionBase):

pass

class BatchLinearPostGradFusion(GroupBatchFusionBase):
"""
Fuse ops in a batch way in post grad (aten level).
"""
def _decompose_stack(
self, graph: torch.fx.GraphModule, input_tensors: List[Any]
) -> Any:
unsqueezed_inputs = []
for input_tensor in input_tensors:
unsqueezed_input = graph.call_function(
aten.unsqueeze, args=(input_tensor, 0)
)
unsqueezed_inputs.append(unsqueezed_input)
stacked_inputs = graph.call_function(
aten.cat,
args=(unsqueezed_inputs, 0),
)
return stacked_inputs

def _addmm_node_can_be_fused(self, node):
return (
node.kwargs.get("beta", 1.0) == 1.0 and node.kwargs.get("alpha", 1.0) == 1.0
)

def match(self, node):
if CallFunctionVarArgs(aten.mm).match(node):
input_m, weight_m = node.args
bias_m = None
elif CallFunctionVarArgs(aten.addmm.default).match(
node
) and self._addmm_node_can_be_fused(node):
bias_m, input_m, weight_m = node.args
else:
return None

m, k = input_m.meta["tensor_meta"].shape
n = weight_m.meta["tensor_meta"].shape[1]
batch_key = ("batch_linear", m, k, n, bias_m is not None)
return batch_key

def fuse(self, graph, subset):
batch_inputs = []
batch_weights = []
batch_biases = []
batch_nodes = []

for node in subset:
if CallFunctionVarArgs(aten.addmm.default).match(node):
bias, input, weight = node.args
elif CallFunctionVarArgs(aten.mm.default).match(node):
input, weight = node.args
bias = None
batch_nodes.append(node)
batch_inputs.append(input)
batch_weights.append(weight)
batch_biases.append(bias)

with graph.inserting_before(subset[-1]):
fused_inputs = self._decompose_stack(graph, batch_inputs)
fused_weights = self._decompose_stack(graph, batch_weights)
fused_bmm = graph.call_function(
torch.ops.aten.bmm,
args=(fused_inputs, fused_weights),
)

for i, original_mm in enumerate(batch_nodes):
has_bias = False
with graph.inserting_after(fused_bmm):
new_mm = graph.call_function(
torch.ops.aten.select, args=((fused_bmm, 0, i))
)
if batch_biases[i]:
has_bias = True
new_bias_add = graph.call_function(
torch.ops.aten.add, args=((batch_biases[i], new_mm))
)
new_mm_cont = new_bias_add if has_bias else new_mm
original_mm.replace_all_uses_with(new_mm_cont)
new_mm_cont.meta.update(original_mm.meta)
graph.erase_node(original_mm)


class GroupLinearFusion(GroupFusion):
def _addmm_node_can_be_fused(self, node: torch.fx.Node):
Expand Down Expand Up @@ -603,7 +684,8 @@ def get_fusion_candidates(
continue

key = rule.match(node)
if key is not None:
# SymInt is not hashable, so we need to skip it
if key is not None and not isinstance(key, torch.SymInt):
candidate_nodes = candidate_dict[key]
if node not in candidate_nodes:
candidate_nodes.append(node)
Expand Down Expand Up @@ -633,6 +715,8 @@ def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusion
fused_set.update(subset)
if isinstance(rule, GroupFusion):
counters["inductor"]["group_fusion"] += 1
elif isinstance(rule, BatchFusion):
counters["inductor"]["post_grad_batch_fusion"] += 1
else:
counters["inductor"]["batch_fusion"] += 1

Expand All @@ -653,6 +737,9 @@ def group_batch_fusion_post_grad_passes(graph: torch.fx.Graph):
if config.group_fusion and has_fbgemm:
fusions += [GroupLinearFusion()]

if config.post_grad_batch_fusion:
fusions += [BatchLinearPostGradFusion()]

for rule in fusions:
apply_group_batch_fusion(graph, rule)
print_graph(graph, f"Apply fusion {rule.__class__.__name__}.")
Expand Down

0 comments on commit 8107494

Please sign in to comment.