Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant] Support 2 dim input in quantized batchnorm 1d #51597

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 11 additions & 6 deletions aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp
Expand Up @@ -60,10 +60,10 @@ Tensor q_batch_norm1d_impl(
return out;
}
int64_t ndim = qx.dim();
TORCH_CHECK(ndim == 3, "Expecting the input tensor of rank 3.");
TORCH_CHECK(ndim == 2 || ndim == 3, "Expecting the input tensor of rank 2 or 3.");
const int64_t N = qx.size(0);
const int64_t C = qx.size(1);
const int64_t H = qx.size(2);
const int64_t H = ndim == 3 ? qx.size(2) : 1;

TORCH_CHECK(weight.numel() == C, "Expect weight size to match C");
TORCH_CHECK(bias.numel() == C, "Expect weight size to match C");
Expand All @@ -82,8 +82,13 @@ Tensor q_batch_norm1d_impl(
const float* mean_data = mean.template data_ptr<float>();
const float* var_data = var.template data_ptr<float>();

// create a fake W dimension so we can use NHWC
qx = qx.unsqueeze(-1);
if (ndim == 2) {
// create a fake H and W dimension so we can use NHWC
qx = qx.unsqueeze(-1).unsqueeze(-1);
} else {
// create a fake W dimension so we can use NHWC
qx = qx.unsqueeze(-1);
}

auto oSizes = qx.sizes();
auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast);
Expand Down Expand Up @@ -341,7 +346,7 @@ Tensor q_batch_norm_impl(
int64_t output_zero_point) {
Tensor qy;
int64_t dim = qx.dim();
if (dim == 3) {
if (dim == 2 || dim == 3) {
qy = q_batch_norm1d_impl<ReluFused>(
qx, mb_weight, mb_bias, mean, var, eps, output_scale, output_zero_point);
} else if (dim == 4) {
Expand All @@ -351,7 +356,7 @@ Tensor q_batch_norm_impl(
qy = q_batch_norm3d_impl<ReluFused>(
qx, mb_weight, mb_bias, mean, var, eps, output_scale, output_zero_point);
} else {
TORCH_CHECK(false, "quantized::batch_norm only support 3d, 4d or 5d inputs.");
TORCH_CHECK(false, "quantized::batch_norm only support 2d, 3d, 4d or 5d inputs.");
}
return qy;
}
Expand Down
12 changes: 6 additions & 6 deletions test/quantization/test_quantized_op.py
Expand Up @@ -2090,7 +2090,7 @@ def test_instance_norm(self):
@skipIfNoFBGEMM
def test_batch_norm_relu(self):
# hypothesis too slow for this test, create test cases manually
max_sides = (3, 4, 5)
max_sides = (2, 3, 4, 5)
side_lens = (1, 8, 11)
torch_types = (torch.qint8, torch.quint8)
combined = [max_sides, side_lens, torch_types]
Expand All @@ -2114,7 +2114,7 @@ def test_batch_norm_relu(self):
bias = torch.rand(c).float()
eps = 0.001
qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x)
if len(X.shape) == 3:
if len(X.shape) == 2 or len(X.shape) == 3:
qy = torch.ops.quantized.batch_norm1d_relu(
qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
elif len(X.shape) == 4:
Expand All @@ -2141,7 +2141,7 @@ def test_batch_norm_relu(self):
@skipIfNoFBGEMM
def test_batch_norm(self):
# hypothesis too slow for this test, create test cases manually
max_sides = (3, 4, 5)
max_sides = (2, 3, 4, 5)
side_lens = (1, 8, 11)
torch_types = (torch.qint8, torch.quint8)
combined = [max_sides, side_lens, torch_types]
Expand All @@ -2165,13 +2165,13 @@ def test_batch_norm(self):
bias = torch.rand(c).float()
eps = 0.001
qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x)
if len(X.shape) == 3:
if len(X.shape) == 2 or len(X.shape) == 3:
qy = torch.ops.quantized.batch_norm1d(
qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
if len(X.shape) == 4:
elif len(X.shape) == 4:
qy = torch.ops.quantized.batch_norm2d(
qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
if len(X.shape) == 5:
elif len(X.shape) == 5:
qy = torch.ops.quantized.batch_norm3d(
qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)

Expand Down