diff --git a/test/test_operations_hlo.py b/test/test_operations_hlo.py index 9b254d1464b..a7f34d6efb6 100644 --- a/test/test_operations_hlo.py +++ b/test/test_operations_hlo.py @@ -67,6 +67,22 @@ def test_dropout_by_u8_mask(self): hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([b]) assert 'u8' in hlo_text + def test_bfloat16_mul_not_upcast(self): + a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') + b = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') + c = a * b + hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([c]) + # Check that the output is not upcasted to float32 + assert 'f32' not in hlo_text + + def test_bfloat16_float32_mul_upcast(self): + a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') + b = torch.rand(5, 5, dtype=torch.float32).to('xla') + c = a * b + hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([c]) + # Check that the output is upcasted to float32 + assert 'f32' in hlo_text + if __name__ == '__main__': torch.set_default_dtype(torch.float32) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index b6a8484a250..d48251056c8 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2535,7 +2535,6 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self, .add_input(self) .add_input(other) .cast_inputs_to_common_dtype() - .use_opmathtype_for_compute() .run(); }