88namespace torch_xla {
99namespace {
1010
11- bool IsF32BatchNormWithLowerFPInputs (const xla::XlaOp& input,
12- const xla::XlaOp& weight) {
13- static constexpr std::array<xla::PrimitiveType, 9 > lowerPrecistionTypes = {
14- xla::PrimitiveType::F8E5M2, xla::PrimitiveType::F8E4M3,
15- xla::PrimitiveType::F8E4M3FN, xla::PrimitiveType::F8E4M3B11FNUZ,
16- xla::PrimitiveType::F8E3M4, xla::PrimitiveType::F8E5M2FNUZ,
17- xla::PrimitiveType::F8E4M3FNUZ, xla::PrimitiveType::F16,
18- xla::PrimitiveType::BF16};
19- if (std::find (lowerPrecistionTypes.begin (), lowerPrecistionTypes.end (),
20- ShapeHelper::ShapeOfXlaOp (input).element_type ()) !=
21- lowerPrecistionTypes.end () &&
11+ bool IsF32BatchNormWithFP16Inputs (const xla::XlaOp& input,
12+ const xla::XlaOp& weight) {
13+ if (ShapeHelper::ShapeOfXlaOp (input).element_type () ==
14+ xla::PrimitiveType::F16 &&
2215 ShapeHelper::ShapeOfXlaOp (weight).element_type () ==
2316 xla::PrimitiveType::F32) {
2417 return true ;
@@ -46,39 +39,37 @@ xla::XlaOp BatchNormVarianceInvert(xla::XlaOp variance, float eps_value) {
4639
4740BatchNormOutput BuildBatchNormTraining (xla::XlaOp input, xla::XlaOp weight,
4841 xla::XlaOp bias, float eps_value) {
49- bool is_batchnorm_with_lower_fp_inputs =
50- IsF32BatchNormWithLowerFPInputs (input, weight);
42+ bool is_batchnorm_with_fp16_inputs =
43+ IsF32BatchNormWithFP16Inputs (input, weight);
5144 // Handle the mixed precision use case.
52- if (is_batchnorm_with_lower_fp_inputs ) {
45+ if (is_batchnorm_with_fp16_inputs ) {
5346 input = xla::ConvertElementType (input, xla::PrimitiveType::F32);
5447 }
5548 xla::XlaOp outputs = xla::BatchNormTraining (input, weight, bias, eps_value,
5649 /* feature_index=*/ 1 );
5750 xla::XlaOp output = xla::GetTupleElement (outputs, 0 );
5851 xla::XlaOp batch_mean = xla::GetTupleElement (outputs, 1 );
5952 xla::XlaOp batch_variance = xla::GetTupleElement (outputs, 2 );
60- if (is_batchnorm_with_lower_fp_inputs) {
61- output = xla::ConvertElementType (
62- output, ShapeHelper::ShapeOfXlaOp (input).element_type ());
53+ if (is_batchnorm_with_fp16_inputs) {
54+ output = xla::ConvertElementType (output, xla::PrimitiveType::F16);
6355 }
6456 return {output, batch_mean, batch_variance};
6557}
6658
6759xla::XlaOp BuildBatchNormInference (xla::XlaOp input, xla::XlaOp weight,
6860 xla::XlaOp bias, xla::XlaOp mean,
6961 xla::XlaOp variance, float eps_value) {
70- bool is_batchnorm_with_lower_fp_inputs =
71- IsF32BatchNormWithLowerFPInputs (input, weight);
62+ bool is_batchnorm_with_fp16_inputs =
63+ IsF32BatchNormWithFP16Inputs (input, weight);
7264 // Handle the mixed precision use case.
73- if (is_batchnorm_with_lower_fp_inputs ) {
65+ if (is_batchnorm_with_fp16_inputs ) {
7466 input = xla::ConvertElementType (input, xla::PrimitiveType::F32);
7567 }
7668 xla::XlaOp output =
7769 xla::BatchNormInference (input, weight, bias, mean, variance, eps_value,
7870 /* feature_index=*/ 1 );
79- if (is_batchnorm_with_lower_fp_inputs) {
80- output = xla::ConvertElementType (
81- output, ShapeHelper::ShapeOfXlaOp (input).element_type ());
71+ if (is_batchnorm_with_fp16_inputs) {
72+ output = xla::ConvertElementType (output, xla::PrimitiveType::F16);
8273 }
8374 return output;
8475}
@@ -87,10 +78,10 @@ BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input,
8778 xla::XlaOp weight, xla::XlaOp save_mean,
8879 xla::XlaOp save_invstd, bool training,
8980 float eps_value) {
90- bool is_batchnorm_with_lower_fp_inputs =
91- IsF32BatchNormWithLowerFPInputs (input, weight);
81+ bool is_batchnorm_with_fp16_inputs =
82+ IsF32BatchNormWithFP16Inputs (input, weight);
9283 // Handle the mixed precision use case.
93- if (is_batchnorm_with_lower_fp_inputs ) {
84+ if (is_batchnorm_with_fp16_inputs ) {
9485 input = xla::ConvertElementType (input, xla::PrimitiveType::F32);
9586 grad = xla::ConvertElementType (grad, xla::PrimitiveType::F32);
9687 }
@@ -100,9 +91,8 @@ BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input,
10091 xla::XlaOp grad_input = xla::GetTupleElement (grads, 0 );
10192 xla::XlaOp grad_weight = xla::GetTupleElement (grads, 1 );
10293 xla::XlaOp grad_bias = xla::GetTupleElement (grads, 2 );
103- if (is_batchnorm_with_lower_fp_inputs) {
104- grad_input = xla::ConvertElementType (
105- grad_input, ShapeHelper::ShapeOfXlaOp (input).element_type ());
94+ if (is_batchnorm_with_fp16_inputs) {
95+ grad_input = xla::ConvertElementType (grad_input, xla::PrimitiveType::F16);
10696 }
10797 return {grad_input, grad_weight, grad_bias};
10898}
0 commit comments