From 80b4f78a8d186b106a2e41e88f079fd8e836a6e7 Mon Sep 17 00:00:00 2001 From: Raj Thakur Date: Sat, 8 Nov 2025 04:32:53 +0000 Subject: [PATCH 1/3] Add opmath cast sequence for CPU or Neuron --- torch_xla/csrc/aten_xla_type.cpp | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index d48251056c86..d7976cf54595 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2529,13 +2529,26 @@ at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output, at::Tensor XLANativeFunctions::mul(const at::Tensor& self, const at::Tensor& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - using FnType = XLATensorPtr(const XLATensorPtr&, const XLATensorPtr&, - std::optional); - return OpConfig::From(static_cast(tensor_methods::mul)) - .add_input(self) - .add_input(other) - .cast_inputs_to_common_dtype() - .run(); + + // Check device type to determine if we need opmathtype for mixed-precision + XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); + XlaDeviceType hw_type = + static_cast(xla_self->GetDevice().type()); + + auto config = + OpConfig([](const XLAInputVector& inputs, at::ScalarType dtype) { + return tensor_methods::mul(inputs[0], inputs[1], dtype); + }) + .add_input(self) + .add_input(other) + .cast_inputs_to_common_dtype(); + + // Only use opmathtype for CPU or Neuron backend + if (hw_type == XlaDeviceType::CPU || hw_type == XlaDeviceType::NEURON) { + config.use_opmathtype_for_compute(); + } + + return config.run(); } at::Tensor XLANativeFunctions::mul(const at::Tensor& self, From 63e7bf1e7388eda1940e7d4cfe8c0488d69374e4 Mon Sep 17 00:00:00 2001 From: Raj Thakur Date: Sun, 9 Nov 2025 00:19:07 +0000 Subject: [PATCH 2/3] test upcast on cpu --- test/test_operations_hlo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_operations_hlo.py b/test/test_operations_hlo.py index a7f34d6efb65..551737ce3321 100644 --- a/test/test_operations_hlo.py +++ b/test/test_operations_hlo.py @@ -67,13 +67,13 @@ def test_dropout_by_u8_mask(self): hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([b]) assert 'u8' in hlo_text - def test_bfloat16_mul_not_upcast(self): + def test_bfloat16_mul_upcast(self): a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') b = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') c = a * b hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([c]) # Check that the output is not upcasted to float32 - assert 'f32' not in hlo_text + assert 'f32' in hlo_text def test_bfloat16_float32_mul_upcast(self): a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') From 214f075a0e956f8cf090826c43fe97162ef96d48 Mon Sep 17 00:00:00 2001 From: Raj Thakur Date: Mon, 10 Nov 2025 20:49:19 +0000 Subject: [PATCH 3/3] Revert "mul: remove opmath cast sequence (#9663)" This reverts commit 2a9138a26ee257fef05310ad3fecf7c55fe80d73. --- test/test_operations_hlo.py | 16 ---------------- torch_xla/csrc/aten_xla_type.cpp | 28 ++++++++-------------------- 2 files changed, 8 insertions(+), 36 deletions(-) diff --git a/test/test_operations_hlo.py b/test/test_operations_hlo.py index 551737ce3321..9b254d1464b2 100644 --- a/test/test_operations_hlo.py +++ b/test/test_operations_hlo.py @@ -67,22 +67,6 @@ def test_dropout_by_u8_mask(self): hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([b]) assert 'u8' in hlo_text - def test_bfloat16_mul_upcast(self): - a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') - b = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') - c = a * b - hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([c]) - # Check that the output is not upcasted to float32 - assert 'f32' in hlo_text - - def test_bfloat16_float32_mul_upcast(self): - a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') - b = torch.rand(5, 5, dtype=torch.float32).to('xla') - c = a * b - hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([c]) - # Check that the output is upcasted to float32 - assert 'f32' in hlo_text - if __name__ == '__main__': torch.set_default_dtype(torch.float32) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index d7976cf54595..b6a8484a2505 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2529,26 +2529,14 @@ at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output, at::Tensor XLANativeFunctions::mul(const at::Tensor& self, const at::Tensor& other) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - - // Check device type to determine if we need opmathtype for mixed-precision - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - XlaDeviceType hw_type = - static_cast(xla_self->GetDevice().type()); - - auto config = - OpConfig([](const XLAInputVector& inputs, at::ScalarType dtype) { - return tensor_methods::mul(inputs[0], inputs[1], dtype); - }) - .add_input(self) - .add_input(other) - .cast_inputs_to_common_dtype(); - - // Only use opmathtype for CPU or Neuron backend - if (hw_type == XlaDeviceType::CPU || hw_type == XlaDeviceType::NEURON) { - config.use_opmathtype_for_compute(); - } - - return config.run(); + using FnType = XLATensorPtr(const XLATensorPtr&, const XLATensorPtr&, + std::optional); + return OpConfig::From(static_cast(tensor_methods::mul)) + .add_input(self) + .add_input(other) + .cast_inputs_to_common_dtype() + .use_opmathtype_for_compute() + .run(); } at::Tensor XLANativeFunctions::mul(const at::Tensor& self,