Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor][fx pass] Remove split nodes with split section size one #112922

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions test/inductor/test_split_cat_fx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,17 @@ def cm_with_list(x):
]
for fn, expected_split_norm_count in [
(arg_only, 1),
(arg_only_dim0, 1),
(arg_only_dim0, 0),
(kwarg1, 1),
(kwarg2, 1),
(kwarg3, 1),
(list_replace, 1),
(multi_split, 17),
(multi_split, 1),
(unequal_split, 1),
(arg_only_cm, 1),
(kwarg1_cm, 1),
(kwarg2_cm, 1),
(multi_split_cm, 17),
(multi_split_cm, 1),
(unequal_split_cm, 1),
(cm_with_list, 1),
]:
Expand Down Expand Up @@ -226,12 +226,12 @@ def split_getitem_out_of_order(x):
torch.randn(2, 32),
]
for fn, expected_split_merged in [
(multi_split, 16),
(multi_split, 0),
(multi_split_2, 16),
(multi_split_2_neg_dim, 16),
(multi_split_with_sizes, 2),
(multi_split_kwarg1, 16),
(multi_split_kwarg2, 16),
(multi_split_kwarg1, 0),
(multi_split_kwarg2, 0),
(unequal_multi_split, 3),
(unequal_multi_split_neg_index, 3),
(diff_dims, 0),
Expand Down
44 changes: 44 additions & 0 deletions torch/_inductor/fx_passes/split_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,46 @@ def _get_split_args_default(split_node):
)


# noqa: W605
# ############The pattern to be optimized is#########
# unbind (dim=0)
# / ... \
# getitem getitem -> user=1
# | |
# split split -> dim=1, user=1, split_section_size=1
# | |
# getitem getitem -> user=1
# \ /
# cat (dim=1) -> user=1
# |

# ################After transformation#############
# unbind (dim=0)
# / ... \
# getitem getitem -> user=1
# \ /
# cat (dim=1) -> user=1
# |


def remove_split_with_size_one(
graph: torch.fx.Graph,
node: torch.fx.Node,
input: torch.fx.Node,
):
# find the grand children of the split_node
next_users = find_next_users(node)
user = next(iter(node.users.keys()))
# replace the users of grand child node with the input node
for next_user in next_users:
next_user.replace_input_with(user, input)
# erase the split node and its child
graph.erase_node(user)
graph.erase_node(node)

counters["inductor"]["remove_split_with_size_one"] += 1


def normalize_split_base(
match: Match,
_get_split_args: Callable[
Expand All @@ -85,6 +125,10 @@ def normalize_split_base(
if any(isinstance(section, torch.SymInt) for section in split_sections):
# TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing.
return
# remove the dummy split whose split sections size is one
if len(split_sections) == 1:
remove_split_with_size_one(graph, split_node, split_input)
return
if split_dim < 0: # Normalize split dim
split_dim += split_input.meta["example_value"].dim()
with graph.inserting_after(split_node):
Expand Down
Loading