Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions backends/aoti/passes/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.python_library(
name = "passes",
srcs = [
"replace_view_copy_with_view.py",
],
visibility = [
"//executorch/...",
],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
# This pass replaces view_copy ops with view ops. This is different than
# exir/passes/replace_view_copy_with_view.py and exir/passes/reinplace.py
# because this should only be used in the AOTInductor backend, as it
# has less restrictions on whether the tensor memory is densely packed,

from typing import Dict, Iterable, Tuple
from typing import Dict, Iterable

import torch
from executorch.exir.dialects._ops import ops
Expand All @@ -15,33 +18,30 @@
from torch import fx


_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = (
torch.ops.aten.slice_copy.Tensor,
ops.edge.aten.slice_copy.Tensor,
)

_SLICE_TARGETS: Dict[
_VIEW_TARGETS: Dict[
torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload
] = {
torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor,
ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor,
torch.ops.aten.select_copy.int: torch.ops.aten.select.int,
ops.edge.aten.select_copy.int: ops.edge.aten.select.int,
}


class ReplaceSliceCopyWithSlicePass(ExportPass):
"""Replace non-mutated ``slice_copy`` results with ``slice`` views."""
class ReplaceViewCopyWithViewPass(ExportPass):
"""Replace non-mutated ``view_copy`` type of ops with ``view`` ops."""

def call(self, graph_module: fx.GraphModule) -> PassResult:
graph_changed = False

for node in graph_module.graph.nodes:
if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS:
if node.op != "call_function" or node.target not in _VIEW_TARGETS:
continue

if self._has_blocking_user(node, node.users.keys()):
continue

node.target = _SLICE_TARGETS[node.target]
node.target = _VIEW_TARGETS[node.target]
graph_changed = True

if graph_changed:
Expand Down
6 changes: 3 additions & 3 deletions backends/apple/metal/metal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from typing import Any, Dict, final, List, Optional, Set

import torch
from executorch.backends.apple.metal.replace_slice_copy_with_slice import (
ReplaceSliceCopyWithSlicePass,
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
ReplaceViewCopyWithViewPass,
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir._warnings import experimental
Expand Down Expand Up @@ -93,7 +93,7 @@ def preprocess(
mps_edge_program = move_to_device_pass(edge_program, "mps")

# replace slice_copy with slice
ReplaceSliceCopyWithSlicePass()(mps_edge_program.graph_module)
ReplaceViewCopyWithViewPass()(mps_edge_program.graph_module)

edge_program_module = mps_edge_program.module()

Expand Down
2 changes: 1 addition & 1 deletion backends/cuda/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ runtime.python_library(
name = "cuda_backend",
srcs = [
"cuda_backend.py",
"replace_slice_copy_with_slice.py",
],
visibility = [
"//executorch/...",
],
deps = [
"//caffe2:torch",
"//executorch/backends/aoti/passes:passes",
"//executorch/exir/_serialize:lib",
"//executorch/exir/backend:backend_details",
"//executorch/exir/backend:compile_spec_schema",
Expand Down
8 changes: 4 additions & 4 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from typing import Any, Dict, final, List, Optional, Set

import torch
from executorch.backends.cuda.replace_slice_copy_with_slice import (
ReplaceSliceCopyWithSlicePass,
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
ReplaceViewCopyWithViewPass,
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir._warnings import experimental
Expand Down Expand Up @@ -123,8 +123,8 @@ def preprocess(
# Move the edge_program from CPU to CUDA for aoti compile
cuda_edge_program = move_to_device_pass(edge_program, "cuda")

# replace slice_copy with slice
ReplaceSliceCopyWithSlicePass()(cuda_edge_program.graph_module)
# replace slice_copy.Tensor with slice.Tensor, select_copy.int with select.int
ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module)

cuda_edge_program = cuda_edge_program.run_decompositions(
cuda_decomposition_table
Expand Down
118 changes: 0 additions & 118 deletions backends/cuda/replace_slice_copy_with_slice.py

This file was deleted.

Loading