diff --git a/test/test_autocast.py b/test/test_autocast.py index 9d7b9f7b7b1e..acbd0e03be39 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -484,20 +484,6 @@ def test_autocast_tpu_check_dtype(self): assert not torch.is_autocast_xla_enabled() -class TestOtherOps(unittest.TestCase): - - @unittest.skipIf(not (xm.get_xla_supported_devices("TPU") or - xm.get_xla_supported_devices("GPU")), - f"bfloat16 is only enabled for TPU and GPU") - def test_batch_norm(self): - device = xm.xla_device() - data = torch.randn(4, 16, 32, 32, device=device, dtype=torch.bfloat16) - with autocast(device, dtype=torch.bfloat16): - output = torch.nn.BatchNorm2d(16)(data) - xm.mark_step() - self.assertEqual(output.dtype, torch.bfloat16) - - if __name__ == "__main__": test = unittest.main(verbosity=FLAGS.verbosity, exit=False) sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/batch_norm.cpp b/torch_xla/csrc/batch_norm.cpp index fa9a365e7743..c33fc0c523c6 100644 --- a/torch_xla/csrc/batch_norm.cpp +++ b/torch_xla/csrc/batch_norm.cpp @@ -8,17 +8,10 @@ namespace torch_xla { namespace { -bool IsF32BatchNormWithLowerFPInputs(const xla::XlaOp& input, - const xla::XlaOp& weight) { - static constexpr std::array lowerPrecistionTypes = { - xla::PrimitiveType::F8E5M2, xla::PrimitiveType::F8E4M3, - xla::PrimitiveType::F8E4M3FN, xla::PrimitiveType::F8E4M3B11FNUZ, - xla::PrimitiveType::F8E3M4, xla::PrimitiveType::F8E5M2FNUZ, - xla::PrimitiveType::F8E4M3FNUZ, xla::PrimitiveType::F16, - xla::PrimitiveType::BF16}; - if (std::find(lowerPrecistionTypes.begin(), lowerPrecistionTypes.end(), - ShapeHelper::ShapeOfXlaOp(input).element_type()) != - lowerPrecistionTypes.end() && +bool IsF32BatchNormWithFP16Inputs(const xla::XlaOp& input, + const xla::XlaOp& weight) { + if (ShapeHelper::ShapeOfXlaOp(input).element_type() == + xla::PrimitiveType::F16 && ShapeHelper::ShapeOfXlaOp(weight).element_type() == xla::PrimitiveType::F32) { return true; @@ -46,10 +39,10 @@ xla::XlaOp BatchNormVarianceInvert(xla::XlaOp variance, float eps_value) { BatchNormOutput BuildBatchNormTraining(xla::XlaOp input, xla::XlaOp weight, xla::XlaOp bias, float eps_value) { - bool is_batchnorm_with_lower_fp_inputs = - IsF32BatchNormWithLowerFPInputs(input, weight); + bool is_batchnorm_with_fp16_inputs = + IsF32BatchNormWithFP16Inputs(input, weight); // Handle the mixed precision use case. - if (is_batchnorm_with_lower_fp_inputs) { + if (is_batchnorm_with_fp16_inputs) { input = xla::ConvertElementType(input, xla::PrimitiveType::F32); } xla::XlaOp outputs = xla::BatchNormTraining(input, weight, bias, eps_value, @@ -57,9 +50,8 @@ BatchNormOutput BuildBatchNormTraining(xla::XlaOp input, xla::XlaOp weight, xla::XlaOp output = xla::GetTupleElement(outputs, 0); xla::XlaOp batch_mean = xla::GetTupleElement(outputs, 1); xla::XlaOp batch_variance = xla::GetTupleElement(outputs, 2); - if (is_batchnorm_with_lower_fp_inputs) { - output = xla::ConvertElementType( - output, ShapeHelper::ShapeOfXlaOp(input).element_type()); + if (is_batchnorm_with_fp16_inputs) { + output = xla::ConvertElementType(output, xla::PrimitiveType::F16); } return {output, batch_mean, batch_variance}; } @@ -67,18 +59,17 @@ BatchNormOutput BuildBatchNormTraining(xla::XlaOp input, xla::XlaOp weight, xla::XlaOp BuildBatchNormInference(xla::XlaOp input, xla::XlaOp weight, xla::XlaOp bias, xla::XlaOp mean, xla::XlaOp variance, float eps_value) { - bool is_batchnorm_with_lower_fp_inputs = - IsF32BatchNormWithLowerFPInputs(input, weight); + bool is_batchnorm_with_fp16_inputs = + IsF32BatchNormWithFP16Inputs(input, weight); // Handle the mixed precision use case. - if (is_batchnorm_with_lower_fp_inputs) { + if (is_batchnorm_with_fp16_inputs) { input = xla::ConvertElementType(input, xla::PrimitiveType::F32); } xla::XlaOp output = xla::BatchNormInference(input, weight, bias, mean, variance, eps_value, /*feature_index=*/1); - if (is_batchnorm_with_lower_fp_inputs) { - output = xla::ConvertElementType( - output, ShapeHelper::ShapeOfXlaOp(input).element_type()); + if (is_batchnorm_with_fp16_inputs) { + output = xla::ConvertElementType(output, xla::PrimitiveType::F16); } return output; } @@ -87,10 +78,10 @@ BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input, xla::XlaOp weight, xla::XlaOp save_mean, xla::XlaOp save_invstd, bool training, float eps_value) { - bool is_batchnorm_with_lower_fp_inputs = - IsF32BatchNormWithLowerFPInputs(input, weight); + bool is_batchnorm_with_fp16_inputs = + IsF32BatchNormWithFP16Inputs(input, weight); // Handle the mixed precision use case. - if (is_batchnorm_with_lower_fp_inputs) { + if (is_batchnorm_with_fp16_inputs) { input = xla::ConvertElementType(input, xla::PrimitiveType::F32); grad = xla::ConvertElementType(grad, xla::PrimitiveType::F32); } @@ -100,9 +91,8 @@ BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input, xla::XlaOp grad_input = xla::GetTupleElement(grads, 0); xla::XlaOp grad_weight = xla::GetTupleElement(grads, 1); xla::XlaOp grad_bias = xla::GetTupleElement(grads, 2); - if (is_batchnorm_with_lower_fp_inputs) { - grad_input = xla::ConvertElementType( - grad_input, ShapeHelper::ShapeOfXlaOp(input).element_type()); + if (is_batchnorm_with_fp16_inputs) { + grad_input = xla::ConvertElementType(grad_input, xla::PrimitiveType::F16); } return {grad_input, grad_weight, grad_bias}; }