Skip to content

Commit e05dad0

Browse files
IvanKobzarevetaf
authored andcommitted
partitioner option to ignore partitioner_tag for abstract usage (#166725)
Partitioner functionality is appealing to use in different scenarios (E.g. Autoparallel) We have special logic about "partitioner_tag" from meta that is only needed for forward/backward split. Adding optional argument to avoid it and do only generic split based on inputs/outputs. Potentially we want to make `_extract_graph_with_inputs_outputs` without underscore :) Pull Request resolved: #166725 Approved by: https://github.com/bdhirsh
1 parent ea92e64 commit e05dad0

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

torch/_functorch/partitioners.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def _extract_graph_with_inputs_outputs(
180180
outputs: list[fx.Node],
181181
outputs_descs: list[AOTOutput],
182182
subgraph: Optional[str] = None,
183+
ignore_must_be_in_fw_bw: bool = False,
183184
) -> fx.Graph:
184185
"""
185186
Given a graph, extracts out a subgraph that takes the specified nodes as
@@ -203,13 +204,22 @@ def _extract_graph_with_inputs_outputs(
203204
env[node] = new_node
204205

205206
for node in joint_graph.nodes:
206-
if _must_be_in_backward(node) and subgraph != "backward" and node not in inputs:
207-
env[node] = InvalidNode # type: ignore[assignment]
208-
continue
207+
if not ignore_must_be_in_fw_bw:
208+
if (
209+
_must_be_in_backward(node)
210+
and subgraph != "backward"
211+
and node not in inputs
212+
):
213+
env[node] = InvalidNode # type: ignore[assignment]
214+
continue
209215

210-
if _must_be_in_forward(node) and subgraph != "forward" and node not in inputs:
211-
env[node] = InvalidNode # type: ignore[assignment]
212-
continue
216+
if (
217+
_must_be_in_forward(node)
218+
and subgraph != "forward"
219+
and node not in inputs
220+
):
221+
env[node] = InvalidNode # type: ignore[assignment]
222+
continue
213223

214224
if node in env:
215225
# Node must be one of our inputs. (Any member of env which wasn't an

0 commit comments

Comments
 (0)