From 55c07edcb04a305a601d0b22a3fcdf896143ac7e Mon Sep 17 00:00:00 2001 From: Sungjoon Shon Date: Wed, 1 Oct 2025 16:22:28 +0000 Subject: [PATCH] mul: remove opmath cast sequence MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove the explicit opmath-driven cast chain (bf16→f32→bf16, etc.) from `mul`. The op now executes in the dtype chosen by standard dtype promotion, without inserting unconditional upcast/downcast steps. But leave its functionality for future usage. --- test/test_operations_hlo.py | 16 ++++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 1 - 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/test/test_operations_hlo.py b/test/test_operations_hlo.py index 9b254d1464b2..a7f34d6efb65 100644 --- a/test/test_operations_hlo.py +++ b/test/test_operations_hlo.py @@ -67,6 +67,22 @@ def test_dropout_by_u8_mask(self): hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([b]) assert 'u8' in hlo_text + def test_bfloat16_mul_not_upcast(self): + a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') + b = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') + c = a * b + hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([c]) + # Check that the output is not upcasted to float32 + assert 'f32' not in hlo_text + + def test_bfloat16_float32_mul_upcast(self): + a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') + b = torch.rand(5, 5, dtype=torch.float32).to('xla') + c = a * b + hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([c]) + # Check that the output is upcasted to float32 + assert 'f32' in hlo_text + if __name__ == '__main__': torch.set_default_dtype(torch.float32) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index b6a8484a2505..d48251056c86 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2535,7 +2535,6 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self, .add_input(self) .add_input(other) .cast_inputs_to_common_dtype() - .use_opmathtype_for_compute() .run(); }