Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 89 additions & 7 deletions torchvision/csrc/ops/quantized/cpu/qroi_align_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include <ATen/ATen.h>
#include <ATen/native/quantized/affine_quantizer.h>
#include <torch/library.h>

#include "../../cpu/roi_align_common.h"
Expand All @@ -9,6 +8,90 @@ namespace ops {

namespace {

// BEGIN copy-pasted code from pytorch core
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/affine_quantizer_base.cpp
// We're vendoring the quantize_val() and dequantize_val() functions here. The
// reason is that these functions belong in at::native, which is incompatible
// with android xplat support.

// FIXME: Remove this section once we can use at::native for android xplat
// builds, or when quantize_val() and dequantize_val() aren't in at::native

#ifdef USE_FBGEMM
template <typename T>
T quantize_val(double scale, int64_t zero_point, float value) {
// Internally, fbgemm::Quantize uses std::nearbyint.
// std::nearbyint results in nearest integer value according to the current
// rounding mode and the default rounding mode is rounds to even in half-way
// cases in most popular processor architectures like x86 and ARM. This is
// typically faster than an alternatives like std::round that rounds half-way
// cases away from zero, and can be consistent with SIMD implementations for
// example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
// _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int32_t qvalue;
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
qvalue = fbgemm::Quantize<typename T::underlying, false /*LEGACY*/>(
value,
static_cast<int32_t>(zero_point),
static_cast<float>(scale),
/*result_precision=*/CHAR_BIT * sizeof(typename T::underlying));
return static_cast<T>(qvalue);
}

template <typename T>
inline float dequantize_val(double scale, int64_t zero_point, T value) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
fbgemm::TensorQuantizationParams qparams;
qparams.scale = static_cast<float>(scale);
qparams.zero_point = static_cast<int32_t>(zero_point);
return fbgemm::Dequantize<typename T::underlying>(value.val_, qparams);
}
#else // USE_FBGEMM

#if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
template <class T>
inline float Round(const float x) {
return ::nearbyintf(x);
}
inline double Round(const double x) {
return ::nearbyint(x);
}
#else
template <class T>
inline T Round(const T x) {
return std::nearbyint(x);
}
#endif

template <typename T>
T quantize_val(double scale, int64_t zero_point, float value) {
// std::nearbyint results in nearest integer value according to the current
// rounding mode and the default rounding mode is rounds to even in half-way
// cases in most popular processor architectures like x86 and ARM. This is
// typically faster than an alternatives like std::round that rounds half-way
// cases away from zero, and can be consistent with SIMD implementations for
// example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with
// _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode.
int64_t qvalue;
constexpr int64_t qmin = std::numeric_limits<typename T::underlying>::min();
constexpr int64_t qmax = std::numeric_limits<typename T::underlying>::max();
float inv_scale = 1.0f / static_cast<float>(scale);
qvalue = static_cast<int64_t>(zero_point + Round(value * inv_scale));
qvalue = std::max<int64_t>(qvalue, qmin);
qvalue = std::min<int64_t>(qvalue, qmax);
return static_cast<T>(qvalue);
}

template <typename T>
float dequantize_val(double scale, int64_t zero_point, T value) {
// We need to convert the qint8 value to float to ensure the subtraction
// subexpression returns a float
return (static_cast<float>(value.val_) - zero_point) * scale;
}
#endif // USE_FBGEMM
// END copy-pasted code from pytorch core

template <typename T>
void qroi_align_forward_kernel_impl(
int n_rois,
Expand Down Expand Up @@ -46,19 +129,19 @@ void qroi_align_forward_kernel_impl(
// Do not using rounding; this implementation detail is critical
float offset = aligned ? 0.5 : 0.;
float roi_start_w =
at::native::dequantize_val(rois_scale, rois_zp, offset_rois[1]) *
dequantize_val(rois_scale, rois_zp, offset_rois[1]) *
spatial_scale -
offset;
float roi_start_h =
at::native::dequantize_val(rois_scale, rois_zp, offset_rois[2]) *
dequantize_val(rois_scale, rois_zp, offset_rois[2]) *
spatial_scale -
offset;
float roi_end_w =
at::native::dequantize_val(rois_scale, rois_zp, offset_rois[3]) *
dequantize_val(rois_scale, rois_zp, offset_rois[3]) *
spatial_scale -
offset;
float roi_end_h =
at::native::dequantize_val(rois_scale, rois_zp, offset_rois[4]) *
dequantize_val(rois_scale, rois_zp, offset_rois[4]) *
spatial_scale -
offset;

Expand Down Expand Up @@ -134,8 +217,7 @@ void qroi_align_forward_kernel_impl(

output_val /= count; // Average pooling

output[index] =
at::native::quantize_val<T>(input_scale, input_zp, output_val);
output[index] = quantize_val<T>(input_scale, input_zp, output_val);
} // for pw
} // for ph
} // for c
Expand Down