Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add a pre-AOT lowering pass to remove detach ops #2756

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch_tensorrt.dynamo.lowering import (
apply_lowering_passes,
get_decompositions,
remove_detach,
repair_input_aliasing,
)
from torch_tensorrt.dynamo.utils import (
Expand Down Expand Up @@ -74,6 +75,7 @@ def _pretraced_backend(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
repair_input_aliasing(gm)
remove_detach(gm)

# Invoke AOTAutograd to translate operators to aten
gm = aot_export_joint_simple(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/lowering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
)
from ._decompositions import get_decompositions # noqa: F401
from ._fusers import * # noqa: F401
from ._remove_detach import remove_detach
from ._repair_input_aliasing import repair_input_aliasing
from .passes import apply_lowering_passes
32 changes: 32 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_remove_detach.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def remove_detach(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Remove detach ops from the graph
"""
modified_graph = False

for node in gm.graph.nodes:
# If the node is a detach node
if len(node.users) == 1 and list(node.users)[0].target == "detach":
modified_graph = True
detach_node = list(node.users)[0]
logger.debug(
f"Removing node {detach_node} from the graph. It is a detach node with a single user."
)
detach_node.replace_all_uses_with(node)
gm.graph.erase_node(detach_node)

if modified_graph:
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Removed detach nodes:\n{gm.graph}")

return gm
29 changes: 29 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,35 @@ def forward(self, x):
f"Select_scatter TRT outputs don't match with the original model.",
)

def test_lowering_detach_removal(self):
class Detach(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x):
y = torch.ops.aten.detach.default(x) + 1
return y

# Operations expected to be removed
unexpected_ops = {torch.ops.aten.detach.default}

inputs = [
torch.rand(
5,
),
]

fx_graph = torch.fx.symbolic_trace(Detach())
unexpected_ops_seen, _ = lower_graph_testing(
fx_graph, inputs, unexpected_ops=unexpected_ops, min_block_size=1
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)


if __name__ == "__main__":
run_tests()
Loading