diff --git a/test/test_operations_hlo.py b/test/test_operations_hlo.py index a7f34d6efb65..9b254d1464b2 100644 --- a/test/test_operations_hlo.py +++ b/test/test_operations_hlo.py @@ -67,22 +67,6 @@ def test_dropout_by_u8_mask(self): hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([b]) assert 'u8' in hlo_text - def test_bfloat16_mul_not_upcast(self): - a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') - b = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') - c = a * b - hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([c]) - # Check that the output is not upcasted to float32 - assert 'f32' not in hlo_text - - def test_bfloat16_float32_mul_upcast(self): - a = torch.rand(5, 5, dtype=torch.bfloat16).to('xla') - b = torch.rand(5, 5, dtype=torch.float32).to('xla') - c = a * b - hlo_text = torch_xla._XLAC._get_xla_tensors_hlo([c]) - # Check that the output is upcasted to float32 - assert 'f32' in hlo_text - if __name__ == '__main__': torch.set_default_dtype(torch.float32) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index d48251056c86..b6a8484a2505 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2535,6 +2535,7 @@ at::Tensor XLANativeFunctions::mul(const at::Tensor& self, .add_input(self) .add_input(other) .cast_inputs_to_common_dtype() + .use_opmathtype_for_compute() .run(); }