Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class InsertInt32CastsAfterInt64PlaceholdersPass(ArmPass):
# Key: op overload; Value: zero-based indices of positional args that must be i64.
I64_INPUT_ARG_POSITIONS = {
torch.ops.aten.one_hot.default: (0,),
torch.ops.aten.index_copy_.default: (2,),
torch.ops.aten.index_copy.default: (2,),
}

def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
import torch
from executorch.backends.arm._passes import InsertInt32CastsAfterInt64PlaceholdersPass

from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
from executorch.backends.arm.test.tester.test_pipeline import (
PassPipeline,
TosaPipelineINT,
)

input_t = Tuple[torch.Tensor, torch.Tensor] # weights, indices
input_t3 = Tuple[torch.Tensor, torch.LongTensor, torch.Tensor]


class Int64InputModel(torch.nn.Module):
Expand Down Expand Up @@ -44,3 +48,67 @@ def test_int64_model_tosa_FP():
)
pipeline.pop_stage(-1) # Do not compare output
pipeline.run()


class UpcastToInt64ForIndexCopyInplaceModel(torch.nn.Module):
aten_op = "torch.ops.aten.index_copy_.default"

def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.tensor):
return x.index_copy_(0, index, y)

def get_inputs(self) -> input_t3:
return (
torch.zeros(5, 3),
torch.tensor([0, 4, 2]),
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float),
)


def test_upcast_to_int64_for_index_copy_inplace_tosa_INT():
module = UpcastToInt64ForIndexCopyInplaceModel()
pipeline = TosaPipelineINT[input_t3](
module,
module.get_inputs(),
aten_op=module.aten_op,
)
pipeline.pop_stage("check.quant_nodes")
pipeline.change_args(
"check_count.exir",
{
"torch.ops.higher_order.executorch_call_delegate": 0,
},
)
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()


class UpcastToInt64ForIndexCopyModel(torch.nn.Module):
aten_op = "torch.ops.aten.index_copy.default"

def forward(self, x: torch.Tensor, index: torch.LongTensor, y: torch.tensor):
return x.index_copy(0, index, y)

def get_inputs(self) -> input_t3:
return (
torch.zeros(5, 3),
torch.tensor([0, 4, 2]),
torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float),
)


def test_upcast_to_int64_for_index_copy_tosa_INT():
module = UpcastToInt64ForIndexCopyModel()
pipeline = TosaPipelineINT[input_t3](
module,
module.get_inputs(),
aten_op=module.aten_op,
)
pipeline.pop_stage("check.quant_nodes")
pipeline.change_args(
"check_count.exir",
{
"torch.ops.higher_order.executorch_call_delegate": 0,
},
)
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()
Loading