Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion backends/arm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,38 @@ List of model specific and optional passes:
- InsertCastForOpsWithInt64InputPass
- Functionality:
- For LLMs such as LLama, some opeartors like aten.embedding have int64 input. In order to lower these operators to TOSA, this pass will insert a casting node that converts the input from int64 to int32.
- Example usage: backends/arm/test/models/test_llama.py
- Supported Ops:
- aten.embedding.default, aten.slice_copy.Tensor
- Example usage:
- backends/arm/test/models/test_llama.py

- ConvertInt64ConstOpsToInt32Pass
- Functionalities:
- Rewrites constant-producing ops that output int64 to instead output int32, when values are within int32 bounds.
- Supported Ops:
- `torch.full`, `torch.arange`, `torch.eye`, `torch.linspace`, `torch.tensor`
- Example usage:
- backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
- backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py

- ConvertInt64OutputOpsToInt32Pass
- Overview:
- Rewrites or removes operations that produce int64 outputs, converting them to int32 where possible.
- Overflow checks are applied selectively; for ops without such checks, users need to ensure values fit within the int32 range.
- Functionalities:
1. Handling casting to int64:
- (1) int32 -> int64:
- Removes the cast and redirect uses of int64 to int32
- (2) other types -> int64:
- Rewrites the cast to other types -> int32
- Supported Ops:
- torch.ops.aten.to.\[dtype|dtype_layout\]
- exir_ops.edge.dim_order_ops._to_dim_order_copy.default
2. Post-process argmax outputs:
- Inserts an int64->int32 cast after the argmax operations that produce int64 outputs:
- Supported Ops:
- torch.ops.aten.argmax.default
- exir_ops.edge.aten.argmax.default
- Example usage:
- (Functionality 1) backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py
- (Functionality 2) backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
2 changes: 2 additions & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from .convert_any_default_dim_dims_pass import ConvertAnyDefaultDimDimsPass # noqa
from .convert_expand_copy_to_repeat import ConvertExpandCopyToRepeatPass # noqa
from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa
from .convert_int64_const_ops_to_int32 import ConvertInt64ConstOpsToInt32Pass # noqa
from .convert_int64_output_ops_to_int32 import ConvertInt64OutputOpsToInt32Pass # noqa
from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa
from .convert_minmax_pass import ConvertMinMaxPass # noqa
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
Expand Down
8 changes: 8 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
ConvertAnyDefaultDimDimsPass,
ConvertExpandCopyToRepeatPass,
ConvertFullLikeToFullPass,
ConvertInt64ConstOpsToInt32Pass,
ConvertInt64OutputOpsToInt32Pass,
ConvertIntPowToMuls,
ConvertMinMaxPass,
ConvertMmToBmmPass,
Expand Down Expand Up @@ -98,6 +100,7 @@
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
from executorch.exir.pass_manager import PassManager
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
from torch.fx import GraphModule


Expand Down Expand Up @@ -258,6 +261,11 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
)

def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(
RemoveGraphAssertsPass()
) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph
self.add_pass(ConvertInt64ConstOpsToInt32Pass())
self.add_pass(ConvertInt64OutputOpsToInt32Pass())
self.add_pass(InsertCastForOpsWithInt64InputPass())
self.add_pass(DecomposeEmbeddingPass())
self.add_pass(DecomposeScaledDotProductAttention())
Expand Down
74 changes: 74 additions & 0 deletions backends/arm/_passes/convert_int64_const_ops_to_int32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import logging

import torch
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
from executorch.exir.pass_base import ExportPass, PassResult


logger = logging.getLogger(__name__)
INT32_MIN = torch.iinfo(torch.int32).min
INT32_MAX = torch.iinfo(torch.int32).max


class ConvertInt64ConstOpsToInt32Pass(ExportPass):
"""
Rewrite constant ops that produce int64 to int32 where safe.

List of supported operatos:
1. `torch.full`
2. `torch.arange`
3. `torch.eye`
4. `torch.linspace`
5. `torch.tensor`
"""

torch_ops = [
torch.ops.aten.full.default,
torch.ops.aten.arange.default,
torch.ops.aten.arange.start,
torch.ops.aten.arange.start_step,
torch.ops.aten.eye.default,
torch.ops.aten.linspace.default,
]

def call(self, graph_module: torch.fx.GraphModule):
modified = False
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue

if node.target not in ComputeConstantOpsAOT.targeted_ops + self.torch_ops:
continue

data = node.target(*node.args, **node.kwargs)
if data.dtype is not torch.int64:
continue

min_val, max_val = torch.min(data), torch.max(data)
if INT32_MIN <= min_val and max_val <= INT32_MAX:
logger.warning(
f"Casting {node.name} from torch.int64 to torch.int32"
f" defined in {node.meta.get('stack_trace','[no stack trace found]')}"
)
node.update_kwarg("dtype", torch.int32)
modified = True
else:
logger.warning(
f"[{node.name}] has values: min={min_val}, max={max_val}, which exceeds int32 range "
f"([{INT32_MIN}, {INT32_MAX}]); not converting dtype to int32."
)

if modified:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)
153 changes: 153 additions & 0 deletions backends/arm/_passes/convert_int64_output_ops_to_int32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import logging

import torch
from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
set_node_arg,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


logger = logging.getLogger(__name__)


class ConvertInt64OutputOpsToInt32Pass(ExportPass):
"""
Rewrites or removes operations that produce int64 outputs, converting them
to int32 where possible.


Currently, this pass handles casting and argmax operators:
1. int32 -> int64:
removes the cast and redirects all uses to the original int32 value.
2. other types -> int64:
rewrites the cast to produce int32 instead of int64.
3. torch.argmax()
insert an int64->int32 cast after the argmax node

Future extensions may include operators that return int64 outputs by default
(e.g., `argmin`), rewriting them or inserting an int64 -> int32 cast to yield
int32 results.

Note: Overflow checks are applied selectively in this pass. For operators without
such checks, it is the user's responsibility to ensure that values fit within
the int32 range.
"""

aten_cast_ops = (
torch.ops.aten.to.dtype,
torch.ops.aten.to.dtype_layout,
)
edge_cast_ops = (exir_ops.edge.dim_order_ops._to_dim_order_copy.default,)

aten_argmax_ops = (torch.ops.aten.argmax.default,)
edge_argmax_ops = (exir_ops.edge.aten.argmax.default,)

aten_ops = aten_cast_ops + aten_argmax_ops
edge_ops = edge_cast_ops + edge_argmax_ops

# dtype is specified in args
cast_ops_args = (
torch.ops.aten.to.dtype, # to_2: node.args: (gt, torch.int64) node.kwargs: {}
)
# dtype is specified in kwargs
cast_ops_kwargs = (
torch.ops.aten.to.dtype_layout, # to_1: node.args: (unsqueeze,) node.kwargs: {'dtype': torch.int64, 'layout': torch.strided, 'device': device(type='cpu')}
exir_ops.edge.dim_order_ops._to_dim_order_copy.default, # node.args: (aten_gt_scalar,) node.kwargs: {'dtype': torch.int64, 'dim_order': [0, 1]}
)

def _get_decomposition(self, op):
if op in self.edge_ops:
return exir_ops.edge.aten._to_copy.default

if op in self.aten_ops:
return torch.ops.aten._to_copy.default

raise RuntimeError(
f"[{self.__class__.__name__}] Can't get decomposition for op {op}"
)

def _convert_casting_operators(self, node: torch.fx.Node):
input_node = node.all_input_nodes[0]
input_dtype = get_first_fake_tensor(input_node).dtype
# Case 1: int32 -> int64 - removes the ops
if input_dtype == torch.int32:
users = [user for user in node.users if node != user]
for user in users:
logger.warning(
f"Removing int32->int64 casting node {node.name} defined in"
f" {node.meta.get('stack_trace','[no stack trace found]')}"
)
user.replace_input_with(node, input_node)
# Case 2: other types -> int64 - rewrites to cast to int32
else:
if node.target in self.cast_ops_kwargs:
set_node_arg(node, "dtype", torch.int32)
elif node.target in self.cast_ops_args:
set_node_arg(node, 1, torch.int32)
else:
raise RuntimeError(f"Unexpected target {node.target} in {node.name}")
output_dtype = get_first_fake_tensor(node).dtype
logger.warning(
f"Converting casting node {node.name} from {input_dtype}->{output_dtype} to"
f" {input_dtype}->torch.int32 defined in {node.meta.get('stack_trace','[no stack trace found]')}"
)

def _convert_argmax_operators(self, node: torch.fx.Node, graph: torch.fx.Graph):
output_tensor = node
to_copy_op = self._get_decomposition(node.target)
with graph.inserting_after(node):
cast_after = create_node(
graph,
to_copy_op,
args=(output_tensor,),
kwargs={
"dtype": torch.int32,
},
)
users = [user for user in node.users if user != cast_after]
for user in users:
user.replace_input_with(output_tensor, cast_after)
logger.warning(
f"Inserting a casting node {cast_after.name} after {node.name} to cast int64 output"
f" to int32 for {node.name} defined in {node.meta.get('stack_trace','[no stack trace found]')}"
)

def call(self, graph_module: torch.fx.GraphModule):
modified = False
graph = graph_module.graph
for node in list(graph.nodes):
if node.op != "call_function":
continue
if node.target not in self.aten_ops + self.edge_ops:
continue
output_dtype = get_first_fake_tensor(node).dtype
if output_dtype != torch.int64:
continue

if node.target in self.aten_cast_ops + self.edge_cast_ops:
self._convert_casting_operators(node)
elif node.target in self.aten_argmax_ops + self.edge_argmax_ops:
# TODO: Add range check based on the input tensor shape before casting the output
self._convert_argmax_operators(node, graph)
else:
raise RuntimeError(f"Unexpected target {node.target} in {node.name}")

modified = True

if modified:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import unittest

import torch
from executorch.backends.arm._passes import InsertCastForOpsWithInt64InputPass
from executorch.backends.arm._passes import (
ConvertInt64ConstOpsToInt32Pass,
ConvertInt64OutputOpsToInt32Pass,
InsertCastForOpsWithInt64InputPass,
)

from executorch.backends.arm.test import common
from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import (
Expand All @@ -28,13 +32,11 @@ class TestCLIPTextModelWithProjection(unittest.TestCase):
# for that is some assert ops are removed by passes in the
# .to_executorch step, i.e. after Arm partitioner.
ops_after_partitioner = {
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 3,
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 4,
"executorch_exir_dialects_edge__ops_aten_argmax_default": 1,
"executorch_exir_dialects_edge__ops_aten_index_Tensor": 1,
"executorch_exir_dialects_edge__ops_aten_lt_Tensor": 1,
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
"torch.ops.higher_order.executorch_call_delegate": 3,
"torch.ops.higher_order.executorch_call_delegate": 2,
}

def _prepare_inputs(
Expand All @@ -60,15 +62,19 @@ def prepare_model_and_inputs(self):

return text_encoder_model, text_encoder_model_inputs

def test_CLIPTextModelWithProjection_tosa_MI(self):
def test_CLIPTextModelWithProjection_tosa_FP(self):
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
with torch.no_grad():
(
ArmTester(
text_encoder_model,
example_inputs=text_encoder_model_inputs,
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
transform_passes=[InsertCastForOpsWithInt64InputPass()],
transform_passes=[
InsertCastForOpsWithInt64InputPass(),
ConvertInt64ConstOpsToInt32Pass(),
ConvertInt64OutputOpsToInt32Pass(),
],
)
.export()
.to_edge_transform_and_lower()
Expand Down
Loading
Loading