Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

quantize_tensor_per_channel ARM implementation #46018

Closed
wants to merge 6 commits into from
Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
259 changes: 254 additions & 5 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Expand Up @@ -2568,6 +2568,225 @@ void dequantize_tensor_per_tensor_affine_cpu(
#endif // USE_FBGEMM

// TODO: add fbgemm for per channel
#if defined(__ARM_NEON__) || defined(__aarch64__)
// Generic template defaults to naive quantize implementation
template <typename T>
void quantize_tensor_per_channel_arm(
const float* in,
Tensor qtensor,
const double* scales,
const int64_t* zero_points,
const bool channels_last,
const int64_t batches,
const int64_t elements_per_channel,
const int64_t channels) {
auto out = qtensor.data_ptr<T>();
if (channels_last) {
for (auto b = 0; b < batches; ++b) {
for (auto e = 0; e < elements_per_channel; ++e) {
for (auto c = 0; c < channels; ++c) {
auto i = b * channels * elements_per_channel + e * channels + c;
out[i] =
at::native::quantize_val<T>(scales[c], zero_points[c], in[i]);
}
}
}
} else {
for (auto b = 0; b < batches; ++b) {
for (auto c = 0; c < channels; ++c) {
for (auto e = 0; e < elements_per_channel; ++e) {
auto i = b * channels * elements_per_channel +
c * elements_per_channel + e;
out[i] =
at::native::quantize_val<T>(scales[c], zero_points[c], in[i]);
}
}
}
}
}
ajliu marked this conversation as resolved.
Show resolved Hide resolved

// Specialized implementation from caffe2::Int8Quantize.
// There may be slight accuracy difference between this and implementation of
// quantize_val
// TODO Update quantize_tensor_per_channel_arm implementation to follow
// quantize_val, i.e. f = Round(value/scale + zero_point)
// TODO Make quantize_tensor_per_channel_arm work for other datatypes too (int8,
// int32).
template <>
void quantize_tensor_per_channel_arm<c10::quint8>(
const float* in,
Tensor qtensor,
const double* scales,
const int64_t* zero_points,
const bool channels_last,
const int64_t batches,
const int64_t elements_per_channel,
const int64_t channels) {
auto out = (uint8_t*)qtensor.data_ptr<c10::quint8>();
#if defined(__ARM_NEON__)
// magic float and magic int to take care of rounding
// int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000
// Some detail:
// 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you
// add a small number to a large number, the result rounds to the precision of
// the least significant bit of the large number. For IEEE-754
// single-precision number mantissa has 23 bits, and adding 2**23 would cause
// rounding to the nearest even integer. The we cast to int and subtract the
// same number (0x4B400000 is the integer representation of 12582912.0f) to
// get only the mantissa. This works if -2**22 < x < 2**22, but preserves the
// sign for negative numbers.
const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f);
const int32x4_t vmagic_int = vdupq_n_s32(0x4B400000);
// Copy zero_points (int64_t) into int32_t array
// Copy reciprocal of scales (double) into float array
int32_t zero_points_int32t[channels];
ajliu marked this conversation as resolved.
Show resolved Hide resolved
float inv_scales[channels];
for (int i = 0; i < channels; ++i) {
zero_points_int32t[i] = (int32_t)(uint32_t)zero_points[i];
inv_scales[i] = 1.0f / (float)scales[i];
}
if (channels_last) {
for (uint32_t b = 0; b < batches; ++b) {
for (uint32_t e = 0; e < elements_per_channel; ++e) {
uint32_t c = 0;
while (c + 8 < channels) {
const int32x4_t voffset0123 =
vsubq_s32(vld1q_s32(zero_points_int32t + c), vmagic_int);
ajliu marked this conversation as resolved.
Show resolved Hide resolved
const float32x4_t vinv_scale0123 = vld1q_f32(inv_scales + c);
c += 4;
const int32x4_t voffset4567 =
vsubq_s32(vld1q_s32(zero_points_int32t + c), vmagic_int);
const float32x4_t vinv_scale4567 = vld1q_f32(inv_scales + c);
c += 4;
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t vraw0123 = vaddq_s32(
voffset0123,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale0123))));
const int32x4_t vraw4567 = vaddq_s32(
voffset4567,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale4567))));
const int16x8_t vraw01234567 =
vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
vst1_u8(out, vout01234567);
out += 8;
}
for (; c < channels; ++c) {
(*out++) = at::native::quantize_val_arm(
scales[c], zero_points[c], (*in++));
}
}
}
} else {
for (uint32_t b = 0; b < batches; ++b) {
for (uint32_t c = 0; c < channels; ++c) {
uint32_t e = 0;
const int32x4_t voffset = vdupq_n_s32(zero_points_int32t[c] - 0x4B400000);
const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]);
for (; e + 8 < elements_per_channel; e += 8) {
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t vraw0123 = vaddq_s32(
voffset,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale))));
const int32x4_t vraw4567 = vaddq_s32(
voffset,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale))));
const int16x8_t vraw01234567 =
vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
vst1_u8(out, vout01234567);
out += 8;
}
for (; e < elements_per_channel; ++e) {
(*out++) = at::native::quantize_val_arm(
scales[c], zero_points[c], (*in++));
}
}
}
}
#else // defined(__ARM_NEON__)
// Copy zero_points (int64_t) into int16_t array
// Copy scales (double) into float array
int16_t zero_points_int16t[channels];
float inv_scales[channels];
for (int i = 0; i < channels; ++i) {
zero_points_int16t[i] = (int16_t)(uint16_t)zero_points[i];
inv_scales[i] = 1.0f / (float)scales[i];
}
if (channels_last) {
for (uint32_t b = 0; b < batches; ++b) {
for (uint32_t e = 0; e < elements_per_channel; ++e) {
uint32_t c = 0;
while (c + 8 < channels) {
const int16x8_t vzero_point = vld1q_s16(zero_points_int16t + c);
const float32x4_t vinv_scale0123 = vld1q_f32(inv_scales + c);
c += 4;
const float32x4_t vinv_scale4567 = vld1q_f32(inv_scales + c);
c += 4;
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t v0123_rounded =
vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale0123));
const int32x4_t v4567_rounded =
vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale4567));
const int16x8_t v01234567_packed = vqaddq_s16(
vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded),
vzero_point);
const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed);
vst1_u8(out, vout01234567);
out += 8;
}
for (; c < channels; ++c) {
(*out++) =
at::native::quantize_val_arm(scales[c], zero_points[c], (*in++));
}
}
}
} else {
for (uint32_t b = 0; b < batches; ++b) {
for (uint32_t c = 0; c < channels; ++c) {
uint32_t e = 0;
const int16x8_t vzero_point = vdupq_n_s16(zero_points_int16t[c]);
const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]);
for (; e + 8 < elements_per_channel; e += 8) {
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t v0123_rounded =
vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale));
const int32x4_t v4567_rounded =
vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale));
const int16x8_t v01234567_packed = vqaddq_s16(
vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded),
vzero_point);
const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed);
vst1_u8(out, vout01234567);
out += 8;
}
for (; e < elements_per_channel; ++e) {
(*out++) =
at::native::quantize_val_arm(scales[c], zero_points[c], (*in++));
}
}
}
}
#endif // defined(__ARM_NEON__)
}
#endif // defined(__ARM_NEON__) || defined(__aarch64__)

void quantize_tensor_per_channel_affine_cpu(
Tensor rtensor,
Tensor qtensor,
Expand All @@ -2580,9 +2799,37 @@ void quantize_tensor_per_channel_affine_cpu(
// Since current implemntation on channels_last format does not
// cover per channel quant with arbitrary axis value, it is better
// to check and fail.
TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
TORCH_CHECK(
rtensor.is_contiguous() || (axis <= 1),
"If tensor is channels_last contig then per channel quantization "
"is supported only for axis = 0 or 1.");
#if defined(__ARM_NEON__) || defined(__aarch64__)
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_channel_affine_cpu", [&]() {
const float* const rdata = rtensor.data_ptr<float>();
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel =
size_from_dim_(axis + 1, rtensor.sizes());
int64_t channels = rtensor.size(axis);
double* scales_data = scales.data_ptr<double>();
int64_t* zero_points_data = zero_points.data_ptr<int64_t>();
auto channels_last =
(axis == 1 &&
(rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d)));
check_tensor_memory_format(rtensor, qtensor);
quantize_tensor_per_channel_arm<scalar_t>(
rdata,
qtensor,
scales_data,
zero_points_data,
channels_last,
batches,
elements_per_channel,
channels);
});
#else
// Fallback path
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_channel_affine_cpu", [&]() {
int64_t batches = size_to_dim_(axis, rtensor.sizes());
Expand All @@ -2594,12 +2841,13 @@ void quantize_tensor_per_channel_affine_cpu(
check_tensor_memory_format(rtensor, qtensor);
const float* rdata = rtensor.data_ptr<float>();
auto qdata = qtensor.data_ptr<scalar_t>();
if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
if (axis == 1 &&
(rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
// This code handles per channel quant when axis = 1 and
// channels_last contig.
// If axis = 0 and channels_last contig, implementation
// for channels first (NCHW) works.
// If axis = 0 and channels_last contig, implementation for channels
// first (NCHW) works.
for (auto b = 0; b < batches; ++b) {
for (auto e = 0; e < elements_per_channel; ++e) {
for (auto c = 0; c < channel; ++c) {
Expand All @@ -2622,6 +2870,7 @@ void quantize_tensor_per_channel_affine_cpu(
}
}
});
#endif // defined(__ARM_NEON__) || defined(__aarch64__)
}

template<typename T, typename N, typename Q>
Expand Down