Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exir/passes/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ fbcode_target(_kind = runtime.python_library,
"propagate_device_pass.py",
],
deps = [
":device_copy_ops_registry",
"//caffe2:torch",
"//executorch/exir:delegate",
"//executorch/exir:lowered_backend_module",
Expand Down
207 changes: 156 additions & 51 deletions exir/passes/propagate_device_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@

# pyre-strict

import copy
import logging
import operator
from typing import Optional

# Import to register the et_copy ops so torch.ops.et_copy is available.
import executorch.exir.passes._device_copy_ops_registry # noqa: F401

import executorch.exir.schema as schema

import torch
Expand Down Expand Up @@ -124,23 +129,150 @@ def _tag_specs_with_device(
return False


def _clone_spec_with_device(
spec: TensorSpec,
device_type: schema.DeviceType,
device_index: int = 0,
) -> TensorSpec:
"""Create a copy of a TensorSpec with a different device."""
new_spec = copy.copy(spec)
new_spec.init_mem_planning_fields()
_set_device_on_spec(new_spec, device_type, device_index)
return new_spec


class PropagateDevicePass(PassBase):
"""
After to_backend, walk the graph and set device metadata on TensorSpecs
based on partitioner-assigned delegation info.

Rules:
1. Delegated nodes: Input and output tensors of a delegate call are marked
with the target device derived from the delegate's CompileSpec
(key="target_device").
2. Non-delegated nodes: Remain on CPU (default).
3. Getitem nodes that extract from a delegate call inherit the device from
the delegate call's output spec at the corresponding index.
After to_backend, walk the graph and insert H2D/D2H copy ops at delegate
boundaries based on partitioner-assigned device info.

When a delegate has a target_device CompileSpec (e.g., "cuda:0"):
- For each delegate input: insert et_copy._h2d_copy before the delegate call.
The original input node stays CPU; the h2d_copy output is tagged as device.
- For each delegate output: insert et_copy._d2h_copy after each getitem.
The getitem stays device; the d2h_copy output is tagged as CPU.
- Getitem nodes that extract from a delegate call inherit the device.

Skip-copy optimizations:
- skip_h2d_for_method_inputs: If the input is a graph-level placeholder
feeding directly to a delegate, don't insert H2D — tag the placeholder
as device instead (user provides GPU tensor at runtime).
- skip_d2h_for_method_outputs: If the getitem feeds directly to graph
output, don't insert D2H — the output stays on device.
"""

def __init__(
self,
) -> None:
super().__init__()

def _is_placeholder(self, node: torch.fx.Node) -> bool:
"""Check if a node is a graph-level input (placeholder)."""
return node.op == "placeholder"

def _feeds_directly_to_output(self, node: torch.fx.Node) -> bool:
"""Check if all users of a node are output nodes."""
return all(user.op == "output" for user in node.users)

def _insert_h2d_copies(
self,
graph_module: torch.fx.GraphModule,
node: torch.fx.Node,
target_device_type: schema.DeviceType,
device_index: int,
) -> bool:
"""Insert H2D copy nodes for each tensor input to a delegate call."""
changed = False
new_args = list(node.args)
for i, arg in enumerate(node.args[1:], start=1):
if not isinstance(arg, torch.fx.Node):
continue
arg_spec = arg.meta.get("spec")
if not isinstance(arg_spec, TensorSpec):
continue

with graph_module.graph.inserting_before(node):
h2d_node = graph_module.graph.call_function(
torch.ops.et_copy._h2d_copy.default,
(arg,),
)
h2d_spec = _clone_spec_with_device(
arg_spec, target_device_type, device_index
)
h2d_node.meta["spec"] = h2d_spec
h2d_node.meta["val"] = arg.meta.get("val")
if "tensor_meta" in arg.meta:
h2d_node.meta["tensor_meta"] = arg.meta["tensor_meta"]
new_args[i] = h2d_node
changed = True

node.args = tuple(new_args)
return changed

def _insert_d2h_for_getitem(
self,
graph_module: torch.fx.GraphModule,
node: torch.fx.Node,
) -> bool:
"""If *node* is a getitem extracting from a delegate call, tag its spec
with the delegate device and insert a D2H copy after it."""
source_node = node.args[0]
if not (
isinstance(source_node, torch.fx.Node)
and source_node.op == "call_function"
and source_node.target == executorch_call_delegate
):
return False

spec = node.meta.get("spec")
source_specs = source_node.meta.get("spec")
idx = node.args[1]
if not (
isinstance(spec, TensorSpec)
and isinstance(source_specs, (tuple, list))
and isinstance(idx, int)
and idx < len(source_specs)
):
return False

source_spec = source_specs[idx]
if not isinstance(source_spec, TensorSpec):
return False

_set_device_on_spec(spec, source_spec.device, source_spec.device_index)

with graph_module.graph.inserting_after(node):
d2h_node = graph_module.graph.call_function(
torch.ops.et_copy._d2h_copy.default,
(node,),
)
d2h_spec = _clone_spec_with_device(spec, schema.DeviceType.CPU, 0)
d2h_node.meta["spec"] = d2h_spec
d2h_node.meta["val"] = node.meta.get("val")
if "tensor_meta" in node.meta:
d2h_node.meta["tensor_meta"] = node.meta["tensor_meta"]

node.replace_all_uses_with(
d2h_node,
delete_user_cb=lambda user, _d2h=d2h_node: user != _d2h,
)
return True

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
# Two-pass approach:
# Pass 1 – For each delegate with a target_device CompileSpec, insert
# H2D copy nodes before delegate inputs and tag the delegate
# output specs with the target device. Delegates without a
# target_device are left untouched (no copies, specs stay CPU).
# Pass 2 – For each getitem that extracts from a device-tagged delegate
# (tracked in device_delegates), propagate the device onto the
# getitem spec and insert a D2H copy after it so downstream
# non-delegated ops receive CPU tensors.
changed = False
for node in graph_module.graph.nodes:
device_delegates: set[torch.fx.Node] = set()

# Pass 1: insert H2D copies and tag delegate output specs.
for node in list(graph_module.graph.nodes):
if node.op == "call_function" and node.target == executorch_call_delegate:
lowered_module = _get_lowered_module(graph_module, node)
if lowered_module is None:
Expand All @@ -151,18 +283,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
continue

target_device_type, device_index = result
device_delegates.add(node)

changed |= self._insert_h2d_copies(
graph_module, node, target_device_type, device_index
)

# Tag delegate input tensors.
# args[0] is the get_attr node for the lowered module; skip it.
for arg in node.args[1:]:
if isinstance(arg, torch.fx.Node):
changed |= _tag_specs_with_device(
arg.meta.get("spec"),
target_device_type,
device_index,
)

# Tag delegate output tensors.
changed |= _tag_specs_with_device(
node.meta.get("spec"),
target_device_type,
Expand All @@ -177,34 +303,13 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
lowered_module.backend_id,
)

# Second pass: propagate device through getitem nodes that extract
# individual outputs from a delegate call.
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target.__name__ == "getitem":
source_node = node.args[0]
if (
isinstance(source_node, torch.fx.Node)
and source_node.op == "call_function"
and source_node.target == executorch_call_delegate
):
spec = node.meta.get("spec")
source_specs = source_node.meta.get("spec")
idx = node.args[1]
if (
spec is not None
and isinstance(spec, TensorSpec)
and source_specs is not None
and isinstance(source_specs, (tuple, list))
and isinstance(idx, int)
and idx < len(source_specs)
):
source_spec = source_specs[idx]
if isinstance(source_spec, TensorSpec):
_set_device_on_spec(
spec,
source_spec.device,
source_spec.device_index,
)
changed = True
# Second pass: propagate device through getitem nodes and insert D2H
# only for delegates that have a target_device.
for node in list(graph_module.graph.nodes):
if node.op == "call_function" and node.target == operator.getitem:
source = node.args[0]
if isinstance(source, torch.fx.Node) and source in device_delegates:
changed |= self._insert_d2h_for_getitem(graph_module, node)

graph_module.recompile()
return PassResult(graph_module, changed)
1 change: 1 addition & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ python_unittest(
"//executorch/exir/backend/test:backend_with_compiler_demo",
"//executorch/exir/dialects:lib",
"//executorch/exir/passes:propagate_device_pass",
"//executorch/exir/passes:device_copy_ops_registry",
],
)

Expand Down
Loading
Loading