From 6c26ea4c06ea990eaece341415df609fb51d1b96 Mon Sep 17 00:00:00 2001 From: Yufeng Shi Date: Wed, 21 May 2025 17:05:33 +0100 Subject: [PATCH] Arm backend: Add passes to handle int64 const and int64 output ops - Add ConvertInt64ConstOpsToInt32Pass to convert constant-producing ops that output int64 to instead output int32, when values are within int32 bounds. Supported Ops: `torch.full`, `torch.arange`, `torch.eye`, `torch.linspace`, `torch.tensor` - Add ConvertInt64OutputOpsToInt32Pass to 1. convert or remove unnecessary casts to int64 2. insert an int64->int32 cast after the argmax ndoes that produce int64 outputs Signed-off-by: Yufeng Shi Change-Id: I04e5fa9a7170c5b5dc785ae8619189545de0ec2c Co-authored-by: Erik Lundell --- backends/arm/README.md | 34 +- backends/arm/_passes/__init__.py | 2 + backends/arm/_passes/arm_pass_manager.py | 8 + .../convert_int64_const_ops_to_int32.py | 74 +++ .../convert_int64_output_ops_to_int32.py | 153 ++++++ .../test_CLIPTextModelWithProjection.py | 22 +- .../stable_diffusion/test_T5EncoderModel.py | 26 +- .../test_vae_AutoencoderKL.py | 2 +- backends/arm/test/models/test_conformer.py | 4 +- .../test_convert_int64_const_ops_to_int32.py | 511 ++++++++++++++++++ .../test_convert_int64_output_ops_to_int32.py | 131 +++++ 11 files changed, 940 insertions(+), 27 deletions(-) create mode 100644 backends/arm/_passes/convert_int64_const_ops_to_int32.py create mode 100644 backends/arm/_passes/convert_int64_output_ops_to_int32.py create mode 100644 backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py create mode 100644 backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py diff --git a/backends/arm/README.md b/backends/arm/README.md index e2e49c0c10f..f2cc365ab62 100644 --- a/backends/arm/README.md +++ b/backends/arm/README.md @@ -195,6 +195,38 @@ List of model specific and optional passes: - InsertCastForOpsWithInt64InputPass - Functionality: - For LLMs such as LLama, some opeartors like aten.embedding have int64 input. In order to lower these operators to TOSA, this pass will insert a casting node that converts the input from int64 to int32. - - Example usage: backends/arm/test/models/test_llama.py - Supported Ops: - aten.embedding.default, aten.slice_copy.Tensor + - Example usage: + - backends/arm/test/models/test_llama.py + +- ConvertInt64ConstOpsToInt32Pass + - Functionalities: + - Rewrites constant-producing ops that output int64 to instead output int32, when values are within int32 bounds. + - Supported Ops: + - `torch.full`, `torch.arange`, `torch.eye`, `torch.linspace`, `torch.tensor` + - Example usage: + - backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py + - backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py + +- ConvertInt64OutputOpsToInt32Pass + - Overview: + - Rewrites or removes operations that produce int64 outputs, converting them to int32 where possible. + - Overflow checks are applied selectively; for ops without such checks, users need to ensure values fit within the int32 range. + - Functionalities: + 1. Handling casting to int64: + - (1) int32 -> int64: + - Removes the cast and redirect uses of int64 to int32 + - (2) other types -> int64: + - Rewrites the cast to other types -> int32 + - Supported Ops: + - torch.ops.aten.to.\[dtype|dtype_layout\] + - exir_ops.edge.dim_order_ops._to_dim_order_copy.default + 2. Post-process argmax outputs: + - Inserts an int64->int32 cast after the argmax operations that produce int64 outputs: + - Supported Ops: + - torch.ops.aten.argmax.default + - exir_ops.edge.aten.argmax.default + - Example usage: + - (Functionality 1) backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py + - (Functionality 2) backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 4d2449f946c..a728f894ee5 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -16,6 +16,8 @@ from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa +from .convert_int64_const_ops_to_int32 import ConvertInt64ConstOpsToInt32Pass # noqa +from .convert_int64_output_ops_to_int32 import ConvertInt64OutputOpsToInt32Pass # noqa from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa from .convert_minmax_pass import ConvertMinMaxPass # noqa from .convert_split_to_slice import ConvertSplitToSlicePass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 7592be1d7da..086ab1435c0 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -20,6 +20,8 @@ ConvertAnyDefaultDimDimsPass, ConvertExpandCopyToRepeatPass, ConvertFullLikeToFullPass, + ConvertInt64ConstOpsToInt32Pass, + ConvertInt64OutputOpsToInt32Pass, ConvertIntPowToMuls, ConvertMinMaxPass, ConvertMmToBmmPass, @@ -98,6 +100,7 @@ from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass from executorch.exir import ExportedProgram from executorch.exir.pass_manager import PassManager +from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass from torch.fx import GraphModule @@ -258,6 +261,11 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram): ) def transform_for_annotation_pipeline(self, graph_module: GraphModule): + self.add_pass( + RemoveGraphAssertsPass() + ) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph + self.add_pass(ConvertInt64ConstOpsToInt32Pass()) + self.add_pass(ConvertInt64OutputOpsToInt32Pass()) self.add_pass(InsertCastForOpsWithInt64InputPass()) self.add_pass(DecomposeEmbeddingPass()) self.add_pass(DecomposeScaledDotProductAttention()) diff --git a/backends/arm/_passes/convert_int64_const_ops_to_int32.py b/backends/arm/_passes/convert_int64_const_ops_to_int32.py new file mode 100644 index 00000000000..704c89dbd78 --- /dev/null +++ b/backends/arm/_passes/convert_int64_const_ops_to_int32.py @@ -0,0 +1,74 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + + +import logging + +import torch +from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT +from executorch.exir.pass_base import ExportPass, PassResult + + +logger = logging.getLogger(__name__) +INT32_MIN = torch.iinfo(torch.int32).min +INT32_MAX = torch.iinfo(torch.int32).max + + +class ConvertInt64ConstOpsToInt32Pass(ExportPass): + """ + Rewrite constant ops that produce int64 to int32 where safe. + + List of supported operatos: + 1. `torch.full` + 2. `torch.arange` + 3. `torch.eye` + 4. `torch.linspace` + 5. `torch.tensor` + """ + + torch_ops = [ + torch.ops.aten.full.default, + torch.ops.aten.arange.default, + torch.ops.aten.arange.start, + torch.ops.aten.arange.start_step, + torch.ops.aten.eye.default, + torch.ops.aten.linspace.default, + ] + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + + if node.target not in ComputeConstantOpsAOT.targeted_ops + self.torch_ops: + continue + + data = node.target(*node.args, **node.kwargs) + if data.dtype is not torch.int64: + continue + + min_val, max_val = torch.min(data), torch.max(data) + if INT32_MIN <= min_val and max_val <= INT32_MAX: + logger.warning( + f"Casting {node.name} from torch.int64 to torch.int32" + f" defined in {node.meta.get('stack_trace','[no stack trace found]')}" + ) + node.update_kwarg("dtype", torch.int32) + modified = True + else: + logger.warning( + f"[{node.name}] has values: min={min_val}, max={max_val}, which exceeds int32 range " + f"([{INT32_MIN}, {INT32_MAX}]); not converting dtype to int32." + ) + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/convert_int64_output_ops_to_int32.py b/backends/arm/_passes/convert_int64_output_ops_to_int32.py new file mode 100644 index 00000000000..d3803c82ffc --- /dev/null +++ b/backends/arm/_passes/convert_int64_output_ops_to_int32.py @@ -0,0 +1,153 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + + +import logging + +import torch +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, + set_node_arg, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +logger = logging.getLogger(__name__) + + +class ConvertInt64OutputOpsToInt32Pass(ExportPass): + """ + Rewrites or removes operations that produce int64 outputs, converting them + to int32 where possible. + + + Currently, this pass handles casting and argmax operators: + 1. int32 -> int64: + removes the cast and redirects all uses to the original int32 value. + 2. other types -> int64: + rewrites the cast to produce int32 instead of int64. + 3. torch.argmax() + insert an int64->int32 cast after the argmax node + + Future extensions may include operators that return int64 outputs by default + (e.g., `argmin`), rewriting them or inserting an int64 -> int32 cast to yield + int32 results. + + Note: Overflow checks are applied selectively in this pass. For operators without + such checks, it is the user's responsibility to ensure that values fit within + the int32 range. + """ + + aten_cast_ops = ( + torch.ops.aten.to.dtype, + torch.ops.aten.to.dtype_layout, + ) + edge_cast_ops = (exir_ops.edge.dim_order_ops._to_dim_order_copy.default,) + + aten_argmax_ops = (torch.ops.aten.argmax.default,) + edge_argmax_ops = (exir_ops.edge.aten.argmax.default,) + + aten_ops = aten_cast_ops + aten_argmax_ops + edge_ops = edge_cast_ops + edge_argmax_ops + + # dtype is specified in args + cast_ops_args = ( + torch.ops.aten.to.dtype, # to_2: node.args: (gt, torch.int64) node.kwargs: {} + ) + # dtype is specified in kwargs + cast_ops_kwargs = ( + torch.ops.aten.to.dtype_layout, # to_1: node.args: (unsqueeze,) node.kwargs: {'dtype': torch.int64, 'layout': torch.strided, 'device': device(type='cpu')} + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, # node.args: (aten_gt_scalar,) node.kwargs: {'dtype': torch.int64, 'dim_order': [0, 1]} + ) + + def _get_decomposition(self, op): + if op in self.edge_ops: + return exir_ops.edge.aten._to_copy.default + + if op in self.aten_ops: + return torch.ops.aten._to_copy.default + + raise RuntimeError( + f"[{self.__class__.__name__}] Can't get decomposition for op {op}" + ) + + def _convert_casting_operators(self, node: torch.fx.Node): + input_node = node.all_input_nodes[0] + input_dtype = get_first_fake_tensor(input_node).dtype + # Case 1: int32 -> int64 - removes the ops + if input_dtype == torch.int32: + users = [user for user in node.users if node != user] + for user in users: + logger.warning( + f"Removing int32->int64 casting node {node.name} defined in" + f" {node.meta.get('stack_trace','[no stack trace found]')}" + ) + user.replace_input_with(node, input_node) + # Case 2: other types -> int64 - rewrites to cast to int32 + else: + if node.target in self.cast_ops_kwargs: + set_node_arg(node, "dtype", torch.int32) + elif node.target in self.cast_ops_args: + set_node_arg(node, 1, torch.int32) + else: + raise RuntimeError(f"Unexpected target {node.target} in {node.name}") + output_dtype = get_first_fake_tensor(node).dtype + logger.warning( + f"Converting casting node {node.name} from {input_dtype}->{output_dtype} to" + f" {input_dtype}->torch.int32 defined in {node.meta.get('stack_trace','[no stack trace found]')}" + ) + + def _convert_argmax_operators(self, node: torch.fx.Node, graph: torch.fx.Graph): + output_tensor = node + to_copy_op = self._get_decomposition(node.target) + with graph.inserting_after(node): + cast_after = create_node( + graph, + to_copy_op, + args=(output_tensor,), + kwargs={ + "dtype": torch.int32, + }, + ) + users = [user for user in node.users if user != cast_after] + for user in users: + user.replace_input_with(output_tensor, cast_after) + logger.warning( + f"Inserting a casting node {cast_after.name} after {node.name} to cast int64 output" + f" to int32 for {node.name} defined in {node.meta.get('stack_trace','[no stack trace found]')}" + ) + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + graph = graph_module.graph + for node in list(graph.nodes): + if node.op != "call_function": + continue + if node.target not in self.aten_ops + self.edge_ops: + continue + output_dtype = get_first_fake_tensor(node).dtype + if output_dtype != torch.int64: + continue + + if node.target in self.aten_cast_ops + self.edge_cast_ops: + self._convert_casting_operators(node) + elif node.target in self.aten_argmax_ops + self.edge_argmax_ops: + # TODO: Add range check based on the input tensor shape before casting the output + self._convert_argmax_operators(node, graph) + else: + raise RuntimeError(f"Unexpected target {node.target} in {node.name}") + + modified = True + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py index 9561e2132ee..f89e06deda0 100644 --- a/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py +++ b/backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py @@ -7,7 +7,11 @@ import unittest import torch -from executorch.backends.arm._passes import InsertCastForOpsWithInt64InputPass +from executorch.backends.arm._passes import ( + ConvertInt64ConstOpsToInt32Pass, + ConvertInt64OutputOpsToInt32Pass, + InsertCastForOpsWithInt64InputPass, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import ( @@ -28,13 +32,11 @@ class TestCLIPTextModelWithProjection(unittest.TestCase): # for that is some assert ops are removed by passes in the # .to_executorch step, i.e. after Arm partitioner. ops_after_partitioner = { - "executorch_exir_dialects_edge__ops_aten__to_copy_default": 3, + "executorch_exir_dialects_edge__ops_aten__to_copy_default": 4, "executorch_exir_dialects_edge__ops_aten_argmax_default": 1, - "executorch_exir_dialects_edge__ops_aten_index_Tensor": 1, - "executorch_exir_dialects_edge__ops_aten_lt_Tensor": 1, - "executorch_exir_dialects_edge__ops_aten_view_copy_default": 2, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, - "torch.ops.higher_order.executorch_call_delegate": 3, + "torch.ops.higher_order.executorch_call_delegate": 2, } def _prepare_inputs( @@ -60,7 +62,7 @@ def prepare_model_and_inputs(self): return text_encoder_model, text_encoder_model_inputs - def test_CLIPTextModelWithProjection_tosa_MI(self): + def test_CLIPTextModelWithProjection_tosa_FP(self): text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs() with torch.no_grad(): ( @@ -68,7 +70,11 @@ def test_CLIPTextModelWithProjection_tosa_MI(self): text_encoder_model, example_inputs=text_encoder_model_inputs, compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), - transform_passes=[InsertCastForOpsWithInt64InputPass()], + transform_passes=[ + InsertCastForOpsWithInt64InputPass(), + ConvertInt64ConstOpsToInt32Pass(), + ConvertInt64OutputOpsToInt32Pass(), + ], ) .export() .to_edge_transform_and_lower() diff --git a/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py b/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py index 0567d32eebb..0628f010f08 100644 --- a/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py +++ b/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py @@ -7,7 +7,11 @@ import unittest import torch -from executorch.backends.arm._passes import InsertCastForOpsWithInt64InputPass +from executorch.backends.arm._passes import ( + ConvertInt64ConstOpsToInt32Pass, + ConvertInt64OutputOpsToInt32Pass, + InsertCastForOpsWithInt64InputPass, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import ( @@ -29,19 +33,7 @@ class TestT5EncoderModel(unittest.TestCase): # .to_executorch step, i.e. after Arm partitioner. ops_after_partitioner = { "executorch_exir_dialects_edge__ops_aten__to_copy_default": 2, - "executorch_exir_dialects_edge__ops_aten_abs_default": 1, - "executorch_exir_dialects_edge__ops_aten_add_Tensor": 3, - "executorch_exir_dialects_edge__ops_aten_arange_start_step": 2, - "executorch_exir_dialects_edge__ops_aten_full_like_default": 1, - "executorch_exir_dialects_edge__ops_aten_gt_Scalar": 1, - "executorch_exir_dialects_edge__ops_aten_lt_Scalar": 1, - "executorch_exir_dialects_edge__ops_aten_minimum_default": 1, - "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, - "executorch_exir_dialects_edge__ops_aten_sub_Tensor": 1, - "executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default": 2, "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, - "executorch_exir_dialects_edge__ops_aten_where_self": 1, - "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3, "torch.ops.higher_order.executorch_call_delegate": 2, } @@ -68,7 +60,7 @@ def prepare_model_and_inputs(self): return t5_encoder_model, t5_encoder_model_inputs - def test_T5EncoderModel_tosa_MI(self): + def test_T5EncoderModel_tosa_FP(self): t5_encoder_model, t5_encoder_model_inputs = self.prepare_model_and_inputs() with torch.no_grad(): ( @@ -76,7 +68,11 @@ def test_T5EncoderModel_tosa_MI(self): t5_encoder_model, example_inputs=t5_encoder_model_inputs, compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"), - transform_passes=[InsertCastForOpsWithInt64InputPass()], + transform_passes=[ + InsertCastForOpsWithInt64InputPass(), + ConvertInt64ConstOpsToInt32Pass(), + ConvertInt64OutputOpsToInt32Pass(), + ], ) .export() .to_edge_transform_and_lower() diff --git a/backends/arm/test/models/stable_diffusion/test_vae_AutoencoderKL.py b/backends/arm/test/models/stable_diffusion/test_vae_AutoencoderKL.py index cab4ca53d9c..ab0f4892fb8 100644 --- a/backends/arm/test/models/stable_diffusion/test_vae_AutoencoderKL.py +++ b/backends/arm/test/models/stable_diffusion/test_vae_AutoencoderKL.py @@ -41,7 +41,7 @@ def forward(self, *args, **kwargs): return auto_encoder_model, auto_encoder_model_inputs - def test_AutoencoderKL_tosa_MI(self): + def test_AutoencoderKL_tosa_FP(self): auto_encoder_model, auto_encoder_model_inputs = self.prepare_model_and_inputs() with torch.no_grad(): ( diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py index 6a66b25d27d..3119145aef1 100644 --- a/backends/arm/test/models/test_conformer.py +++ b/backends/arm/test/models/test_conformer.py @@ -65,7 +65,7 @@ def test_conformer_tosa_INT(): pipeline = TosaPipelineINT[input_t]( TestConformer.conformer, TestConformer.model_example_inputs, - aten_op=TestConformer.aten_ops, + aten_op=[], # RemoveGraphAssertsPass is added in transform_for_annotation_pipeline to remove the assert ops exir_op=[], use_to_edge_transform_and_lower=True, ) @@ -132,7 +132,7 @@ def test_conformer_vgf_INT(): pipeline = VgfPipeline[input_t]( TestConformer.conformer, TestConformer.model_example_inputs, - aten_op=TestConformer.aten_ops, + aten_op=[], # RemoveGraphAssertsPass is added in transform_for_annotation_pipeline to remove the assert ops exir_op=[], tosa_version="TOSA-1.0+INT", use_to_edge_transform_and_lower=True, diff --git a/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py b/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py new file mode 100644 index 00000000000..ddb31625849 --- /dev/null +++ b/backends/arm/test/passes/test_convert_int64_const_ops_to_int32.py @@ -0,0 +1,511 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple, Union + +import pytest + +import torch +from executorch.backends.arm._passes import ( + ConvertInt64ConstOpsToInt32Pass, + ConvertInt64OutputOpsToInt32Pass, +) + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineFP, + TosaPipelineINT, +) + +input_t1 = Tuple[torch.Tensor] # Input x +input_t2 = Tuple[torch.Tensor, torch.Tensor] # Input x, y + + +##################################################### +## Test arange(dtype=int64) -> arange(dtype=int32) ## +##################################################### + + +class ArangeDefaultIncrementViewLessThan(torch.nn.Module): + + def forward(self, x: torch.Tensor): + return (torch.arange(10, dtype=torch.int64) + 1).view(-1, 1) < x + + test_data = { + "randint": ( + torch.randint( + 0, + 10, + (1,), + dtype=torch.int32, + ), + ), + } + + +@common.parametrize("test_data", ArangeDefaultIncrementViewLessThan.test_data) +def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1): + module = ArangeDefaultIncrementViewLessThan() + aten_ops_checks = [ + "torch.ops.aten.lt.Tensor", + "torch.ops.aten.view.default", + ] + exir_ops_checks = [ + "executorch_exir_dialects_edge__ops_aten_lt_Tensor", + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + ] + pipeline = TosaPipelineFP[input_t1]( + module, + test_data, + aten_ops_checks, + exir_ops_checks, + transform_passes=[ConvertInt64ConstOpsToInt32Pass()], + ) + pipeline.run() + + +@common.parametrize("test_data", ArangeDefaultIncrementViewLessThan.test_data) +def test_convert_arange_default_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1): + module = ArangeDefaultIncrementViewLessThan() + aten_ops_checks = [ + "torch.ops.aten.lt.Tensor", + "torch.ops.aten.view.default", + ] + exir_ops_checks = [ + "executorch_exir_dialects_edge__ops_aten_lt_Tensor", + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + ] + pipeline = TosaPipelineINT[input_t1]( + module, + test_data, + aten_ops_checks, + exir_ops_checks, + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + +class ArangeStartIncrementViewLessThan(torch.nn.Module): + + def forward(self, x: torch.Tensor): + return (torch.arange(0, 10, dtype=torch.int64) + 1).view(-1, 1) < x + + test_data = { + "randint": ( + torch.randint( + 0, + 10, + (1,), + dtype=torch.int32, + ), + ), + } + + +@common.parametrize("test_data", ArangeStartIncrementViewLessThan.test_data) +def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1): + module = ArangeStartIncrementViewLessThan() + aten_ops_checks = [ + "torch.ops.aten.lt.Tensor", + "torch.ops.aten.view.default", + ] + exir_ops_checks = [ + "executorch_exir_dialects_edge__ops_aten_lt_Tensor", + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + ] + pipeline = TosaPipelineFP[input_t1]( + module, + test_data, + aten_ops_checks, + exir_ops_checks, + transform_passes=[ConvertInt64ConstOpsToInt32Pass()], + ) + pipeline.run() + + +@common.parametrize("test_data", ArangeStartIncrementViewLessThan.test_data) +def test_convert_arange_start_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1): + module = ArangeStartIncrementViewLessThan() + aten_ops_checks = [ + "torch.ops.aten.lt.Tensor", + "torch.ops.aten.view.default", + ] + exir_ops_checks = [ + "executorch_exir_dialects_edge__ops_aten_lt_Tensor", + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + ] + pipeline = TosaPipelineINT[input_t1]( + module, + test_data, + aten_ops_checks, + exir_ops_checks, + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + +class ArangeStartStepIncrementViewLessThan(torch.nn.Module): + + def forward(self, x: torch.Tensor): + return (torch.arange(0, 10, 2, dtype=torch.int64) + 1).view(-1, 1) < x + + test_data = { + "randint": ( + torch.randint( + 0, + 10, + (1,), + dtype=torch.int32, + ), + ), + } + + +@common.parametrize("test_data", ArangeStartStepIncrementViewLessThan.test_data) +def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_FP( + test_data: input_t1, +): + module = ArangeStartStepIncrementViewLessThan() + aten_ops_checks = [ + "torch.ops.aten.lt.Tensor", + "torch.ops.aten.view.default", + ] + exir_ops_checks = [ + "executorch_exir_dialects_edge__ops_aten_lt_Tensor", + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + ] + pipeline = TosaPipelineFP[input_t1]( + module, + test_data, + aten_ops_checks, + exir_ops_checks, + transform_passes=[ConvertInt64ConstOpsToInt32Pass()], + ) + pipeline.run() + + +@common.parametrize("test_data", ArangeStartStepIncrementViewLessThan.test_data) +def test_convert_arange_start_step_int64_dtype_to_int32_pass_tosa_INT( + test_data: input_t1, +): + module = ArangeStartStepIncrementViewLessThan() + aten_ops_checks = [ + "torch.ops.aten.lt.Tensor", + "torch.ops.aten.view.default", + ] + exir_ops_checks = [ + "executorch_exir_dialects_edge__ops_aten_lt_Tensor", + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + ] + pipeline = TosaPipelineINT[input_t1]( + module, + test_data, + aten_ops_checks, + exir_ops_checks, + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + +######################################################### +## Test arange(dtype=None) -> arange(dtype=None/int32) ## +######################################################### + + +class ArangeAddDtypeNone(torch.nn.Module): + aten_op: str = "torch.ops.aten.arange.start_step" + exir_op: str = "executorch_exir_dialects_edge__ops_aten_arange_start_step" + + def __init__(self, start: float, stop: float, step: float): + super().__init__() + self.args = (start, stop, step) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.arange(*self.args) + x + + test_data = { + "int64": (lambda: (torch.randn(10, 1),), (0, 10, 1)), + "float32_start": (lambda: (torch.randn(10, 1),), (0.0, 10, 1)), + "float32_stop": (lambda: (torch.randn(10, 1),), (0, 10.0, 1)), + "float32_step": (lambda: (torch.randn(10, 1),), (0, 10, 1.0)), + "int64_bool_0": (lambda: (torch.randn(10, 1),), (False, True, True)), + "int64_bool_1": (lambda: (torch.randn(10, 1),), (False, True, True * 10)), + "float32_bool_0": (lambda: (torch.randn(10, 1),), (0.0, True, True)), + "float32_bool_1": (lambda: (torch.randn(10, 1),), (False, True, True * 10.0)), + } + + +@common.parametrize("test_data", ArangeAddDtypeNone.test_data) +def test_arange_dtype_none_tosa_FP(test_data): + input_data, init_data = test_data + pipeline = TosaPipelineFP[input_t1]( + ArangeAddDtypeNone(*init_data), + input_data(), + ArangeAddDtypeNone.aten_op, + ArangeAddDtypeNone.exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", ArangeAddDtypeNone.test_data) +def test_arange_dtype_none_tosa_INT(test_data): + input_data, init_data = test_data + pipeline = TosaPipelineINT[input_t1]( + ArangeAddDtypeNone(*init_data), + input_data(), + ArangeAddDtypeNone.aten_op, + ArangeAddDtypeNone.exir_op, + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + +################################################# +## Test full(dtype=int64) -> full(dtype=int32) ## +################################################# + + +class FullIncrementViewMulXLessThanY(torch.nn.Module): + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return ( + ( + torch.full( + ( + 1, + 3, + 5, + ), + 10, + dtype=torch.int64, + ) + + 1 + ).view(-1, 1) + * x + ) < y + + test_data = { + "randint": ( + torch.randint( + 0, + 10, + (1,), + dtype=torch.int32, + ), + torch.randint( + 0, + 10, + (1,), + dtype=torch.int32, + ), + ), + } + + +@common.parametrize("test_data", FullIncrementViewMulXLessThanY.test_data) +def test_convert_full_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1): + """ + There are four int64 placeholders in the original graph: + 1. _lifted_tensor_constant0: 1 + 2. x + 3. y + Ideally, after applying ConvertInt64ConstOpsToInt32Pass to convert the aten.full from int64 to int32, + the int32 type should propagate throughout the graph, and no int64 values should remain. + However, due to unexpected retracing behavior, a cast from int32 → int64 for x was reintroducedh. + + Applying ConvertInt64OutputOpsToInt32Pass afterward resolves this issue, + removing the int64 cast and producing a fully delegated int32 graph. + """ + module = FullIncrementViewMulXLessThanY() + aten_ops_checks = [ + "torch.ops.aten.full.default", + "torch.ops.aten.add.Tensor", + "torch.ops.aten.view.default", + "torch.ops.aten.mul.Tensor", + "torch.ops.aten.lt.Tensor", + ] + exir_ops_checks = [ + "executorch_exir_dialects_edge__ops_aten_full_default", + "executorch_exir_dialects_edge__ops_aten_add_Tensor", + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + "executorch_exir_dialects_edge__ops_aten_mul_Tensor", + "executorch_exir_dialects_edge__ops_aten_lt_Tensor", + ] + pipeline = TosaPipelineFP[input_t2]( + module, + test_data, + aten_ops_checks, + exir_ops_checks, + transform_passes=[ + ConvertInt64ConstOpsToInt32Pass(), + ConvertInt64OutputOpsToInt32Pass(), + ], + ) + pipeline.run() + + +@common.parametrize("test_data", FullIncrementViewMulXLessThanY.test_data) +def test_convert_full_int64_dtype_to_int32_pass_tosa_INT(test_data: input_t1): + """ + For INT profile, _lifted_tensor_constant0 is still int64 after applying ConvertInt64ConstOpsToInt32Pass(). + And an int64->int32 cast is inserted at the beginning of the graph. + TODO: Explore why _lifted_tensor_constant0 is handled in different ways in FP and INT profile. + Find a way to optimize out the int64->int32 cast. + """ + module = FullIncrementViewMulXLessThanY() + aten_ops_checks = [ + "torch.ops.aten.full.default", + "torch.ops.aten.add.Tensor", + "torch.ops.aten.view.default", + "torch.ops.aten.mul.Tensor", + "torch.ops.aten.lt.Tensor", + ] + exir_ops_checks = [ + "executorch_exir_dialects_edge__ops_aten_full_default", + "executorch_exir_dialects_edge__ops_aten_add_Tensor", + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + "executorch_exir_dialects_edge__ops_aten_mul_Tensor", + "executorch_exir_dialects_edge__ops_aten_lt_Tensor", + ] + pipeline = TosaPipelineINT[input_t2]( + module, + test_data, + aten_ops_checks, + exir_ops_checks, + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.run() + + +class RejectFullIncrementViewMulXLessThanY(torch.nn.Module): + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return ( + ( + torch.full( + ( + 1, + 3, + 5, + ), + torch.iinfo(torch.int32).max + 1, + dtype=torch.int64, + ) + + 1 + ).view(-1, 1) + * x + ) < y + + test_data = { + "randint": ( + torch.randint( + 0, + 10, + (1,), + dtype=torch.int32, + ), + torch.randint( + 0, + 10, + (1,), + dtype=torch.int32, + ), + ), + } + + +@common.parametrize("test_data", RejectFullIncrementViewMulXLessThanY.test_data) +@pytest.mark.xfail( + reason="MLETORCH-1254: Add operator support check for aten.arange and aten.full" +) +def test_reject_convert_full_int64_dtype_to_int32_pass_tosa_FP(test_data: input_t1): + module = RejectFullIncrementViewMulXLessThanY() + aten_ops_checks = [ + "torch.ops.aten.full.default", + "torch.ops.aten.add.Tensor", + "torch.ops.aten.view.default", + "torch.ops.aten.mul.Tensor", + "torch.ops.aten.lt.Tensor", + ] + pipeline = TosaPipelineFP[input_t2]( + module, + test_data, + aten_ops_checks, + exir_op=[], + transform_passes=[ + ConvertInt64ConstOpsToInt32Pass(), + ConvertInt64OutputOpsToInt32Pass(), + ], + ) + pipeline.run() + + +##################################################### +## Test full(dtype=None) -> full(dtype=None/int32) ## +##################################################### + + +class AddConstFullDtypeNone(torch.nn.Module): + # Input + a full with constant value. + exir_op = "executorch_exir_dialects_edge__ops_aten_full_default" + + def __init__(self, size: tuple, fill_value: Union[bool, float, int]): + super().__init__() + self.args = (size, fill_value) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.full(*self.args) + x + + test_data = { + "int64": (lambda: (torch.randn(1),), ((1, 2, 3), 10)), + "float32": (lambda: (torch.randn(1),), ((1, 2, 3), 10.0)), + } + + test_data_bool = { + "bool": (lambda: (torch.randn(1),), ((1, 2, 3), True)), + } + + +@common.parametrize("test_data", AddConstFullDtypeNone.test_data) +def test_full_dtype_none_tosa_FP(test_data): + input_data, init_data = test_data + pipeline = TosaPipelineFP[input_t1]( + AddConstFullDtypeNone(*init_data), + input_data(), + aten_op=[], + exir_op=AddConstFullDtypeNone.exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", AddConstFullDtypeNone.test_data_bool) +def test_full_dtype_none_tosa_FP_bool(test_data): + input_data, init_data = test_data + pipeline = TosaPipelineFP[input_t1]( + AddConstFullDtypeNone(*init_data), + input_data(), + aten_op=[], + exir_op=AddConstFullDtypeNone.exir_op, + ) + pipeline.change_args( + "check_count.exir", + {"torch.ops.higher_order.executorch_call_delegate": 2}, + ) + pipeline.run() + + +@common.parametrize( + "test_data", AddConstFullDtypeNone.test_data | AddConstFullDtypeNone.test_data_bool +) +def test_full_dtype_none_tosa_INT(test_data): + input_data, init_data = test_data + pipeline = TosaPipelineINT[input_t1]( + AddConstFullDtypeNone(*init_data), + input_data(), + aten_op=[], + exir_op=AddConstFullDtypeNone.exir_op, + ) + pipeline.pop_stage("check.quant_nodes") + pipeline.run() diff --git a/backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py b/backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py new file mode 100644 index 00000000000..cfed4245eed --- /dev/null +++ b/backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py @@ -0,0 +1,131 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm._passes import ConvertInt64OutputOpsToInt32Pass + +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import TosaPipelineFP + +input_t1 = Tuple[torch.Tensor] # Input x + + +######################################### +## Test [int32 | other types] -> int64 ## +######################################### + + +class CastingToInt64Model(torch.nn.Module): + def __init__(self, target_dtype): + super().__init__() + self.target_dtype = target_dtype + + def forward(self, x: torch.Tensor): + return x.to(dtype=self.target_dtype) + + +test_data_suite_convert = { + "fp32_input": lambda: (torch.rand((1, 2, 3, 4), dtype=torch.float32), torch.int64), + "fp16_input": lambda: (torch.rand((1, 2, 3, 4), dtype=torch.float16), torch.int64), + "int16_input": lambda: ( + torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int16), + torch.int64, + ), + "int8_input": lambda: ( + torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8), + torch.int64, + ), +} + + +test_data_suite_remove = { + "int32_input": lambda: ( + torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32), + torch.int64, + ), +} + + +@common.parametrize("test_data", test_data_suite_convert) +def test_convert_or_remove_casting_to_int64_covnert_tosa_FP(test_data: Tuple): + test_tensor, target_dtype = test_data() + module = CastingToInt64Model(target_dtype) + + pipeline = TosaPipelineFP[input_t1]( + module, + (test_tensor,), + aten_op="torch.ops.aten.to.dtype", + exir_op=[], + transform_passes=[ConvertInt64OutputOpsToInt32Pass()], + ) + pipeline.pop_stage( + "run_method_and_compare_outputs" + ) # As expected: RuntimeError: Int did not match Long + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_remove) +def test_convert_or_remove_casting_to_int64_remove_tosa_FP(test_data: Tuple): + test_tensor, target_dtype = test_data() + module = CastingToInt64Model(target_dtype) + + pipeline = TosaPipelineFP[input_t1]( + module, + (test_tensor,), + aten_op=[], + exir_op=[], + transform_passes=[ConvertInt64OutputOpsToInt32Pass()], + ) + pipeline.change_args( + "check_count.exir", {"torch.ops.higher_order.executorch_call_delegate": 0} + ) # Empty graph without nodes + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +##################################################### +## Test arange(dtype=int64) -> arange(dtype=int32) ## +##################################################### + + +class Int64OutputModel(torch.nn.Module): + + def forward(self, x: torch.Tensor): + # return torch.argmax(x) # RuntimeError: Int did not match Long; But this is expected as we expect _argmax_i32 to generate int32 output + # return (10 * torch.argmax(x) + 10).to(dtype=torch.int32) # [1]. This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (function _resize_output_check) + return (10 * torch.argmax(x, dim=-1) + 10) + 1.5 + + def get_inputs(self) -> input_t1: + return ( + torch.randint( + 0, + 10, + (2, 4, 6, 8), + ), + ) + + +def test_insert_int64_output_to_int32_cast_tosa_FP(): + module = Int64OutputModel() + aten_ops_checks = [ + "torch.ops.aten.argmax.default", + "torch.ops.aten.mul.Tensor", + "torch.ops.aten.add.Tensor", + ] + exir_ops_checks = [ + "executorch_exir_dialects_edge__ops_aten_mul_Tensor", + "executorch_exir_dialects_edge__ops_aten_add_Tensor", + ] + pipeline = TosaPipelineFP[input_t1]( + module, + module.get_inputs(), + aten_op=aten_ops_checks, + exir_op=exir_ops_checks, + transform_passes=[ConvertInt64OutputOpsToInt32Pass()], + ) + pipeline.run()