From 918af66b1033464f0c29e77ba7d2982b68ae9e19 Mon Sep 17 00:00:00 2001 From: Yufeng Shi Date: Wed, 18 Jun 2025 14:45:14 +0100 Subject: [PATCH] Arm backend: Fix bug of inserting unnecessary casts for aten.where.self - In MatchWhereSelfDtypePass, target_dtype was initialized with fp32. This works when at least one of the inputs is fp32. But when both inputs are int32, the pass will incorrectly insert int32->fp32 casts. These casts are unnecessary and may introduce operand dtype mismatch issues. - Fix it by initializing target_dtype with input_dtype. Change-Id: Id67aed8e90f886dc2f2491946847ad01702d5aa5 Signed-off-by: Yufeng Shi --- backends/arm/_passes/match_where_self_arg_dtype_pass.py | 2 +- backends/arm/test/ops/test_where.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/backends/arm/_passes/match_where_self_arg_dtype_pass.py b/backends/arm/_passes/match_where_self_arg_dtype_pass.py index 154602129f8..fdbd4433bab 100644 --- a/backends/arm/_passes/match_where_self_arg_dtype_pass.py +++ b/backends/arm/_passes/match_where_self_arg_dtype_pass.py @@ -49,7 +49,7 @@ def call(self, graph_module: torch.fx.GraphModule): input_dtype = input_.meta["val"].dtype other_dtype = other_.meta["val"].dtype - target_dtype = torch.float32 + target_dtype = input_dtype if input_dtype != other_dtype: target_dtype = get_largest_dtype(input_dtype, other_dtype) diff --git a/backends/arm/test/ops/test_where.py b/backends/arm/test/ops/test_where.py index 7bfd27ac0a8..a60cf587a3e 100644 --- a/backends/arm/test/ops/test_where.py +++ b/backends/arm/test/ops/test_where.py @@ -121,6 +121,12 @@ def scalar_condition(input: torch.Tensor): scalar_condition, ) +int32_scalar_cond = Where( + 1, + torch.int32, + scalar_condition, +) + test_modules_common = { "two_dim_tensor_cond": lambda: two_dim_tensor_cond, "three_dim_tensor_cond": lambda: three_dim_tensor_cond, @@ -134,6 +140,7 @@ def scalar_condition(input: torch.Tensor): **test_modules_common, "float32_tensor_cond_tuple_dtype": lambda: float32_tensor_cond_tuple_dtype, "float32_tensor_cond_tuple_dtype_bool": lambda: float32_tensor_cond_tuple_dtype_bool, + "int32_scalar_cond": lambda: int32_scalar_cond, } test_modules_BI = {