From 71d3e1410392c6a7d84802968e35fb6506b5f853 Mon Sep 17 00:00:00 2001 From: Anthony Liu Date: Tue, 13 Oct 2020 13:15:42 -0700 Subject: [PATCH] quantize_tensor_per_channel ARM implementation Summary: Currently on mobile devices quantize_tensor has a vectorized implementation using ARM intrinsics; however quantize_tensor_per_channel does not. Test Plan: Build and push to mobile device ``` BUILD_MOBILE_BENCHMARK=1 BUILD_MOBILE_TEST=1 ANDROID_DEBUG_SYMBOLS=1 BUILD_PYTORCH_MOBILE=1 ANDROID_ABI=arm64-v8a ./scripts/build_android.sh -DANDROID_CCACHE=$(which ccache) -DBUILD_BINARY=ON adb push build_android/bin/quantize_per_channel /data/local/tmp ``` and then run the benchmark binary over adb shell Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 21882a2b455e5bf6e5d24c756719c03f8dd25bc5 Pull Request resolved: https://github.com/pytorch/pytorch/pull/46018 --- .../cpu/kernels/QuantizedOpKernels.cpp | 246 +++++++++++++++++- 1 file changed, 243 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index 9dd98b71ad40..f19f46f0dbc1 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -2568,6 +2568,214 @@ void dequantize_tensor_per_tensor_affine_cpu( #endif // USE_FBGEMM // TODO: add fbgemm for per channel +#if defined(__ARM_NEON__) || defined(__aarch64__) +// Generic template defaults to naive quantize implementation +template +void quantize_tensor_per_channel_arm( + const float* in, + Tensor qtensor, + const float* scales, + const int32_t* zero_points, + const bool channels_last, + const int64_t batches, + const int64_t elements_per_channel, + const int64_t channels) { + auto out = qtensor.data_ptr(); + if (channels_last) { + for (auto b = 0; b < batches; ++b) { + for (auto e = 0; e < elements_per_channel; ++e) { + for (auto c = 0; c < channels; ++c) { + auto i = b * channels * elements_per_channel + e * channels + c; + out[i] = + at::native::quantize_val(scales[c], zero_points[c], in[i]); + } + } + } + } else { + for (auto b = 0; b < batches; ++b) { + for (auto c = 0; c < channels; ++c) { + for (auto e = 0; e < elements_per_channel; ++e) { + auto i = b * channels * elements_per_channel + + c * elements_per_channel + e; + out[i] = + at::native::quantize_val(scales[c], zero_points[c], in[i]); + } + } + } + } +} + +// Specialized implementation from caffe2::Int8Quantize. +// There may be slight accuracy difference between this and implementation of +// quantize_val +// TODO Update quantize_tensor_per_channel_arm implementation to follow +// quantize_val, i.e. f = Round(value/scale + zero_point) +// TODO Make quantize_tensor_per_channel_arm work for other datatypes too (int8, +// int32). +template <> +void quantize_tensor_per_channel_arm( + const float* in, + Tensor qtensor, + const float* scales, + const int32_t* zero_points, + const bool channels_last, + const int64_t batches, + const int64_t elements_per_channel, + const int64_t channels) { + auto out = (uint8_t*)qtensor.data_ptr(); +#if defined(__ARM_NEON__) + // magic float and magic int to take care of rounding + // int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000 + // Some detail: + // 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you + // add a small number to a large number, the result rounds to the precision of + // the least significant bit of the large number. For IEEE-754 + // single-precision number mantissa has 23 bits, and adding 2**23 would cause + // rounding to the nearest even integer. The we cast to int and subtract the + // same number (0x4B400000 is the integer representation of 12582912.0f) to + // get only the mantissa. This works if -2**22 < x < 2**22, but preserves the + // sign for negative numbers. + const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f); + const int32x4_t vmagic_int = vdupq_n_s32(0x4B400000); + if (channels_last) { + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t e = 0; e < elements_per_channel; ++e) { + uint32_t c = 0; + while (c + 8 < channels) { + const int32x4_t voffset0123 = + vsubq_s32(vld1q_s32(zero_points + c), vmagic_int); + const float32x4_t vinv_scale0123 = vrecpeq_f32(vld1q_f32(scales + c)); + c += 4; + const int32x4_t voffset4567 = + vsubq_s32(vld1q_s32(zero_points + c), vmagic_int); + const float32x4_t vinv_scale4567 = vrecpeq_f32(vld1q_f32(scales + c)); + c += 4; + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t vraw0123 = vaddq_s32( + voffset0123, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale0123)))); + const int32x4_t vraw4567 = vaddq_s32( + voffset4567, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale4567)))); + const int16x8_t vraw01234567 = + vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567)); + const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567); + vst1_u8(out, vout01234567); + out += 8; + } + for (; c < channels; ++c) { + (*out++) = + at::native::quantize_val_arm(scales[c], zero_points[c], (*in++)); + } + } + } + } else { + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t c = 0; c < channels; ++c) { + uint32_t e = 0; + const int32x4_t voffset = vdupq_n_s32(zero_points[c] - 0x4B400000); + const float32x4_t vinv_scale = vdupq_n_f32(1.0f / scales[c]); + for (; e + 8 < elements_per_channel; e += 8) { + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t vraw0123 = vaddq_s32( + voffset, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale)))); + const int32x4_t vraw4567 = vaddq_s32( + voffset, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale)))); + const int16x8_t vraw01234567 = + vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567)); + const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567); + vst1_u8(out, vout01234567); + out += 8; + } + for (; e < elements_per_channel; ++e) { + (*out++) = + at::native::quantize_val_arm(scales[c], zero_points[c], (*in++)); + } + } + } + } +#else // defined(__ARM_NEON__) + // Copy zero_points (int32_t) into int16_t array + int16_t zero_points_int16t[channels]; + for (int i = 0; i < channels; ++i) { + zero_points_int16t[i] = (int16_t)(uint16_t)zero_points[i]; + } + if (channels_last) { + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t e = 0; e < elements_per_channel; ++e) { + uint32_t c = 0; + while (c + 8 < channels) { + const int16x8_t vzero_point = vld1q_s16(zero_points_int16t + c); + const float32x4_t vinv_scale0123 = vrecpeq_f32(vld1q_f32(scales + c)); + c += 4; + const float32x4_t vinv_scale4567 = vrecpeq_f32(vld1q_f32(scales + c)); + c += 4; + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t v0123_rounded = + vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale0123)); + const int32x4_t v4567_rounded = + vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale4567)); + const int16x8_t v01234567_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), + vzero_point); + const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed); + vst1_u8(out, vout01234567); + out += 8; + } + for (; c < channels; ++c) { + (*out++) = + at::native::quantize_val_arm(scales[c], zero_points[c], (*in++)); + } + } + } + } else { + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t c = 0; c < channels; ++c) { + uint32_t e = 0; + const int16x8_t vzero_point = vdupq_n_s16(zero_points[c]); + const float32x4_t vinv_scale = vdupq_n_f32(1.0f / scales[c]); + for (; e + 8 < elements_per_channel; e += 8) { + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t v0123_rounded = + vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale)); + const int32x4_t v4567_rounded = + vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale)); + const int16x8_t v01234567_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), + vzero_point); + const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed); + vst1_u8(out, vout01234567); + out += 8; + } + for (; e < elements_per_channel; ++e) { + (*out++) = + at::native::quantize_val_arm(scales[c], zero_points[c], (*in++)); + } + } + } + } +#endif // defined(__ARM_NEON__) +} +#endif // defined(__ARM_NEON__) || defined(__aarch64__) + void quantize_tensor_per_channel_affine_cpu( Tensor rtensor, Tensor qtensor, @@ -2580,9 +2788,39 @@ void quantize_tensor_per_channel_affine_cpu( // Since current implemntation on channels_last format does not // cover per channel quant with arbitrary axis value, it is better // to check and fail. - TORCH_CHECK(rtensor.is_contiguous() || (axis <=1), + TORCH_CHECK( + rtensor.is_contiguous() || (axis <= 1), "If tensor is channels_last contig then per channel quantization " "is supported only for axis = 0 or 1."); +#if defined(__ARM_NEON__) || defined(__aarch64__) + AT_DISPATCH_QINT_TYPES( + qtensor.scalar_type(), "quantize_tensor_per_channel_affine_cpu", [&]() { + const float* const rdata = rtensor.data_ptr(); + int64_t batches = size_to_dim_(axis, rtensor.sizes()); + int64_t elements_per_channel = + size_from_dim_(axis + 1, rtensor.sizes()); + int64_t channels = rtensor.size(axis); + auto zero_points_int = zero_points.to(at::kInt); + auto scales_float = scales.to(at::kFloat); + float* scales_data = scales_float.data_ptr(); + int32_t* zero_points_data = zero_points_int.data_ptr(); + check_tensor_memory_format(rtensor, qtensor); + auto channels_last = + (axis == 1 && + (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || + rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))); + quantize_tensor_per_channel_arm( + rdata, + qtensor, + scales_data, + zero_points_data, + channels_last, + batches, + elements_per_channel, + channels); + }); +#else + // Fallback path AT_DISPATCH_QINT_TYPES( qtensor.scalar_type(), "quantize_tensor_per_channel_affine_cpu", [&]() { int64_t batches = size_to_dim_(axis, rtensor.sizes()); @@ -2594,8 +2832,9 @@ void quantize_tensor_per_channel_affine_cpu( check_tensor_memory_format(rtensor, qtensor); const float* rdata = rtensor.data_ptr(); auto qdata = qtensor.data_ptr(); - if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || - rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { + if (axis == 1 && + (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || + rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { // This code handles per channel quant when axis = 1 and // channels_last contig. // If axis = 0 and channels_last contig, implementation @@ -2622,6 +2861,7 @@ void quantize_tensor_per_channel_affine_cpu( } } }); +#endif // defined(__ARM_NEON__) || defined(__aarch64__) } template