diff --git a/backends/arm/_passes/match_where_self_arg_dtype_pass.py b/backends/arm/_passes/match_where_self_arg_dtype_pass.py index 154602129f8..fdbd4433bab 100644 --- a/backends/arm/_passes/match_where_self_arg_dtype_pass.py +++ b/backends/arm/_passes/match_where_self_arg_dtype_pass.py @@ -49,7 +49,7 @@ def call(self, graph_module: torch.fx.GraphModule): input_dtype = input_.meta["val"].dtype other_dtype = other_.meta["val"].dtype - target_dtype = torch.float32 + target_dtype = input_dtype if input_dtype != other_dtype: target_dtype = get_largest_dtype(input_dtype, other_dtype) diff --git a/backends/arm/test/ops/test_where.py b/backends/arm/test/ops/test_where.py index 7bfd27ac0a8..a60cf587a3e 100644 --- a/backends/arm/test/ops/test_where.py +++ b/backends/arm/test/ops/test_where.py @@ -121,6 +121,12 @@ def scalar_condition(input: torch.Tensor): scalar_condition, ) +int32_scalar_cond = Where( + 1, + torch.int32, + scalar_condition, +) + test_modules_common = { "two_dim_tensor_cond": lambda: two_dim_tensor_cond, "three_dim_tensor_cond": lambda: three_dim_tensor_cond, @@ -134,6 +140,7 @@ def scalar_condition(input: torch.Tensor): **test_modules_common, "float32_tensor_cond_tuple_dtype": lambda: float32_tensor_cond_tuple_dtype, "float32_tensor_cond_tuple_dtype_bool": lambda: float32_tensor_cond_tuple_dtype_bool, + "int32_scalar_cond": lambda: int32_scalar_cond, } test_modules_BI = {