Skip to content

Commit

Permalink
quantize_tensor_per_channel ARM implementation
Browse files Browse the repository at this point in the history
Summary: Currently on mobile devices quantize_tensor has a vectorized implementation
using ARM intrinsics; however quantize_tensor_per_channel does not.

Test Plan: Build and push to mobile device
```
BUILD_MOBILE_BENCHMARK=1 BUILD_MOBILE_TEST=1 ANDROID_DEBUG_SYMBOLS=1 BUILD_PYTORCH_MOBILE=1 ANDROID_ABI=arm64-v8a ./scripts/build_android.sh  -DANDROID_CCACHE=$(which ccache) -DBUILD_BINARY=ON
adb push build_android/bin/quantize_per_channel /data/local/tmp
```
and then run the benchmark binary over adb shell

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 21882a2b455e5bf6e5d24c756719c03f8dd25bc5
Pull Request resolved: #46018
  • Loading branch information
ajliu committed Oct 13, 2020
1 parent d071ec1 commit 71d3e14
Showing 1 changed file with 243 additions and 3 deletions.
246 changes: 243 additions & 3 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Expand Up @@ -2568,6 +2568,214 @@ void dequantize_tensor_per_tensor_affine_cpu(
#endif // USE_FBGEMM

// TODO: add fbgemm for per channel
#if defined(__ARM_NEON__) || defined(__aarch64__)
// Generic template defaults to naive quantize implementation
template <typename T>
void quantize_tensor_per_channel_arm(
const float* in,
Tensor qtensor,
const float* scales,
const int32_t* zero_points,
const bool channels_last,
const int64_t batches,
const int64_t elements_per_channel,
const int64_t channels) {
auto out = qtensor.data_ptr<T>();
if (channels_last) {
for (auto b = 0; b < batches; ++b) {
for (auto e = 0; e < elements_per_channel; ++e) {
for (auto c = 0; c < channels; ++c) {
auto i = b * channels * elements_per_channel + e * channels + c;
out[i] =
at::native::quantize_val<T>(scales[c], zero_points[c], in[i]);
}
}
}
} else {
for (auto b = 0; b < batches; ++b) {
for (auto c = 0; c < channels; ++c) {
for (auto e = 0; e < elements_per_channel; ++e) {
auto i = b * channels * elements_per_channel +
c * elements_per_channel + e;
out[i] =
at::native::quantize_val<T>(scales[c], zero_points[c], in[i]);
}
}
}
}
}

// Specialized implementation from caffe2::Int8Quantize.
// There may be slight accuracy difference between this and implementation of
// quantize_val
// TODO Update quantize_tensor_per_channel_arm implementation to follow
// quantize_val, i.e. f = Round(value/scale + zero_point)
// TODO Make quantize_tensor_per_channel_arm work for other datatypes too (int8,
// int32).
template <>
void quantize_tensor_per_channel_arm<c10::quint8>(
const float* in,
Tensor qtensor,
const float* scales,
const int32_t* zero_points,
const bool channels_last,
const int64_t batches,
const int64_t elements_per_channel,
const int64_t channels) {
auto out = (uint8_t*)qtensor.data_ptr<c10::quint8>();
#if defined(__ARM_NEON__)
// magic float and magic int to take care of rounding
// int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000
// Some detail:
// 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you
// add a small number to a large number, the result rounds to the precision of
// the least significant bit of the large number. For IEEE-754
// single-precision number mantissa has 23 bits, and adding 2**23 would cause
// rounding to the nearest even integer. The we cast to int and subtract the
// same number (0x4B400000 is the integer representation of 12582912.0f) to
// get only the mantissa. This works if -2**22 < x < 2**22, but preserves the
// sign for negative numbers.
const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f);
const int32x4_t vmagic_int = vdupq_n_s32(0x4B400000);
if (channels_last) {
for (uint32_t b = 0; b < batches; ++b) {
for (uint32_t e = 0; e < elements_per_channel; ++e) {
uint32_t c = 0;
while (c + 8 < channels) {
const int32x4_t voffset0123 =
vsubq_s32(vld1q_s32(zero_points + c), vmagic_int);
const float32x4_t vinv_scale0123 = vrecpeq_f32(vld1q_f32(scales + c));
c += 4;
const int32x4_t voffset4567 =
vsubq_s32(vld1q_s32(zero_points + c), vmagic_int);
const float32x4_t vinv_scale4567 = vrecpeq_f32(vld1q_f32(scales + c));
c += 4;
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t vraw0123 = vaddq_s32(
voffset0123,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale0123))));
const int32x4_t vraw4567 = vaddq_s32(
voffset4567,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale4567))));
const int16x8_t vraw01234567 =
vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
vst1_u8(out, vout01234567);
out += 8;
}
for (; c < channels; ++c) {
(*out++) =
at::native::quantize_val_arm(scales[c], zero_points[c], (*in++));
}
}
}
} else {
for (uint32_t b = 0; b < batches; ++b) {
for (uint32_t c = 0; c < channels; ++c) {
uint32_t e = 0;
const int32x4_t voffset = vdupq_n_s32(zero_points[c] - 0x4B400000);
const float32x4_t vinv_scale = vdupq_n_f32(1.0f / scales[c]);
for (; e + 8 < elements_per_channel; e += 8) {
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t vraw0123 = vaddq_s32(
voffset,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale))));
const int32x4_t vraw4567 = vaddq_s32(
voffset,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale))));
const int16x8_t vraw01234567 =
vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
vst1_u8(out, vout01234567);
out += 8;
}
for (; e < elements_per_channel; ++e) {
(*out++) =
at::native::quantize_val_arm(scales[c], zero_points[c], (*in++));
}
}
}
}
#else // defined(__ARM_NEON__)
// Copy zero_points (int32_t) into int16_t array
int16_t zero_points_int16t[channels];
for (int i = 0; i < channels; ++i) {
zero_points_int16t[i] = (int16_t)(uint16_t)zero_points[i];
}
if (channels_last) {
for (uint32_t b = 0; b < batches; ++b) {
for (uint32_t e = 0; e < elements_per_channel; ++e) {
uint32_t c = 0;
while (c + 8 < channels) {
const int16x8_t vzero_point = vld1q_s16(zero_points_int16t + c);
const float32x4_t vinv_scale0123 = vrecpeq_f32(vld1q_f32(scales + c));
c += 4;
const float32x4_t vinv_scale4567 = vrecpeq_f32(vld1q_f32(scales + c));
c += 4;
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t v0123_rounded =
vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale0123));
const int32x4_t v4567_rounded =
vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale4567));
const int16x8_t v01234567_packed = vqaddq_s16(
vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded),
vzero_point);
const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed);
vst1_u8(out, vout01234567);
out += 8;
}
for (; c < channels; ++c) {
(*out++) =
at::native::quantize_val_arm(scales[c], zero_points[c], (*in++));
}
}
}
} else {
for (uint32_t b = 0; b < batches; ++b) {
for (uint32_t c = 0; c < channels; ++c) {
uint32_t e = 0;
const int16x8_t vzero_point = vdupq_n_s16(zero_points[c]);
const float32x4_t vinv_scale = vdupq_n_f32(1.0f / scales[c]);
for (; e + 8 < elements_per_channel; e += 8) {
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t v0123_rounded =
vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale));
const int32x4_t v4567_rounded =
vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale));
const int16x8_t v01234567_packed = vqaddq_s16(
vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded),
vzero_point);
const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed);
vst1_u8(out, vout01234567);
out += 8;
}
for (; e < elements_per_channel; ++e) {
(*out++) =
at::native::quantize_val_arm(scales[c], zero_points[c], (*in++));
}
}
}
}
#endif // defined(__ARM_NEON__)
}
#endif // defined(__ARM_NEON__) || defined(__aarch64__)

void quantize_tensor_per_channel_affine_cpu(
Tensor rtensor,
Tensor qtensor,
Expand All @@ -2580,9 +2788,39 @@ void quantize_tensor_per_channel_affine_cpu(
// Since current implemntation on channels_last format does not
// cover per channel quant with arbitrary axis value, it is better
// to check and fail.
TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
TORCH_CHECK(
rtensor.is_contiguous() || (axis <= 1),
"If tensor is channels_last contig then per channel quantization "
"is supported only for axis = 0 or 1.");
#if defined(__ARM_NEON__) || defined(__aarch64__)
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_channel_affine_cpu", [&]() {
const float* const rdata = rtensor.data_ptr<float>();
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel =
size_from_dim_(axis + 1, rtensor.sizes());
int64_t channels = rtensor.size(axis);
auto zero_points_int = zero_points.to(at::kInt);
auto scales_float = scales.to(at::kFloat);
float* scales_data = scales_float.data_ptr<float>();
int32_t* zero_points_data = zero_points_int.data_ptr<int32_t>();
check_tensor_memory_format(rtensor, qtensor);
auto channels_last =
(axis == 1 &&
(rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d)));
quantize_tensor_per_channel_arm<scalar_t>(
rdata,
qtensor,
scales_data,
zero_points_data,
channels_last,
batches,
elements_per_channel,
channels);
});
#else
// Fallback path
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_channel_affine_cpu", [&]() {
int64_t batches = size_to_dim_(axis, rtensor.sizes());
Expand All @@ -2594,8 +2832,9 @@ void quantize_tensor_per_channel_affine_cpu(
check_tensor_memory_format(rtensor, qtensor);
const float* rdata = rtensor.data_ptr<float>();
auto qdata = qtensor.data_ptr<scalar_t>();
if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
if (axis == 1 &&
(rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
// This code handles per channel quant when axis = 1 and
// channels_last contig.
// If axis = 0 and channels_last contig, implementation
Expand All @@ -2622,6 +2861,7 @@ void quantize_tensor_per_channel_affine_cpu(
}
}
});
#endif // defined(__ARM_NEON__) || defined(__aarch64__)
}

template<typename T, typename N, typename Q>
Expand Down

0 comments on commit 71d3e14

Please sign in to comment.