Skip to content

Commit 10cee98

Browse files
tengyifeiqihqi
authored andcommitted
Revert "fix batch_norm amp autocast" (#8547)
1 parent 2041bef commit 10cee98

File tree

2 files changed

+19
-43
lines changed

2 files changed

+19
-43
lines changed

test/test_autocast.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -484,20 +484,6 @@ def test_autocast_tpu_check_dtype(self):
484484
assert not torch.is_autocast_xla_enabled()
485485

486486

487-
class TestOtherOps(unittest.TestCase):
488-
489-
@unittest.skipIf(not (xm.get_xla_supported_devices("TPU") or
490-
xm.get_xla_supported_devices("GPU")),
491-
f"bfloat16 is only enabled for TPU and GPU")
492-
def test_batch_norm(self):
493-
device = xm.xla_device()
494-
data = torch.randn(4, 16, 32, 32, device=device, dtype=torch.bfloat16)
495-
with autocast(device, dtype=torch.bfloat16):
496-
output = torch.nn.BatchNorm2d(16)(data)
497-
xm.mark_step()
498-
self.assertEqual(output.dtype, torch.bfloat16)
499-
500-
501487
if __name__ == "__main__":
502488
test = unittest.main(verbosity=FLAGS.verbosity, exit=False)
503489
sys.exit(0 if test.result.wasSuccessful() else 1)

torch_xla/csrc/batch_norm.cpp

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,10 @@
88
namespace torch_xla {
99
namespace {
1010

11-
bool IsF32BatchNormWithLowerFPInputs(const xla::XlaOp& input,
12-
const xla::XlaOp& weight) {
13-
static constexpr std::array<xla::PrimitiveType, 9> lowerPrecistionTypes = {
14-
xla::PrimitiveType::F8E5M2, xla::PrimitiveType::F8E4M3,
15-
xla::PrimitiveType::F8E4M3FN, xla::PrimitiveType::F8E4M3B11FNUZ,
16-
xla::PrimitiveType::F8E3M4, xla::PrimitiveType::F8E5M2FNUZ,
17-
xla::PrimitiveType::F8E4M3FNUZ, xla::PrimitiveType::F16,
18-
xla::PrimitiveType::BF16};
19-
if (std::find(lowerPrecistionTypes.begin(), lowerPrecistionTypes.end(),
20-
ShapeHelper::ShapeOfXlaOp(input).element_type()) !=
21-
lowerPrecistionTypes.end() &&
11+
bool IsF32BatchNormWithFP16Inputs(const xla::XlaOp& input,
12+
const xla::XlaOp& weight) {
13+
if (ShapeHelper::ShapeOfXlaOp(input).element_type() ==
14+
xla::PrimitiveType::F16 &&
2215
ShapeHelper::ShapeOfXlaOp(weight).element_type() ==
2316
xla::PrimitiveType::F32) {
2417
return true;
@@ -46,39 +39,37 @@ xla::XlaOp BatchNormVarianceInvert(xla::XlaOp variance, float eps_value) {
4639

4740
BatchNormOutput BuildBatchNormTraining(xla::XlaOp input, xla::XlaOp weight,
4841
xla::XlaOp bias, float eps_value) {
49-
bool is_batchnorm_with_lower_fp_inputs =
50-
IsF32BatchNormWithLowerFPInputs(input, weight);
42+
bool is_batchnorm_with_fp16_inputs =
43+
IsF32BatchNormWithFP16Inputs(input, weight);
5144
// Handle the mixed precision use case.
52-
if (is_batchnorm_with_lower_fp_inputs) {
45+
if (is_batchnorm_with_fp16_inputs) {
5346
input = xla::ConvertElementType(input, xla::PrimitiveType::F32);
5447
}
5548
xla::XlaOp outputs = xla::BatchNormTraining(input, weight, bias, eps_value,
5649
/*feature_index=*/1);
5750
xla::XlaOp output = xla::GetTupleElement(outputs, 0);
5851
xla::XlaOp batch_mean = xla::GetTupleElement(outputs, 1);
5952
xla::XlaOp batch_variance = xla::GetTupleElement(outputs, 2);
60-
if (is_batchnorm_with_lower_fp_inputs) {
61-
output = xla::ConvertElementType(
62-
output, ShapeHelper::ShapeOfXlaOp(input).element_type());
53+
if (is_batchnorm_with_fp16_inputs) {
54+
output = xla::ConvertElementType(output, xla::PrimitiveType::F16);
6355
}
6456
return {output, batch_mean, batch_variance};
6557
}
6658

6759
xla::XlaOp BuildBatchNormInference(xla::XlaOp input, xla::XlaOp weight,
6860
xla::XlaOp bias, xla::XlaOp mean,
6961
xla::XlaOp variance, float eps_value) {
70-
bool is_batchnorm_with_lower_fp_inputs =
71-
IsF32BatchNormWithLowerFPInputs(input, weight);
62+
bool is_batchnorm_with_fp16_inputs =
63+
IsF32BatchNormWithFP16Inputs(input, weight);
7264
// Handle the mixed precision use case.
73-
if (is_batchnorm_with_lower_fp_inputs) {
65+
if (is_batchnorm_with_fp16_inputs) {
7466
input = xla::ConvertElementType(input, xla::PrimitiveType::F32);
7567
}
7668
xla::XlaOp output =
7769
xla::BatchNormInference(input, weight, bias, mean, variance, eps_value,
7870
/*feature_index=*/1);
79-
if (is_batchnorm_with_lower_fp_inputs) {
80-
output = xla::ConvertElementType(
81-
output, ShapeHelper::ShapeOfXlaOp(input).element_type());
71+
if (is_batchnorm_with_fp16_inputs) {
72+
output = xla::ConvertElementType(output, xla::PrimitiveType::F16);
8273
}
8374
return output;
8475
}
@@ -87,10 +78,10 @@ BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input,
8778
xla::XlaOp weight, xla::XlaOp save_mean,
8879
xla::XlaOp save_invstd, bool training,
8980
float eps_value) {
90-
bool is_batchnorm_with_lower_fp_inputs =
91-
IsF32BatchNormWithLowerFPInputs(input, weight);
81+
bool is_batchnorm_with_fp16_inputs =
82+
IsF32BatchNormWithFP16Inputs(input, weight);
9283
// Handle the mixed precision use case.
93-
if (is_batchnorm_with_lower_fp_inputs) {
84+
if (is_batchnorm_with_fp16_inputs) {
9485
input = xla::ConvertElementType(input, xla::PrimitiveType::F32);
9586
grad = xla::ConvertElementType(grad, xla::PrimitiveType::F32);
9687
}
@@ -100,9 +91,8 @@ BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input,
10091
xla::XlaOp grad_input = xla::GetTupleElement(grads, 0);
10192
xla::XlaOp grad_weight = xla::GetTupleElement(grads, 1);
10293
xla::XlaOp grad_bias = xla::GetTupleElement(grads, 2);
103-
if (is_batchnorm_with_lower_fp_inputs) {
104-
grad_input = xla::ConvertElementType(
105-
grad_input, ShapeHelper::ShapeOfXlaOp(input).element_type());
94+
if (is_batchnorm_with_fp16_inputs) {
95+
grad_input = xla::ConvertElementType(grad_input, xla::PrimitiveType::F16);
10696
}
10797
return {grad_input, grad_weight, grad_bias};
10898
}

0 commit comments

Comments
 (0)