Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 0 additions & 36 deletions test/inductor/test_split_cat_fx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def patch(f):
"normalization_pass": {},
"remove_split_with_size_one_pass": {},
"merge_getitem_cat_pass": {},
"merge_stack_tahn_unbind_pass": {},
"merge_splits_pass": {},
"mutate_cat_pass": {},
"split_cat_pass": {},
Expand Down Expand Up @@ -1269,41 +1268,6 @@ def unbind_cat_to_view(x):
)
counters.clear()

@patch
def test_stack_tahn_unbind_merge(self):
def stack_tahn_unbind(x):
l1_out = torch.split(x, [20, 20, 20, 10, 10, 20, 20], 1)
item0 = l1_out[0]
item1 = l1_out[1]
item2 = l1_out[2]
item3 = l1_out[3]
item4 = l1_out[4]
item5 = l1_out[5]
item6 = l1_out[6]
stack = torch.stack(tensors=(item0, item1, item2), dim=0)
cat_1 = torch.cat((item3, item4), 1)
cat_2 = torch.cat((item5, item6), 1)
tanh = torch.tanh(stack)
unbind = torch.unbind(tanh, 0)
return torch.cat((unbind[0], unbind[1], torch.cat((cat_1, cat_2), 1)), 1)

args = [
torch.randn(50, 120),
]
for fn, expected_stack_tahn_unbind_merged in [
(stack_tahn_unbind, 1),
]:
expected = fn(*args)
actual = torch.compile(fn)(*args)

torch.testing.assert_close(actual, expected)
self.assertEqual(
counters["inductor"]["merge_stack_tahn_unbind_pass"],
expected_stack_tahn_unbind_merged,
)
self.assertIn("merge_getitem_cat_pass_pre_grad", optimus_scuba_log)
counters.clear()

def test_numpy_compat_normalization(self):
def fn(x, y):
a = torch.stack([x, y], axis=1)
Expand Down
6 changes: 6 additions & 0 deletions torch/_inductor/fx_passes/group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,9 +921,15 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):
self.op,
args=(stack_inputs,),
)
if is_node_meta_valid(batch_op):
batch_op.meta["example_value"] = self.op(batch_op.meta["example_value"])
unbind_op = graph.call_function(
torch.unbind, args=(batch_op,), kwargs={"dim": 0}
)
if is_node_meta_valid(unbind_op):
unbind_op.meta["example_value"] = torch.unbind(
unbind_op.meta["example_value"], dim=0
)
for i, node in enumerate(batch_nodes):
with graph.inserting_after(unbind_op):
getitem = graph.call_function(operator.getitem, args=(unbind_op, i))
Expand Down
Loading