Skip to content

Commit

Permalink
[quant] Support 2 dim input in quantized batchnorm 1d
Browse files Browse the repository at this point in the history
Summary:
aliging quantized batchnorm behavior with fp batchnorm

Test Plan:
python test/test_quantization.py TestQuantizedOps.test_batch_norm
python test/test_quantization.py TestQuantizedOps.test_batch_norm_relu

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Feb 2, 2021
1 parent 109bc10 commit 0309da1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
17 changes: 11 additions & 6 deletions aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp
Expand Up @@ -60,10 +60,10 @@ Tensor q_batch_norm1d_impl(
return out;
}
int64_t ndim = qx.dim();
TORCH_CHECK(ndim == 3, "Expecting the input tensor of rank 3.");
TORCH_CHECK(ndim == 2 || ndim == 3, "Expecting the input tensor of rank 2 or 3.");
const int64_t N = qx.size(0);
const int64_t C = qx.size(1);
const int64_t H = qx.size(2);
const int64_t H = ndim == 3 ? qx.size(2) : 1;

TORCH_CHECK(weight.numel() == C, "Expect weight size to match C");
TORCH_CHECK(bias.numel() == C, "Expect weight size to match C");
Expand All @@ -82,8 +82,13 @@ Tensor q_batch_norm1d_impl(
const float* mean_data = mean.template data_ptr<float>();
const float* var_data = var.template data_ptr<float>();

// create a fake W dimension so we can use NHWC
qx = qx.unsqueeze(-1);
if (ndim == 2) {
// create a fake H and W dimension so we can use NHWC
qx = qx.unsqueeze(-1).unsqueeze(-1);
} else {
// create a fake W dimension so we can use NHWC
qx = qx.unsqueeze(-1);
}

auto oSizes = qx.sizes();
auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast);
Expand Down Expand Up @@ -341,7 +346,7 @@ Tensor q_batch_norm_impl(
int64_t output_zero_point) {
Tensor qy;
int64_t dim = qx.dim();
if (dim == 3) {
if (dim == 2 || dim == 3) {
qy = q_batch_norm1d_impl<ReluFused>(
qx, mb_weight, mb_bias, mean, var, eps, output_scale, output_zero_point);
} else if (dim == 4) {
Expand All @@ -351,7 +356,7 @@ Tensor q_batch_norm_impl(
qy = q_batch_norm3d_impl<ReluFused>(
qx, mb_weight, mb_bias, mean, var, eps, output_scale, output_zero_point);
} else {
TORCH_CHECK(false, "quantized::batch_norm only support 3d, 4d or 5d inputs.");
TORCH_CHECK(false, "quantized::batch_norm only support 2d, 3d, 4d or 5d inputs.");
}
return qy;
}
Expand Down
12 changes: 6 additions & 6 deletions test/quantization/test_quantized_op.py
Expand Up @@ -2090,7 +2090,7 @@ def test_instance_norm(self):
@skipIfNoFBGEMM
def test_batch_norm_relu(self):
# hypothesis too slow for this test, create test cases manually
max_sides = (3, 4, 5)
max_sides = (2, 3, 4, 5)
side_lens = (1, 8, 11)
torch_types = (torch.qint8, torch.quint8)
combined = [max_sides, side_lens, torch_types]
Expand All @@ -2114,7 +2114,7 @@ def test_batch_norm_relu(self):
bias = torch.rand(c).float()
eps = 0.001
qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x)
if len(X.shape) == 3:
if len(X.shape) == 2 or len(X.shape) == 3:
qy = torch.ops.quantized.batch_norm1d_relu(
qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
elif len(X.shape) == 4:
Expand All @@ -2141,7 +2141,7 @@ def test_batch_norm_relu(self):
@skipIfNoFBGEMM
def test_batch_norm(self):
# hypothesis too slow for this test, create test cases manually
max_sides = (3, 4, 5)
max_sides = (2, 3, 4, 5)
side_lens = (1, 8, 11)
torch_types = (torch.qint8, torch.quint8)
combined = [max_sides, side_lens, torch_types]
Expand All @@ -2165,13 +2165,13 @@ def test_batch_norm(self):
bias = torch.rand(c).float()
eps = 0.001
qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x)
if len(X.shape) == 3:
if len(X.shape) == 2 or len(X.shape) == 3:
qy = torch.ops.quantized.batch_norm1d(
qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
if len(X.shape) == 4:
elif len(X.shape) == 4:
qy = torch.ops.quantized.batch_norm2d(
qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
if len(X.shape) == 5:
elif len(X.shape) == 5:
qy = torch.ops.quantized.batch_norm3d(
qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)

Expand Down

0 comments on commit 0309da1

Please sign in to comment.