Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ def _insert_placeholder_i64_to_i32_casts(self, graph_module: torch.fx.GraphModul
modified = False
graph = graph_module.graph
for node in graph.nodes:
if node.op != "placeholder":
if node.op not in ("placeholder", "get_attr"):
continue
if "val" not in node.meta:
continue # Ignore submodule get_attrs
node_val = node.meta["val"]
if not self._is_tensor_of_dtype(node_val, torch.int64):
continue
Expand Down
18 changes: 16 additions & 2 deletions backends/arm/test/ops/test_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,9 @@ def test_clamp_vgf_quant(test_data):
pipeline.run()


aten_op_tensor = "torch.ops.aten.clamp.Tensor"
aten_op_tensor = [
"torch.ops.aten.clamp.Tensor",
]
exir_op_tensor = "executorch_exir_dialects_edge__ops_aten_clamp_Tensor"

test_data_suite_tensor_FP = {
Expand Down Expand Up @@ -413,10 +415,22 @@ def test_clamp_tosa_INT_tensor(test_data):
input_tensor, min_val, max_val = test_data()
model = Clamp(min_val, max_val)

# Check that int64 inputs are cast to int32 in the tfa pipeline
if any(
t.dtype == torch.int64
for t in (input_tensor, min_val, max_val)
if isinstance(t, torch.Tensor)
):
aten_op = aten_op_tensor + [
"torch.ops.dim_order_ops._to_dim_order_copy.default"
]
else:
aten_op = aten_op_tensor

pipeline = TosaPipelineINT[input_t](
model,
(input_tensor,),
aten_op_tensor,
aten_op,
exir_op_tensor,
)
pipeline.run()
Expand Down
Loading