Skip to content

Commit

Permalink
[pt][quant] Support either min or max in qclamp (#45937)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #45937

torch.clamp can now be used with quantized tensors with either min argument or max argument only

Fixes #45928
ghstack-source-id: 114085914

Test Plan:
buck test mode/dev caffe2/test:quantization -- 'test_qclamp'  --print-passing-details
```
Started reporting to test run: https://our.intern.facebook.com/intern/testinfra/testrun/4222124686876909
    ✓ ListingSuccess: caffe2/test:quantization - main (7.602)
    ✓ Pass: caffe2/test:quantization - test_qclamp (quantization.test_quantized_op.TestQuantizedOps) (7.233)
Summary
  Pass: 1
  ListingSuccess: 1
Finished test run: https://our.intern.facebook.com/intern/testinfra/testrun/4222124686876909
```

Reviewed By: jerryzh168

Differential Revision: D24153431

fbshipit-source-id: 9735635a48bcdd88d1dd6dc2f18b59311d45ad90
  • Loading branch information
dskhudia authored and facebook-github-bot committed Oct 12, 2020
1 parent bed3b40 commit 87a4baf
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 12 deletions.
52 changes: 52 additions & 0 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,56 @@ void qclamp_kernel(
});
}

void qclamp_min_kernel(const Tensor& qx, Scalar min_scalar, Tensor& qy) {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qclamp", [&]() {
qy = at::_empty_affine_quantized(
qx.sizes(),
at::device(kCPU)
.dtype(SCALAR_TYPE)
.memory_format(qx.suggest_memory_format()),
qx.q_scale(),
qx.q_zero_point(),
c10::nullopt);
using Vec = Vec256<scalar_t>;
auto iter = TensorIterator::unary_op(qy, qx);
auto min = min_scalar.to<float>();
scalar_t min_q = at::native::quantize_val<scalar_t>(
qx.q_scale(), qx.q_zero_point(), min);
auto min_vec = Vec(min_q);
cpu_kernel_vec(
iter,
[&](scalar_t value) -> scalar_t {
return scalar_t(std::max<underlying_t>(value.val_, min_q.val_));
},
[&](Vec val) -> Vec { return val.maximum(min_vec); });
});
}

void qclamp_max_kernel(const Tensor& qx, Scalar max_scalar, Tensor& qy) {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qclamp", [&]() {
qy = at::_empty_affine_quantized(
qx.sizes(),
at::device(kCPU)
.dtype(SCALAR_TYPE)
.memory_format(qx.suggest_memory_format()),
qx.q_scale(),
qx.q_zero_point(),
c10::nullopt);
using Vec = Vec256<scalar_t>;
auto iter = TensorIterator::unary_op(qy, qx);
auto max = max_scalar.to<float>();
scalar_t max_q = at::native::quantize_val<scalar_t>(
qx.q_scale(), qx.q_zero_point(), max);
auto max_vec = Vec(max_q);
cpu_kernel_vec(
iter,
[&](scalar_t value) -> scalar_t {
return scalar_t(std::min<underlying_t>(value.val_, max_q.val_));
},
[&](Vec val) -> Vec { return val.minimum(max_vec); });
});
}

void qthreshold_kernel(
// TODO: For future tasks, since output quantization parameters are set equal to
// the input ones, it might make sense to implement this completely in the
Expand Down Expand Up @@ -2811,6 +2861,8 @@ REGISTER_DISPATCH(qbatch_norm_stub, &q_batch_norm_kernel<false>);
REGISTER_DISPATCH(qcat_nhwc_stub, &qcat_nhwc_kernel<false>);
REGISTER_DISPATCH(qcat_relu_nhwc_stub, &qcat_nhwc_kernel<true>);
REGISTER_DISPATCH(qclamp_stub, &qclamp_kernel);
REGISTER_DISPATCH(qclamp_min_stub, &qclamp_min_kernel);
REGISTER_DISPATCH(qclamp_max_stub, &qclamp_max_kernel);
REGISTER_DISPATCH(qelu_stub, &qelu_kernel);
REGISTER_DISPATCH(qhardsigmoid_stub, &qhardsigmoid_kernel);
REGISTER_DISPATCH(qhardswish_stub, &qhardswish_kernel);
Expand Down
20 changes: 17 additions & 3 deletions aten/src/ATen/native/quantized/cpu/qclamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace at {
namespace native {

DEFINE_DISPATCH(qclamp_stub);
DEFINE_DISPATCH(qclamp_min_stub);
DEFINE_DISPATCH(qclamp_max_stub);

namespace {

Expand Down Expand Up @@ -84,14 +86,26 @@ Tensor quantized_clamp_impl(
Tensor qy;
if (min && max) {
#ifdef USE_PYTORCH_QNNPACK
if (at::globalContext().qEngine() == at::QEngine::QNNPACK && qx.scalar_type() == kQUInt8) {
if (at::globalContext().qEngine() == at::QEngine::QNNPACK &&
qx.scalar_type() == kQUInt8) {
return qnnpack_clamp(qx, *min, *max);
}
#endif
qclamp_stub(qx.device().type(), qx, *min, *max, qy);
} else {
TORCH_CHECK(
false, "Both min and max should be specified for quantized clamp!");
#ifdef USE_PYTORCH_QNNPACK
if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
TORCH_CHECK(
false, "Both min and max should be specified for quantized clamp!");
}
#endif
if (max) {
qclamp_max_stub(qx.device().type(), qx, *max, qy);
} else if (min) {
qclamp_min_stub(qx.device().type(), qx, *min, qy);
} else {
TORCH_CHECK(false, "At least one of 'min' or 'max' must not be None");
}
}
return qy;
}
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/quantized/cpu/quantized_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ using qclamp_fn = void (*)(
Scalar min,
Scalar max,
at::Tensor& /*qy*/);
using qclamp_minmax_fn = void (*)(
const at::Tensor& /*qx*/,
Scalar /*min or max*/,
at::Tensor& /*qy*/);
using qthreshold_fn = void (*)(
const at::Tensor& /*qx*/,
Scalar threshold,
Expand Down Expand Up @@ -167,6 +171,8 @@ DECLARE_DISPATCH(qbinary_fn, qmul_stub);
DECLARE_DISPATCH(qcat_nhwc_fn, qcat_nhwc_stub);
DECLARE_DISPATCH(qcat_nhwc_fn, qcat_relu_nhwc_stub);
DECLARE_DISPATCH(qclamp_fn, qclamp_stub);
DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_min_stub);
DECLARE_DISPATCH(qclamp_minmax_fn, qclamp_max_stub);
DECLARE_DISPATCH(qelu_fn, qelu_stub);
DECLARE_DISPATCH(qhardsigmoid_fn, qhardsigmoid_stub);
DECLARE_DISPATCH(qhardswish_fn, qhardswish_stub);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/quantized/library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ TORCH_LIBRARY(quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("quantized::batch_norm2d_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::batch_norm3d(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::batch_norm3d_relu(Tensor qx, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::clamp(Tensor qx, Scalar? min, Scalar? max) -> Tensor qy"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::clamp(Tensor qx, Scalar? min=None, Scalar? max=None) -> Tensor qy"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::threshold(Tensor qx, Scalar threshold, Scalar value) -> Tensor qy"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::cat(Tensor[] qx, int dim, float? scale, int? zero_point) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::cat_relu(Tensor[] qx, int dim, float? scale, int? zero_point) -> Tensor"));
Expand Down
31 changes: 23 additions & 8 deletions test/quantization/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,22 +575,37 @@ def test_qclamp(self, X, min_val, max_val):
X, (scale, zero_point, torch_type) = X

assume(min_val <= max_val)
Y = X.copy()
Y[Y < min_val] = min_val
Y[Y > max_val] = max_val
qY = torch.quantize_per_tensor(torch.from_numpy(Y), scale=scale,
zero_point=zero_point, dtype=torch_type)
Y_clamp = torch.clamp(torch.from_numpy(X), min=min_val, max=max_val)
qY_clamp = torch.quantize_per_tensor(Y_clamp, scale=scale,
zero_point=zero_point, dtype=torch_type)

X = torch.from_numpy(X)
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type)

ops_under_test = {
'ops.quantized': torch.ops.quantized.clamp,
}

for name, op in ops_under_test.items():
qY_hat = op(qX, min_val, max_val)
self.assertEqual(qY, qY_hat, msg="{} qclamp failed".format(name))
qY_clamp_hat = op(qX, min=min_val, max=max_val)
self.assertEqual(qY_clamp, qY_clamp_hat, msg="{} qclamp failed".format(name))

if torch.backends.quantized.engine == 'fbgemm':
with override_quantized_engine('fbgemm'):
Y_min_clamp = torch.clamp(X, min=min_val)
Y_max_clamp = torch.clamp(X, max=max_val)

qY_min_clamp = torch.quantize_per_tensor(Y_min_clamp, scale=scale,
zero_point=zero_point, dtype=torch_type)
qY_max_clamp = torch.quantize_per_tensor(Y_max_clamp, scale=scale,
zero_point=zero_point, dtype=torch_type)


for name, op in ops_under_test.items():
qY_min_clamp_hat = op(qX, min=min_val)
self.assertEqual(qY_min_clamp, qY_min_clamp_hat, msg="{} qclamp failed".format(name))
qY_max_clamp_hat = op(qX, max=max_val)
self.assertEqual(qY_max_clamp, qY_max_clamp_hat, msg="{} qclamp failed".format(name))

"""Tests the correctness of the quantized::hardtanh op."""
@skipIfNoFBGEMM
Expand Down

0 comments on commit 87a4baf

Please sign in to comment.