Skip to content

Commit

Permalink
[Inductor][fx pass] Remove split nodes with split section size one
Browse files Browse the repository at this point in the history
Summary: We observe that DSNN has many split nodes with split section size one, which hinder the split cat merge in the later pass, thus we remove such nodes in the early stage.

Test Plan:
# local reproduce with DSNN model
```
buck2 run mode/opt //scripts/jackiexu0313/pt2:local_model_with_pt2 -- --test_mode split_batch-group -c
```
P872705076
diffing: https://www.internalfb.com/intern/diffing/?paste_number=872698775

# unit test

```
buck2 test mode/dev-nosan //caffe2/test/inductor:split_cat_fx_passes
```
Buck UI: https://www.internalfb.com/buck2/b248410e-a556-47a2-9293-7f113b49f0d6
Test UI: https://www.internalfb.com/intern/testinfra/testrun/10696049124469023
Network: Up: 80KiB  Down: 47KiB  (reSessionID-a31dec17-d322-4757-ba84-4d262bd139cf)
Jobs completed: 24. Time elapsed: 1:52.8s.
Cache hits: 0%. Commands: 2 (cached: 0, remote: 0, local: 2)
Tests finished: Pass 9. Fail 0. Fatal 0. Skip 0. Build failure 0

Differential Revision: D50990290
  • Loading branch information
mengluy authored and facebook-github-bot committed Nov 3, 2023
1 parent 63fc482 commit 268cd80
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 6 deletions.
12 changes: 6 additions & 6 deletions test/inductor/test_split_cat_fx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,17 @@ def cm_with_list(x):
]
for fn, expected_split_norm_count in [
(arg_only, 1),
(arg_only_dim0, 1),
(arg_only_dim0, 0),
(kwarg1, 1),
(kwarg2, 1),
(kwarg3, 1),
(list_replace, 1),
(multi_split, 17),
(multi_split, 1),
(unequal_split, 1),
(arg_only_cm, 1),
(kwarg1_cm, 1),
(kwarg2_cm, 1),
(multi_split_cm, 17),
(multi_split_cm, 1),
(unequal_split_cm, 1),
(cm_with_list, 1),
]:
Expand Down Expand Up @@ -226,12 +226,12 @@ def split_getitem_out_of_order(x):
torch.randn(2, 32),
]
for fn, expected_split_merged in [
(multi_split, 16),
(multi_split, 0),
(multi_split_2, 16),
(multi_split_2_neg_dim, 16),
(multi_split_with_sizes, 2),
(multi_split_kwarg1, 16),
(multi_split_kwarg2, 16),
(multi_split_kwarg1, 0),
(multi_split_kwarg2, 0),
(unequal_multi_split, 3),
(unequal_multi_split_neg_index, 3),
(diff_dims, 0),
Expand Down
44 changes: 44 additions & 0 deletions torch/_inductor/fx_passes/split_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,46 @@ def _get_split_args_default(split_node):
)


# noqa: W605
# ############The pattern to be optimized is#########
# unbind (dim=0)
# / ... \
# getitem getitem -> user=1
# | |
# split split -> dim=1, user=1, split_section_size=1
# | |
# getitem getitem -> user=1
# \ /
# cat (dim=1) -> user=1
# |

# ################After transformation#############
# unbind (dim=0)
# / ... \
# getitem getitem -> user=1
# \ /
# cat (dim=1) -> user=1
# |


def remove_split_with_size_one(
graph: torch.fx.Graph,
node: torch.fx.Node,
input: torch.fx.Node,
):
# find the grand children of the split_node
next_users = find_next_users(node)
user = next(iter(node.users.keys()))
# replace the users of grand child node with the input node
for next_user in next_users:
next_user.replace_input_with(user, input)
# erase the split node and its child
graph.erase_node(user)
graph.erase_node(node)

counters["inductor"]["remove_split_with_size_one"] += 1


def normalize_split_base(
match: Match,
_get_split_args: Callable[
Expand All @@ -85,6 +125,10 @@ def normalize_split_base(
if any(isinstance(section, torch.SymInt) for section in split_sections):
# TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing.
return
# remove the dummy split whose split sections size is one
if len(split_sections) == 1:
remove_split_with_size_one(graph, split_node, split_input)
return
if split_dim < 0: # Normalize split dim
split_dim += split_input.meta["example_value"].dim()
with graph.inserting_after(split_node):
Expand Down

0 comments on commit 268cd80

Please sign in to comment.