Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 30 additions & 80 deletions backends/cadence/aot/tests/test_simplify_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
from executorch.backends.cadence.aot.compiler import export_to_edge
from executorch.backends.cadence.aot.graph_builder import single_op_builder
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.backends.cadence.aot.simplify_ops import (
Expand Down Expand Up @@ -40,82 +39,47 @@ def test_simplify_slice_scatter_op(
end: Optional[int] = None,
step: int = 1,
):
class SliceScatter(torch.nn.Module):
def __init__(
self, dim: int, start: Optional[int], end: Optional[int], step: int
):
super().__init__()
self.dim = dim
self.start = start
self.end = end
self.step = step

def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.slice_scatter(
x, y, self.dim, self.start, self.end, self.step
)

model = SliceScatter(dim, start, end, step)
x = torch.randn(in_shape)
y = torch.randn(src_shape)
graph_module = export_to_edge(model, (x, y)).exported_program().graph_module

p = SimplifySliceOpPass()

graph_after_passes = cast(PassResult, p(graph_module)).graph_module

self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.slice_scatter.default), 0
x = torch.randn(*in_shape)
y = torch.randn(*src_shape)
gm = single_op_builder(
placeholders=(x, y),
op=exir_ops.edge.aten.slice_scatter.default,
args=(x, y, dim, start, end, step),
)
p = SimplifySliceOpPass()
gm = cast(PassResult, p(gm)).graph_module
self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_scatter.default), 0)

@parameterized.expand(
[
[(3, 16, 5), (3, 0, 5), 1, 15, 3, 3],
[(3, 16, 5), 1, 15, 3, 3],
]
)
@torch.no_grad()
def test_simplify_slice_op(
self,
in_shape: Tuple[int],
src_shape: Tuple[int],
dim: int,
start: Optional[int] = None,
end: Optional[int] = None,
step: int = 1,
):
class SliceCopy(torch.nn.Module):
def __init__(
self, dim: int, start: Optional[int], end: Optional[int], step: int
):
super().__init__()
self.dim = dim
self.start = start
self.end = end
self.step = step

def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.slice_copy(
x, dim=self.dim, start=self.start, end=self.end, step=self.step
)

# Create a model with single slice copy op.
model = SliceCopy(dim, start, end, step)
x = torch.randn(in_shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.slice_copy.Tensor), 1
x = torch.randn(*in_shape)
gm = single_op_builder(
placeholders=(x,),
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(
x,
dim,
start,
end,
step,
),
)

p = SimplifySliceOpPass()

graph_after_passes = cast(PassResult, p(graph_module)).graph_module

self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1
)
gm = cast(PassResult, p(gm)).graph_module
self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 0)
self.assertEqual(count_node(gm, exir_ops.edge.aten.full.default), 1)

def test_simplify_slice_op_args(self) -> None:
x = torch.rand(4, 5)
Expand All @@ -125,24 +89,10 @@ def test_simplify_slice_op_args(self) -> None:
args=(x, 1),
kwargs={"end": 3},
)
self.assertEqual(
[
(n.args[1:], n.kwargs)
for n in gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
],
[((1,), {"end": 3})],
)

original_slice_copy = list(gm.graph.nodes)[1]
self.assertEqual(original_slice_copy.args[1:], (1,))
self.assertEqual(original_slice_copy.kwargs, {"end": 3})
gm = BindOptionalArgsPass().call(gm).graph_module

self.assertEqual(
[
(n.args[1:], n.kwargs)
for n in gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
],
[((1, None, 3, 1), {})],
)
modified_slice_copy = list(gm.graph.nodes)[1]
self.assertEqual(modified_slice_copy.args[1:], (1, None, 3, 1))
self.assertEqual(modified_slice_copy.kwargs, {})
Loading