Skip to content

Commit

Permalink
add qnnpack path for hardtanh
Browse files Browse the repository at this point in the history
Summary:

Adds a QNNPack path for the clamp kernel, which is useful for
hardtanh.

Test Plan:

python test/test_quantized.py TestQNNPackOps.test_hardtanh

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 1dc64fc6d4fa4b685602e9f35ff7eb69701714fe
Pull Request resolved: #35779
  • Loading branch information
vkuzo committed Apr 20, 2020
1 parent 30e7055 commit 493fcfe
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 25 deletions.
65 changes: 65 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qclamp.cpp
Expand Up @@ -4,7 +4,10 @@
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
#include <ATen/native/quantized/cpu/quantized_ops.h>
#include <ATen/native/quantized/cpu/init_qnnpack.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/quantized/Quantizer.h>
#include <caffe2/utils/threadpool/ThreadPoolMobile.h>

#include <algorithm>

Expand All @@ -14,12 +17,74 @@ namespace native {
DEFINE_DISPATCH(qclamp_stub);

namespace {

#ifdef USE_PYTORCH_QNNPACK
Tensor qnnpack_clamp(Tensor input, Scalar min, Scalar max) {

TORCH_CHECK(input.ndimension() > 0, "qnnpack_clamp(): Got empty input tensor");

initQNNPACK();

Tensor input_contig = input.contiguous(input.suggest_memory_format());
size_t num_elems = input_contig.numel() / input_contig.size(0);

auto min_f = min.to<float>();
auto max_f = max.to<float>();
uint8_t min_q =
at::quantize_val<quint8>(input.q_scale(), input.q_zero_point(), min_f).val_;
uint8_t max_q =
at::quantize_val<quint8>(input.q_scale(), input.q_zero_point(), max_f).val_;

pytorch_qnnp_operator_t clamp_op{nullptr};
const pytorch_qnnp_status createStatus = pytorch_qnnp_create_clamp_nc_u8(
num_elems, // channels
min_q,
max_q,
0, // flags
&clamp_op);
TORCH_INTERNAL_ASSERT(createStatus == pytorch_qnnp_status_success,
"failed to create QNNPACK Clamp operator");

Tensor qy = at::_empty_affine_quantized(
input_contig.sizes(),
input_contig.options(),
input_contig.q_scale(),
input_contig.q_zero_point());

const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_clamp_nc_u8(
clamp_op,
input_contig.size(0), // batch_size
(uint8_t*)input_contig.data_ptr<c10::quint8>(), // input_data
num_elems, // input_stride
(uint8_t*)qy.data_ptr<c10::quint8>(), // output_data
num_elems); // output_stride
TORCH_INTERNAL_ASSERT(setupStatus == pytorch_qnnp_status_success,
"failed to setup QNNPACK Clamp operator");

pthreadpool_t threadpool = caffe2::mobile_pthreadpool();

const pytorch_qnnp_status runStatus =
pytorch_qnnp_run_operator(clamp_op, threadpool);

TORCH_INTERNAL_ASSERT(
runStatus == pytorch_qnnp_status_success,
"failed to run QNNPACK Clamp operator");
return qy;
}

#endif // USE_PYTORCH_QNNPACK

Tensor quantized_clamp_impl(
const Tensor& qx,
optional<Scalar> min,
optional<Scalar> max) {
Tensor qy;
if (min && max) {
#ifdef USE_PYTORCH_QNNPACK
if (at::globalContext().qEngine() == at::QEngine::QNNPACK && qx.scalar_type() == kQUInt8) {
return qnnpack_clamp(qx, *min, *max);
}
#endif
qclamp_stub(qx.device().type(), qx, *min, *max, qy);
} else {
TORCH_CHECK(
Expand Down
77 changes: 52 additions & 25 deletions test/quantization/test_quantized.py
Expand Up @@ -456,36 +456,37 @@ def test_qclamp(self, X, min_val, max_val):
min_val=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False),
max_val=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False))
def test_hardtanh(self, X, min_val, max_val):
X, (scale, zero_point, torch_type) = X
with override_quantized_engine('fbgemm'):
X, (scale, zero_point, torch_type) = X

assume(min_val <= max_val)
Y = X.copy()
Y[Y < min_val] = min_val
Y[Y > max_val] = max_val
qY = torch.quantize_per_tensor(torch.from_numpy(Y), scale=scale,
zero_point=zero_point, dtype=torch_type)
X = torch.from_numpy(X)
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type)
assume(min_val <= max_val)
Y = X.copy()
Y[Y < min_val] = min_val
Y[Y > max_val] = max_val
qY = torch.quantize_per_tensor(torch.from_numpy(Y), scale=scale,
zero_point=zero_point, dtype=torch_type)
X = torch.from_numpy(X)
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type)

ops_under_test = {
'nn.quantized.functional.hardtanh':
torch.nn.quantized.functional.hardtanh,
}
ops_under_test = {
'nn.quantized.functional.hardtanh':
torch.nn.quantized.functional.hardtanh,
}

for name, op in ops_under_test.items():
qY_hat = op(qX, min_val, max_val)
self.assertEqual(qY, qY_hat, message="{} hardtanh failed".format(name))
for name, op in ops_under_test.items():
qY_hat = op(qX, min_val, max_val)
self.assertEqual(qY, qY_hat, message="{} hardtanh failed".format(name))

ops_under_test_inplace = {
'inplace nn.quantized.functional.hardtanh':
torch.nn.quantized.functional.hardtanh,
}
ops_under_test_inplace = {
'inplace nn.quantized.functional.hardtanh':
torch.nn.quantized.functional.hardtanh,
}

for name, op_ in ops_under_test_inplace.items():
qY_hat = qX.clone()
op_(qY_hat, min_val, max_val, inplace=True)
self.assertEqual(qY, qY_hat, message="{} hardtanh failed".format(name))
for name, op_ in ops_under_test_inplace.items():
qY_hat = qX.clone()
op_(qY_hat, min_val, max_val, inplace=True)
self.assertEqual(qY, qY_hat, message="{} hardtanh failed".format(name))

"""Tests the correctness of the quantized::hardswish op."""
@given(X=hu.tensor(shapes=hu.array_shapes(1, 8, 1, 8),
Expand Down Expand Up @@ -2817,6 +2818,32 @@ def test_mean(self, batch_size, channels, height, width, scale, zero_point):
def test_hardswish(self, X, Y_scale, Y_zero_point):
_test_hardswish(self, X, Y_scale, Y_zero_point, 'qnnpack')

"""Tests the correctness of the quantized::hardtanh op."""
@given(X=hu.tensor(shapes=hu.array_shapes(1, 8, 1, 8),
elements=hu.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False),
qparams=hu.qparams(dtypes=torch.quint8)),
min_val=hu.floats(-1e6, -9.999999974752427e-07, allow_nan=False, allow_infinity=False),
max_val=hu.floats(9.999999974752427e-07, 1e6, allow_nan=False, allow_infinity=False))
def test_hardtanh(self, X, min_val, max_val):
with override_quantized_engine('qnnpack'):
X, (scale, zero_point, torch_type) = X

assume(min_val <= max_val)
Y = X.copy()
Y[Y < min_val] = min_val
Y[Y > max_val] = max_val
qY = torch.quantize_per_tensor(torch.from_numpy(Y), scale=scale,
zero_point=zero_point, dtype=torch_type)
X = torch.from_numpy(X)
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type)

qY_hat = torch.nn.quantized.functional.hardtanh(qX, min_val, max_val)
self.assertEqual(
qY, qY_hat,
message="hardtanh failed:\nactual {}\nexpected {}".format(qY_hat, qY))


"""Tests the correctness of the tensor comparators."""
class TestComparatorOps(TestCase):
"""Tests the element-wise equality ops."""
Expand Down

0 comments on commit 493fcfe

Please sign in to comment.