Skip to content

Commit

Permalink
[Inductor][fx pass] Fuse pointwise operators in the post grad (#114778)
Browse files Browse the repository at this point in the history
Summary:

We construct a unified API that can be easily add pointwise ops to be batched in the post grad

Test Plan:
# unit test
```
buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:group_batch_fusion
```
Buck UI: https://www.internalfb.com/buck2/19b3f641-782f-4f94-a953-3ff9ce2cfa7b
Test UI: https://www.internalfb.com/intern/testinfra/testrun/1125900251953016
Network: Up: 67KiB  Down: 32KiB  (reSessionID-c2a80f26-8227-4f78-89fc-bcbda0ae8353)
Jobs completed: 18. Time elapsed: 1:19.8s.
Cache hits: 0%. Commands: 2 (cached: 0, remote: 0, local: 2)
Tests finished: Pass 6. Fail 0. Fatal 0. Skip 0. Build failure 0
# local reproduce
### cmf
P887605070
### igctr
P892987433
### mai
P893109069
### icvr
P893075846
### oc
P893109069

Reviewed By: xuzhao9

Differential Revision: D51332067
  • Loading branch information
mengluy authored and facebook-github-bot committed Nov 30, 2023
1 parent 597d3fb commit 1adc1a0
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 15 deletions.
14 changes: 8 additions & 6 deletions test/inductor/test_group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,13 @@ def forward(self, x):
sigmoid_2 = [torch.sigmoid(tanh_2[i]) for i in range(len(tanh_2))]
relu_1 = [torch.nn.functional.relu(sigmoid_1[i]) for i in range(len(sigmoid_1))]
relu_2 = [torch.nn.functional.relu(sigmoid_2[i]) for i in range(len(sigmoid_2))]
return torch.cat(relu_1, dim=1) + torch.cat(relu_2, dim=1)
tanh = [x + y for x, y in zip(relu_1, relu_2)]
sigmoid = [x * x for x in tanh]
return torch.cat(sigmoid, dim=1)


@requires_cuda()
@torch._inductor.config.patch(post_grad_fusion_options={"group_linear": {}})
@torch._inductor.config.patch(post_grad_group_fusion_options={"group_linear": {}})
class TestGroupBatchFusion(TestCase):
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
Expand Down Expand Up @@ -263,7 +265,7 @@ def test_group_linear_fusion(self):
)
self.assertEqual(
counters["inductor"]["batch_fusion"],
0,
2,
)
ref.sum().backward()
res.sum().backward()
Expand All @@ -275,7 +277,7 @@ def test_group_linear_fusion(self):
)
self.assertEqual(
counters["inductor"]["batch_fusion"],
0,
5,
)
counters.clear()

Expand Down Expand Up @@ -306,7 +308,7 @@ def test_group_linear_fusion_different_shapes(self):
)
self.assertEqual(
counters["inductor"]["batch_fusion"],
0,
1,
)
counters.clear()

Expand Down Expand Up @@ -397,7 +399,7 @@ def test_pointwise_op_pre_grad_fusion(self):
ref = module(*input)
res = traced(*input)
self.compare_pred(module, traced, input)
self.assertEqual(counters["inductor"]["batch_fusion"], 3)
self.assertEqual(counters["inductor"]["batch_fusion"], 5)
self.assertEqual(
counters["inductor"]["scmerge_split_removed"],
0,
Expand Down
2 changes: 1 addition & 1 deletion test/inductor/test_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def f(a, b, c, d, e):
return torch.cat((a + 1, b + 2, c + 3, d + 4, e + 5)) + 10

inp = [T(10, 10) for _ in range(5)]
self.assertExpectedInline(count_numel(f, *inp), """2000""")
self.assertExpectedInline(count_numel(f, *inp), """4000""")

def f(a, b):
return torch.cat([a.sum(dim=0), b.sum(dim=0)]) + 10
Expand Down
6 changes: 5 additions & 1 deletion torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,12 @@

# Post grad group/batch fusion and options, set to empty dict to disable fusion.
# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions.
post_grad_fusion_options: Dict[str, Dict[str, Any]] = {}
post_grad_group_fusion_options: Dict[str, Dict[str, Any]] = {}

post_grad_batch_fusion_options: Dict[str, Dict[str, Any]] = {
"batch_aten_add": {},
"batch_aten_mul": {},
}
# enable reordering pass for improving memory locality
reorder_for_locality = True

Expand Down
160 changes: 153 additions & 7 deletions torch/_inductor/fx_passes/group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,135 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
graph.erase_node(original_mm)


def decompose_stack(graph: torch.fx.GraphModule, input_tensors: List[Any]) -> Any:
unsqueezed_inputs = []
for input_tensor in input_tensors:
unsqueezed_input = graph.call_function(
aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0}
)
unsqueezed_inputs.append(unsqueezed_input)
stacked_inputs = graph.call_function(
aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0}
)
return stacked_inputs


def convert_scalar_to_tensor(
graph: torch.fx.GraphModule, node: torch.fx.Node, input, shape
) -> Optional[torch.fx.Node]:
"""
Convert scalar to tensor.
"""
# we dont have a functional way right now of instantiating a non-contiguous
# tensor with full/zeros/ones right now hasn't shown up to be important yet
fake_tensor = node.meta["val"]
if not fake_tensor.is_contiguous(memory_format=torch.contiguous_format):
return None
return graph.call_function(
aten.full.default,
args=(shape, input),
kwargs={
"dtype": fake_tensor.dtype,
"layout": torch.strided,
"device": fake_tensor.device,
"pin_memory": False,
},
)


class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
"""
Batch pointwise operator (e.g., add, mul) in post grad pass.
"""

def __init__(self, op, **kwargs):
super().__init__(op, **kwargs)
self.op = op

def _pointwise_node_can_be_fused(self, node: torch.fx.Node):
# input and other could be scala
input, other = node.args
# input is a torch tensor
if isinstance(input, torch.fx.Node) and not isinstance(other, torch.fx.Node):
return True if "tensor_meta" in input.meta else False
# other is a torch tensor
if not isinstance(input, torch.fx.Node) and isinstance(other, torch.fx.Node):
return True if "tensor_meta" in other.meta else False
# both are scalars
if not isinstance(input, torch.fx.Node) and not isinstance(
other, torch.fx.Node
):
return True
# both are torch tensors
return (
input.meta["tensor_meta"].shape == other.meta["tensor_meta"].shape
if "tensor_meta" in input.meta and "tensor_meta" in other.meta
else False
)

def match(self, node: torch.fx.Node):
if CallFunctionVarArgs(self.op).match(
node
) and self._pointwise_node_can_be_fused(node):
alpha = node.kwargs.get("alpha", 1.0)
input, other = node.args
if hasattr(input, "meta"):
shape = list(input.meta["tensor_meta"].shape)
elif hasattr(other, "meta"):
shape = list(other.meta["tensor_meta"].shape)
else:
shape = [1, 1]
group_key = (
"batch_" + self.op.__name__.lower() + "_post_grad",
str(shape),
str(alpha),
)
else:
group_key = None
return group_key

def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
batch_inputs, batch_others = [], []
alpha = subset[0].kwargs.get("alpha", 1.0)

for node in subset:
input, other = node.args
if not isinstance(input, torch.fx.Node) and not isinstance(
other, torch.fx.Node
):
input = convert_scalar_to_tensor(graph, node, input, [1, 1])
other = convert_scalar_to_tensor(graph, node, other, [1, 1])
elif not isinstance(input, torch.fx.Node):
input = convert_scalar_to_tensor(
graph, node, input, list(other.meta["tensor_meta"].shape)
)
elif not isinstance(other, torch.fx.Node):
other = convert_scalar_to_tensor(
graph, node, other, list(input.meta["tensor_meta"].shape)
)

batch_inputs.append(input)
batch_others.append(other)

with graph.inserting_before(subset[0]):
stack_inputs = decompose_stack(graph, batch_inputs)
stack_others = decompose_stack(graph, batch_others)

batch_op = graph.call_function(
self.op,
args=(stack_inputs, stack_others),
kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {},
)
for i, node in enumerate(subset):
with graph.inserting_after(batch_op):
getitem = graph.call_function(
torch.ops.aten.select, args=((batch_op, 0, i))
)
node.replace_all_uses_with(getitem)
getitem.meta.update(node.meta)
graph.erase_node(node)


@register_fusion("batch_linear_lhs")
class BatchLinearLHSFusion(BatchFusion):
"""
Expand Down Expand Up @@ -549,6 +678,18 @@ def __init__(self, **kwargs):
super().__init__(torch.nn.functional.relu, **kwargs)


@register_fusion("batch_aten_add", pre_grad=False)
class BatchAddPostGradFusion(BatchPointwiseOpsPostGradFusion):
def __init__(self, **kwargs):
super().__init__(aten.add.Tensor, **kwargs)


@register_fusion("batch_aten_mul", pre_grad=False)
class BatchMulPostGradFusion(BatchPointwiseOpsPostGradFusion):
def __init__(self, **kwargs):
super().__init__(aten.mul.Tensor, **kwargs)


def find_independent_subset_greedy(
node_list: List[torch.fx.Node],
graph_search_options: Dict[str, Any],
Expand Down Expand Up @@ -621,7 +762,8 @@ def get_fusion_candidates(
continue

key = rule.match(node)
if key is not None:
# SymInt is not hashable, so we need to skip it
if key is not None and not isinstance(key, torch.SymInt):
candidate_nodes = candidate_dict[key]
if node not in candidate_nodes:
candidate_nodes.append(node)
Expand Down Expand Up @@ -677,18 +819,22 @@ def generate_fusion_from_config(config_options: Dict[str, Any], pre_grad=True):


def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
print_graph(graph, "Before group_batch fusion in post grads pass.")
print_graph(graph, "Before group_batch fusion in pre grad pass.")
fusions: List[GroupBatchFusionBase] = []

if pre_grad:
fusions = generate_fusion_from_config(
fusions += generate_fusion_from_config(
config.pre_grad_fusion_options, pre_grad=True
)
elif has_fbgemm: # Only group fusion (which needs fbgemm) in post grad.
fusions = generate_fusion_from_config(
config.post_grad_fusion_options, pre_grad=False
else:
# Only batch fusion (which doesn't need fbgemm) in post grad.
fusions += generate_fusion_from_config(
config.post_grad_batch_fusion_options, pre_grad=False
)

if has_fbgemm: # Only group fusion (which needs fbgemm) in post grad.
fusions += generate_fusion_from_config(
config.post_grad_group_fusion_options, pre_grad=False
)
for rule in fusions:
apply_group_batch_fusion(graph, rule)
print_graph(graph, f"Apply fusion {rule.__class__.__name__}.")

0 comments on commit 1adc1a0

Please sign in to comment.