Skip to content

Commit

Permalink
[pt][quant] Optimized qadd_scalar (#34925)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #34925

Optimized path for qadd scalar. qadd_scalar time goes down from 55.840ms for a model to 4.637ms.

### Before
```
  -------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                       Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls
-------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
quantize_per_tensor        0.12%            155.807us        0.12%            155.807us        155.807us        1
quantized::conv2d          25.50%           31.981ms         25.50%           31.981ms         273.343us        117
quantized::add_scalar      44.53%           55.840ms         44.53%           55.840ms         809.281us        69
quantized::relu6           1.25%            1.570ms          1.25%            1.570ms          22.749us         69
quantized::mul_scalar      10.73%           13.449ms         10.73%           13.449ms         194.914us        69
quantized::mul             16.67%           20.904ms         16.67%           20.904ms         227.220us        92
adaptive_avg_pool2d        0.03%            41.713us         0.69%            862.922us        35.955us         24
_adaptive_avg_pool2d       0.65%            821.209us        0.65%            821.209us        34.217us         24
sigmoid                    0.15%            182.344us        0.15%            182.344us        7.928us          23
quantized::add             0.34%            431.939us        0.34%            431.939us        26.996us         16
dropout                    0.00%            1.936us          0.00%            1.936us          1.936us          1
view                       0.01%            10.281us         0.01%            10.281us         10.281us         1
dequantize                 0.00%            4.562us          0.00%            4.562us          4.562us          1
-------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Self CPU time total: 125.394ms
```
### After
```
 -------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                       Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls
-------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
quantize_per_tensor        0.18%            130.534us        0.18%            130.534us        130.534us        1
quantized::conv2d          42.29%           31.267ms         42.29%           31.267ms         267.243us        117
quantized::add_scalar      6.27%            4.637ms          6.27%            4.637ms          67.205us         69
quantized::relu6           1.77%            1.312ms          1.77%            1.312ms          19.008us         69
quantized::mul_scalar      18.92%           13.991ms         18.92%           13.991ms         202.768us        69
quantized::mul             28.49%           21.059ms         28.49%           21.059ms         228.904us        92
adaptive_avg_pool2d        0.06%            45.242us         1.27%            942.522us        39.272us         24
_adaptive_avg_pool2d       1.21%            897.280us        1.21%            897.280us        37.387us         24
sigmoid                    0.22%            160.282us        0.22%            160.282us        6.969us          23
quantized::add             0.56%            416.276us        0.56%            416.276us        26.017us         16
dropout                    0.00%            1.245us          0.00%            1.245us          1.245us          1
view                       0.01%            7.122us          0.01%            7.122us          7.122us          1
dequantize                 0.01%            5.952us          0.01%            5.952us          5.952us          1
-------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------
Self CPU time total: 73.930ms
```
ghstack-source-id: 100595212

Test Plan: buck test //caffe2/test:quantized -- 'test_qadd'  --print-passing-details

Differential Revision: D20500848

fbshipit-source-id: c292d15da121e6d13cc4eb92f10549874ff6ab0f
  • Loading branch information
dskhudia authored and facebook-github-bot committed Mar 23, 2020
1 parent 3e4076a commit 506996c
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 11 deletions.
31 changes: 31 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_qint.h
Expand Up @@ -362,6 +362,26 @@ Vec256<c10::qint32> inline operator*(
#endif
}

template <>
Vec256<c10::qint32> inline operator+(
const Vec256<c10::qint32>& a,
const Vec256<c10::qint32>& b) {
#ifdef __AVX2__
return _mm256_add_epi32(a, b);
#else
// Pray the compiler can autovectorize this
int32_t a_vals[a.size()];
int32_t b_vals[b.size()];
a.store(a_vals);
b.store(b_vals);
int32_t result_vals[a.size()];
for (size_t i = 0; i < a.size(); ++i) {
result_vals[i] = a_vals[i] + b_vals[i];
}
return Vec256<c10::qint32>::loadu(result_vals);
#endif
}

#ifdef __AVX2__
/*
* Convert values from int32 back to int8/uint8
Expand Down Expand Up @@ -1149,6 +1169,17 @@ Vec256<c10::qint32> inline operator*(
return retval;
}

template <>
Vec256<c10::qint32> inline operator+(
const Vec256<c10::qint32>& a,
const Vec256<c10::qint32>& b) {
Vec256<c10::qint32> retval;
for (size_t i = 0; i < a.size(); ++i) {
retval.vals[i] = a.vals[i] + b.vals[i];
}
return retval;
}

template <>
struct Vec256<c10::qint8> : public Vec256QuantizedConverter<
c10::qint8,
Expand Down
49 changes: 49 additions & 0 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Expand Up @@ -525,6 +525,53 @@ void qelu_kernel(const Tensor& qx, Scalar alpha, Tensor& qy) {
});
}

// Note: out is assumed to be the same size as self and other.
// Note: Addition is only supported when self and out are of the same dtype.
// Note: other is already assumed to be in int32, i.e., it's
// round(float/self_scale)
template <bool ReLUFused = false>
void qadd_scalar_kernel(Tensor& out, const Tensor& self, Scalar other) {
int64_t zero_point = out.q_zero_point();
float scale = out.q_scale();
float inv_scale = 1.0f / scale;
int64_t self_zero_point = self.q_zero_point();
float self_scale = self.q_scale();

float multiplier = self_scale * inv_scale;

AT_DISPATCH_QINT_TYPES(self.scalar_type(), "qadd_scalar", [&]() {
using Vec = Vec256<scalar_t>;
auto iter = TensorIterator::unary_op(out, self);
auto other_val = other.to<int32_t>();
auto other_vec = Vec256<c10::qint32>(static_cast<c10::qint32>(other_val));
cpu_kernel_vec(
iter,
[&](scalar_t a) -> scalar_t {
int32_t a_sub_z = static_cast<int32_t>(a.val_) -
static_cast<int32_t>(self_zero_point);
int32_t c = a_sub_z + other_val;
scalar_t res =
at::requantize_from_int<scalar_t>(multiplier, zero_point, c);
if (ReLUFused) {
res.val_ = std::max<scalar_t::underlying>(res.val_, zero_point);
}
return res;
},
[&](Vec a) -> Vec {
Vec::int_vec_return_type a_sub_z =
a.widening_subtract(Vec(static_cast<scalar_t>(self_zero_point)));
Vec::int_vec_return_type c;
for (int i = 0; i < Vec::int_num_vecs(); ++i) {
c[i] = a_sub_z[i] + other_vec;
}
Vec rv = Vec::requantize_from_int(c, multiplier, zero_point);
if (ReLUFused) {
rv = rv.maximum(Vec(static_cast<scalar_t>(zero_point)));
}
return rv;
});
});
}
// Note: out is assumed to be the same size as self and other.
// Note: Addition is only supported when self, other, out are of the same dtype.
template <bool ReLUFused = false>
Expand Down Expand Up @@ -1462,6 +1509,8 @@ REGISTER_DISPATCH(qtanh_stub, &qtanh_kernel);
REGISTER_DISPATCH(qelu_stub, &qelu_kernel);
REGISTER_DISPATCH(qadd_relu_stub, &qadd_kernel<true>);
REGISTER_DISPATCH(qadd_stub, &qadd_kernel<false>);
REGISTER_DISPATCH(qadd_scalar_relu_stub, &qadd_scalar_kernel<true>);
REGISTER_DISPATCH(qadd_scalar_stub, &qadd_scalar_kernel<false>);
REGISTER_DISPATCH(qmul_relu_stub, &qmul_kernel<true>);
REGISTER_DISPATCH(qmul_stub, &qmul_kernel<false>);
REGISTER_DISPATCH(qmaxpool_2d_nhwc_stub, &qmaxpool_2d_nhwc_kernel);
Expand Down
30 changes: 19 additions & 11 deletions aten/src/ATen/native/quantized/cpu/qadd.cpp
Expand Up @@ -16,6 +16,8 @@ namespace native {

DEFINE_DISPATCH(qadd_relu_stub);
DEFINE_DISPATCH(qadd_stub);
DEFINE_DISPATCH(qadd_scalar_relu_stub);
DEFINE_DISPATCH(qadd_scalar_stub);

namespace {

Expand Down Expand Up @@ -46,8 +48,9 @@ Tensor _add_out(Tensor& out, const Tensor& self, const Tensor& other) {

template <bool ReLUFused = false>
Tensor _add_scalar_out(Tensor& out, const Tensor& self, Scalar other) {
TORCH_CHECK(self.qscheme() == kPerTensorAffine,
"Only per tensor affine is supported for now!!");
TORCH_CHECK(
self.qscheme() == kPerTensorAffine,
"Only per tensor affine is supported for now!!");
// To implement tensor-scalar addition in quantized space, we simply
// adjust the quantization parameters based on the following rules:
//
Expand All @@ -61,11 +64,11 @@ Tensor _add_scalar_out(Tensor& out, const Tensor& self, Scalar other) {
// If q_min > z - c_q
// s' = [(q_max - (z - c_q)]/[q_max - q_min] * s
// z' = q_min
// Xq' = torch.quantize_linear(Xq.dequantize() + c_q.dequantize() , s', z')
// Xq' = at::requantize_from_int(Xq - z + c_q, s/s', z')
// If q_max < z - c_q
// s' = [z - c_q -q_min]/[q_max - q_min] * s
// z' = q_max
// Xq' = torch.quantize_linear(Xq.dequantize() + c_q.dequantize(), s', z')
// Xq' = at::requantize_from_int(Xq - z + c_q, s/s', z')
// Else
// s' = s
// z' = z - c_q
Expand All @@ -85,24 +88,29 @@ Tensor _add_scalar_out(Tensor& out, const Tensor& self, Scalar other) {
if (q_min > z - c_q) {
s_prime = (((double)q_max - (z - c_q))) / ((double)q_max - q_min) * s;
z_prime = q_min;
auto dequantized_add = self.dequantize() + c_q * s;
out.set_quantizer_(make_per_tensor_affine_quantizer(
s_prime, z_prime, self.scalar_type()));
if (ReLUFused) {
dequantized_add.relu_();
qadd_scalar_relu_stub(self.device().type(), out, self, c_q);
} else {
qadd_scalar_stub(self.device().type(), out, self, c_q);
}
out.copy_(at::quantize_per_tensor(dequantized_add, s_prime, z_prime, self.scalar_type()));
} else if (q_max < z - c_q) {
s_prime = ((double)(z - c_q) - q_min) / ((double)q_max - q_min) * s;
z_prime = q_max;
auto dequantized_add = self.dequantize() + c_q * s;
out.set_quantizer_(make_per_tensor_affine_quantizer(
s_prime, z_prime, self.scalar_type()));
if (ReLUFused) {
dequantized_add.relu_();
qadd_scalar_relu_stub(self.device().type(), out, self, c_q);
} else {
qadd_scalar_stub(self.device().type(), out, self, c_q);
}
out.copy_(at::quantize_per_tensor(dequantized_add, s_prime, z_prime, self.scalar_type()));
} else {
s_prime = s;
z_prime = z - c_q;
out.copy_(self);
out.set_quantizer_(make_per_tensor_affine_quantizer(s_prime, z_prime, self.scalar_type()));
out.set_quantizer_(make_per_tensor_affine_quantizer(
s_prime, z_prime, self.scalar_type()));
if (ReLUFused) {
at::native::quantized_relu_(out);
}
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/quantized/cpu/quantized_ops.h
Expand Up @@ -22,6 +22,8 @@ using qelu_fn = void(*)(
at::Tensor& /*qy*/);
using qbinary_fn =
void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Tensor& /*other*/);
using qadd_scalar_fn =
void (*)(Tensor& /*out*/, const Tensor& /*self*/, Scalar other /*other*/);
using qmaxpool_2d_fn = void (*)(
const Tensor& qx,
int64_t iC, // input/output channels
Expand Down Expand Up @@ -127,6 +129,8 @@ DECLARE_DISPATCH(qbinary_fn, qadd_stub);
DECLARE_DISPATCH(qbinary_fn, qadd_relu_stub);
DECLARE_DISPATCH(qbinary_fn, qmul_stub);
DECLARE_DISPATCH(qbinary_fn, qmul_relu_stub);
DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_stub);
DECLARE_DISPATCH(qadd_scalar_fn, qadd_scalar_relu_stub);
DECLARE_DISPATCH(qelu_fn, qelu_stub);
DECLARE_DISPATCH(qmaxpool_2d_fn, qmaxpool_2d_nhwc_stub);
DECLARE_DISPATCH(qadaptive_avg_pool2d_fn, qadaptive_avg_pool2d_nhwc_stub);
Expand Down

0 comments on commit 506996c

Please sign in to comment.