From 0309da1fc19fc0abad7d0dd92d541ca0337b756c Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 2 Feb 2021 15:46:07 -0800 Subject: [PATCH] [quant] Support 2 dim input in quantized batchnorm 1d Summary: aliging quantized batchnorm behavior with fp batchnorm Test Plan: python test/test_quantization.py TestQuantizedOps.test_batch_norm python test/test_quantization.py TestQuantizedOps.test_batch_norm_relu Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- .../ATen/native/quantized/cpu/qbatch_norm.cpp | 17 +++++++++++------ test/quantization/test_quantized_op.py | 12 ++++++------ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp b/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp index b053940abba2..dc9f4b8e0eb3 100644 --- a/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp +++ b/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp @@ -60,10 +60,10 @@ Tensor q_batch_norm1d_impl( return out; } int64_t ndim = qx.dim(); - TORCH_CHECK(ndim == 3, "Expecting the input tensor of rank 3."); + TORCH_CHECK(ndim == 2 || ndim == 3, "Expecting the input tensor of rank 2 or 3."); const int64_t N = qx.size(0); const int64_t C = qx.size(1); - const int64_t H = qx.size(2); + const int64_t H = ndim == 3 ? qx.size(2) : 1; TORCH_CHECK(weight.numel() == C, "Expect weight size to match C"); TORCH_CHECK(bias.numel() == C, "Expect weight size to match C"); @@ -82,8 +82,13 @@ Tensor q_batch_norm1d_impl( const float* mean_data = mean.template data_ptr(); const float* var_data = var.template data_ptr(); - // create a fake W dimension so we can use NHWC - qx = qx.unsqueeze(-1); + if (ndim == 2) { + // create a fake H and W dimension so we can use NHWC + qx = qx.unsqueeze(-1).unsqueeze(-1); + } else { + // create a fake W dimension so we can use NHWC + qx = qx.unsqueeze(-1); + } auto oSizes = qx.sizes(); auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast); @@ -341,7 +346,7 @@ Tensor q_batch_norm_impl( int64_t output_zero_point) { Tensor qy; int64_t dim = qx.dim(); - if (dim == 3) { + if (dim == 2 || dim == 3) { qy = q_batch_norm1d_impl( qx, mb_weight, mb_bias, mean, var, eps, output_scale, output_zero_point); } else if (dim == 4) { @@ -351,7 +356,7 @@ Tensor q_batch_norm_impl( qy = q_batch_norm3d_impl( qx, mb_weight, mb_bias, mean, var, eps, output_scale, output_zero_point); } else { - TORCH_CHECK(false, "quantized::batch_norm only support 3d, 4d or 5d inputs."); + TORCH_CHECK(false, "quantized::batch_norm only support 2d, 3d, 4d or 5d inputs."); } return qy; } diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index bc68d722e342..ea8bdd135bfa 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -2090,7 +2090,7 @@ def test_instance_norm(self): @skipIfNoFBGEMM def test_batch_norm_relu(self): # hypothesis too slow for this test, create test cases manually - max_sides = (3, 4, 5) + max_sides = (2, 3, 4, 5) side_lens = (1, 8, 11) torch_types = (torch.qint8, torch.quint8) combined = [max_sides, side_lens, torch_types] @@ -2114,7 +2114,7 @@ def test_batch_norm_relu(self): bias = torch.rand(c).float() eps = 0.001 qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x) - if len(X.shape) == 3: + if len(X.shape) == 2 or len(X.shape) == 3: qy = torch.ops.quantized.batch_norm1d_relu( qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point) elif len(X.shape) == 4: @@ -2141,7 +2141,7 @@ def test_batch_norm_relu(self): @skipIfNoFBGEMM def test_batch_norm(self): # hypothesis too slow for this test, create test cases manually - max_sides = (3, 4, 5) + max_sides = (2, 3, 4, 5) side_lens = (1, 8, 11) torch_types = (torch.qint8, torch.quint8) combined = [max_sides, side_lens, torch_types] @@ -2165,13 +2165,13 @@ def test_batch_norm(self): bias = torch.rand(c).float() eps = 0.001 qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x) - if len(X.shape) == 3: + if len(X.shape) == 2 or len(X.shape) == 3: qy = torch.ops.quantized.batch_norm1d( qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point) - if len(X.shape) == 4: + elif len(X.shape) == 4: qy = torch.ops.quantized.batch_norm2d( qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point) - if len(X.shape) == 5: + elif len(X.shape) == 5: qy = torch.ops.quantized.batch_norm3d( qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)