Skip to content

Commit

Permalink
[quant] torch.mean add path for unsupported QNNPACK modes
Browse files Browse the repository at this point in the history
ghstack-source-id: 892ced57ae1a4300388aa4e3f2e0e32dbee9822d
Pull Request resolved: #45533
  • Loading branch information
z-a-f committed Sep 29, 2020
1 parent 1ed1a2f commit bdc2d84
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
9 changes: 8 additions & 1 deletion aten/src/ATen/native/quantized/cpu/qreduction.cpp
Expand Up @@ -83,7 +83,14 @@ Tensor& mean_out_quantized_cpu(
c10::optional<ScalarType> opt_dtype) {
#ifdef USE_PYTORCH_QNNPACK
if (at::globalContext().qEngine() == at::QEngine::QNNPACK &&
self.scalar_type() == kQUInt8) {
self.scalar_type() == kQUInt8 &&
// QNNPACK currently is only supported for NCHW + dim=(2, 3)
// Remove these checks after generic version is implemented.
self.ndimension() == 4 &&
dim.size() == 2 &&
dim[0] == 2 &&
dim[1] == 3
){
result = qnnpack_mean(self, dim);
return result;
}
Expand Down
8 changes: 5 additions & 3 deletions test/quantization/test_quantized_op.py
Expand Up @@ -1726,12 +1726,14 @@ def test_cat_nhwc(self, X, relu):
torch.testing.assert_allclose(out.dequantize(), ref.dequantize())
self.assertNotEqual(out.stride(), sorted(out.stride()))

@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=3,
min_side=1, max_side=2),
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=1, max_dims=5,
min_side=1, max_side=4),
qparams=hu.qparams()),
dim=st.integers(1, 2))
dim=st.integers(-1, 5))
@override_qengines
def test_mean(self, X, dim):
X, (scale, zero_point, torch_type) = X
assume(dim < X.ndim)
qX = torch.quantize_per_tensor(torch.tensor(X).float(), scale, zero_point, torch_type)

Y = torch.mean(qX.dequantize(), dim)
Expand Down

0 comments on commit bdc2d84

Please sign in to comment.