Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 6 additions & 18 deletions torch/_dynamo/graph_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,11 @@ def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None
inds_unique.add(ind)


def _copy_nodes_and_remap_inputs(
subgraph: torch.fx.Graph, region: Region
) -> list[OrderedSet[UsageIndex]]:
def _create_subgraph(
region: Region,
inds_with_external_users: list[int],
) -> tuple[torch.fx.Graph, list[OrderedSet[UsageIndex]]]:
subgraph: torch.fx.Graph = torch.fx.Graph()
external_input_to_usages = _get_external_inputs(region)
external_node_usages = list[OrderedSet[UsageIndex]]()
region_to_subgraph_node = {}
Expand All @@ -241,24 +243,10 @@ def map_arg(node: Node) -> Node:
subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
region_to_subgraph_node[node] = subgraph_node

return external_node_usages


def _create_subgraph_outputs(
subgraph: torch.fx.Graph, inds_to_output: list[int]
) -> None:
node_list = [n for n in subgraph.nodes if n.op not in ("placeholder", "output")]
out_tup = tuple(node_list[ind] for ind in inds_to_output)
out_tup = tuple(node_list[ind] for ind in inds_with_external_users)
subgraph.output(out_tup)


def _create_subgraph(
region: Region,
inds_with_external_users: list[int],
) -> tuple[torch.fx.Graph, list[OrderedSet[UsageIndex]]]:
subgraph: torch.fx.Graph = torch.fx.Graph()
external_node_usages = _copy_nodes_and_remap_inputs(subgraph, region)
_create_subgraph_outputs(subgraph, inds_with_external_users)
return subgraph, external_node_usages


Expand Down
Loading