Skip to content

Commit

Permalink
Allow fuse unsqueeze cat sum with multiple input (#68650)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #68650

Allow fuse unsqueeze cat sum with >2 input, the impl in this diff is naive, just concat item with add. Not sure can have more perf gain with fuse multiple add into one operation.

Test Plan: unit test

Reviewed By: jfix71

Differential Revision: D32520135

fbshipit-source-id: 535b1c8c91e415d5f1af714378b9205c1ca02ffd
  • Loading branch information
Shirong Wu authored and pull[bot] committed Dec 28, 2022
1 parent 45d2aae commit 4b3d2b0
Showing 1 changed file with 16 additions and 11 deletions.
Expand Up @@ -81,20 +81,25 @@ def fuse_unsqueeze_cat_sum(gm: torch.fx.GraphModule):
if node.target != acc_ops.sum:
continue
prev_node = node.kwargs["input"]
if prev_node.target != acc_ops.cat or len(prev_node.kwargs["tensors"]) != 2:
if prev_node.target != acc_ops.cat or prev_node.kwargs["dim"] != 0:
continue
lhs, rhs = prev_node.kwargs["tensors"][0], prev_node.kwargs["tensors"][1]
if lhs.target != acc_ops.unsqueeze or rhs.target != acc_ops.unsqueeze:
continue
lhs_input = lhs.kwargs["input"]
rhs_input = rhs.kwargs["input"]
# prerequisite check
cond1 = lhs.kwargs["dim"] == 0 and rhs.kwargs["dim"] == 0
cond2 = prev_node.kwargs["dim"] == 0
if not cond1 or not cond2:
cat_inputs = prev_node.kwargs["tensors"]
valid_pass = True
for i in cat_inputs:
if i.target != acc_ops.unsqueeze or i.kwargs["dim"] != 0:
valid_pass = False
break

if not valid_pass:
continue
input_val = [i.kwargs["input"] for i in cat_inputs]

with gm.graph.inserting_before(node):
fused_node = gm.graph.call_function(acc_ops.add, kwargs={"input": lhs_input, "other": rhs_input})
left = input_val[0]
for i in range(1, len(input_val)):
right = input_val[i]
fused_node = gm.graph.call_function(acc_ops.add, kwargs={"input": left, "other": right})
left = fused_node
node.replace_all_uses_with(fused_node)

gm.graph.eliminate_dead_code()
Expand Down

0 comments on commit 4b3d2b0

Please sign in to comment.