diff --git a/backends/arm/passes/arm_pass_manager.py b/backends/arm/passes/arm_pass_manager.py index 03fbb38d04b..405ac7e0e1e 100644 --- a/backends/arm/passes/arm_pass_manager.py +++ b/backends/arm/passes/arm_pass_manager.py @@ -11,6 +11,7 @@ from executorch.backends.arm.passes.annotate_channels_last_dim_order_pass import ( AnnotateChannelsLastDimOrder, ) +from executorch.backends.arm.passes.cast_int64_pass import CastInt64ToInt32Pass from executorch.backends.arm.passes.convert_expand_copy_to_repeat import ( ConvertExpandCopyToRepeatPass, ) @@ -36,6 +37,7 @@ def transform_to_backend_pipeline( self, exported_program: ExportedProgram, compile_spec: list[CompileSpec] ): """Apply passes before transforming program to backend""" + self.add_pass(CastInt64ToInt32Pass(exported_program)) self.add_pass(SizeAdjustConv2DPass()) self.add_pass(RemoveClonePass()) self.add_pass(ConvertExpandCopyToRepeatPass()) diff --git a/backends/arm/passes/cast_int64_pass.py b/backends/arm/passes/cast_int64_pass.py new file mode 100644 index 00000000000..6bdbca62879 --- /dev/null +++ b/backends/arm/passes/cast_int64_pass.py @@ -0,0 +1,35 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + + +class CastInt64ToInt32Pass(ExportPass): + def __init__(self, exported_program: torch.export.ExportedProgram): + super(CastInt64ToInt32Pass, self).__init__() + self.exported_program = exported_program + + def _to_int32(self, graph_module: torch.fx.GraphModule): + for node in graph_module.graph.nodes: + fake_tensor = node.meta["val"] + if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor): + if node.meta["val"].dtype == torch.int64: + node.meta["val"] = node.meta["val"].to(torch.int32) + buffer_name = ( + self.exported_program.graph_signature.inputs_to_buffers[ + node.name + ] + ) + new_tensor = self.exported_program.state_dict[buffer_name].to( + torch.int32 + ) + self.exported_program.state_dict[buffer_name] = new_tensor + + def call(self, graph_module: torch.fx.GraphModule): + self._to_int32(graph_module) + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, True)