diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index 9dd98b71ad40..8137049a75c8 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -2568,7 +2568,9 @@ void dequantize_tensor_per_tensor_affine_cpu( #endif // USE_FBGEMM // TODO: add fbgemm for per channel -void quantize_tensor_per_channel_affine_cpu( +// Generic template defaults to naive quantize implementation +template +void quantize_tensor_per_channel_impl( Tensor rtensor, Tensor qtensor, Tensor scales, @@ -2580,47 +2582,253 @@ void quantize_tensor_per_channel_affine_cpu( // Since current implemntation on channels_last format does not // cover per channel quant with arbitrary axis value, it is better // to check and fail. - TORCH_CHECK(rtensor.is_contiguous() || (axis <=1), + int64_t batches = size_to_dim_(axis, rtensor.sizes()); + int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes()); + int64_t channels = rtensor.size(axis); + auto scales_data = scales.data_ptr(); + auto zero_points_data = zero_points.data_ptr(); + const float* in = rtensor.data_ptr(); + auto out = qtensor.data_ptr(); + if (axis == 1 && + (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || + rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { + // This code handles per channel quant when axis = 1 and + // channels_last contig. + // If axis = 0 and channels_last contig, implementation for channels + // first (NCHW) works. + for (auto b = 0; b < batches; ++b) { + for (auto e = 0; e < elements_per_channel; ++e) { + for (auto c = 0; c < channels; ++c) { + auto i = b * channels * elements_per_channel + e * channels + c; + out[i] = at::native::quantize_val( + scales_data[c], zero_points_data[c], in[i]); + } + } + } + } else { + for (auto b = 0; b < batches; ++b) { + for (auto c = 0; c < channels; ++c) { + for (auto e = 0; e < elements_per_channel; ++e) { + auto i = b * channels * elements_per_channel + + c * elements_per_channel + e; + out[i] = at::native::quantize_val( + scales_data[c], zero_points_data[c], in[i]); + } + } + } + } +} + +#if defined(__ARM_NEON__) || defined(__aarch64__) +// Specialized implementation from caffe2::Int8Quantize. +// There may be slight accuracy difference between this and implementation of +// quantize_val +// TODO Update quantize_tensor_per_channel_impl implementation to follow +// quantize_val, i.e. f = Round(value/scale + zero_point) +// TODO Make quantize_tensor_per_channel_impl work for other datatypes too +// (int8, int32). +template <> +void quantize_tensor_per_channel_impl( + Tensor rtensor, + Tensor qtensor, + Tensor scales, + Tensor zero_points, + int64_t axis) { + int64_t batches = size_to_dim_(axis, rtensor.sizes()); + int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes()); + int64_t channels = rtensor.size(axis); + auto scales_data = scales.data_ptr(); + auto zero_points_data = zero_points.data_ptr(); + const float* in = rtensor.data_ptr(); + auto out = (uint8_t*)qtensor.data_ptr(); +#if defined(__ARM_NEON__) + // magic float and magic int to take care of rounding + // int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000 + // Some detail: + // 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you + // add a small number to a large number, the result rounds to the precision of + // the least significant bit of the large number. For IEEE-754 + // single-precision number mantissa has 23 bits, and adding 2**23 would cause + // rounding to the nearest even integer. The we cast to int and subtract the + // same number (0x4B400000 is the integer representation of 12582912.0f) to + // get only the mantissa. This works if -2**22 < x < 2**22, but preserves the + // sign for negative numbers. + const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f); + // Copy reciprocal of scales (double) into float array + // Copy zero_points with magic int (int64_t) into int32_t array + std::vector inv_scales(channels); + std::vector zero_points_int32t(channels); + for (int i = 0; i < channels; ++i) { + inv_scales[i] = 1.0f / (float)scales_data[i]; + zero_points_int32t[i] = (int32_t)(uint32_t)zero_points_data[i] - 0x4B400000; + } + if (axis == 1 && + (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || + rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { + // This code handles per channel quant when axis = 1 and + // channels_last contig. + // If axis = 0 and channels_last contig, implementation for channels + // first (NCHW) works. + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t e = 0; e < elements_per_channel; ++e) { + uint32_t c = 0; + while (c + 8 < channels) { + const int32x4_t voffset0123 = vld1q_s32(&zero_points_int32t[c]); + const float32x4_t vinv_scale0123 = vld1q_f32(&inv_scales[c]); + c += 4; + const int32x4_t voffset4567 = vld1q_s32(&zero_points_int32t[c]); + const float32x4_t vinv_scale4567 = vld1q_f32(&inv_scales[c]); + c += 4; + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t vraw0123 = vaddq_s32( + voffset0123, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale0123)))); + const int32x4_t vraw4567 = vaddq_s32( + voffset4567, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale4567)))); + const int16x8_t vraw01234567 = + vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567)); + const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567); + vst1_u8(out, vout01234567); + out += 8; + } + for (; c < channels; ++c) { + (*out++) = + at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++)); + } + } + } + } else { + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t c = 0; c < channels; ++c) { + uint32_t e = 0; + const int32x4_t voffset = vdupq_n_s32(zero_points_int32t[c]); + const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]); + for (; e + 8 < elements_per_channel; e += 8) { + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t vraw0123 = vaddq_s32( + voffset, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale)))); + const int32x4_t vraw4567 = vaddq_s32( + voffset, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale)))); + const int16x8_t vraw01234567 = + vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567)); + const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567); + vst1_u8(out, vout01234567); + out += 8; + } + for (; e < elements_per_channel; ++e) { + (*out++) = + at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++)); + } + } + } + } +#else // defined(__ARM_NEON__) + // Copy scales (double) into float array + // Copy zero_points (int64_t) into int16_t array + std::vector inv_scales(channels); + std::vector zero_points_int16t(channels); + for (int i = 0; i < channels; ++i) { + inv_scales[i] = 1.0f / (float)scales_data[i]; + zero_points_int16t[i] = (int16_t)(uint16_t)zero_points_data[i]; + } + if (axis == 1 && + (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || + rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { + // This code handles per channel quant when axis = 1 and + // channels_last contig. + // If axis = 0 and channels_last contig, implementation for channels + // first (NCHW) works. + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t e = 0; e < elements_per_channel; ++e) { + uint32_t c = 0; + while (c + 8 < channels) { + const int16x8_t vzero_point = vld1q_s16(&zero_points_int16t[c]); + const float32x4_t vinv_scale0123 = vld1q_f32(&inv_scales[c]); + c += 4; + const float32x4_t vinv_scale4567 = vld1q_f32(&inv_scales[c]); + c += 4; + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t v0123_rounded = + vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale0123)); + const int32x4_t v4567_rounded = + vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale4567)); + const int16x8_t v01234567_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), + vzero_point); + const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed); + vst1_u8(out, vout01234567); + out += 8; + } + for (; c < channels; ++c) { + (*out++) = + at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++)); + } + } + } + } else { + for (uint32_t b = 0; b < batches; ++b) { + for (uint32_t c = 0; c < channels; ++c) { + uint32_t e = 0; + const int16x8_t vzero_point = vdupq_n_s16(zero_points_int16t[c]); + const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]); + for (; e + 8 < elements_per_channel; e += 8) { + const float32x4_t vin0123 = vld1q_f32(in); + in += 4; + const float32x4_t vin4567 = vld1q_f32(in); + in += 4; + const int32x4_t v0123_rounded = + vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale)); + const int32x4_t v4567_rounded = + vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale)); + const int16x8_t v01234567_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), + vzero_point); + const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed); + vst1_u8(out, vout01234567); + out += 8; + } + for (; e < elements_per_channel; ++e) { + (*out++) = + at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++)); + } + } + } + } +#endif // defined(__ARM_NEON__) +} +#endif // defined(__ARM_NEON__) || defined(__aarch64__) + +void quantize_tensor_per_channel_affine_cpu( + Tensor rtensor, + Tensor qtensor, + Tensor scales, + Tensor zero_points, + int64_t axis) { + TORCH_CHECK( + rtensor.is_contiguous() || (axis <= 1), "If tensor is channels_last contig then per channel quantization " "is supported only for axis = 0 or 1."); AT_DISPATCH_QINT_TYPES( qtensor.scalar_type(), "quantize_tensor_per_channel_affine_cpu", [&]() { - int64_t batches = size_to_dim_(axis, rtensor.sizes()); - int64_t elements_per_channel = - size_from_dim_(axis + 1, rtensor.sizes()); - int64_t channel = rtensor.size(axis); - auto scales_data = scales.data_ptr(); - auto zero_points_data = zero_points.data_ptr(); check_tensor_memory_format(rtensor, qtensor); - const float* rdata = rtensor.data_ptr(); - auto qdata = qtensor.data_ptr(); - if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) || - rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) { - // This code handles per channel quant when axis = 1 and - // channels_last contig. - // If axis = 0 and channels_last contig, implementation - // for channels first (NCHW) works. - for (auto b = 0; b < batches; ++b) { - for (auto e = 0; e < elements_per_channel; ++e) { - for (auto c = 0; c < channel; ++c) { - auto i = b * channel * elements_per_channel + e * channel + c; - qdata[i] = quantize_val( - scales_data[c], zero_points_data[c], rdata[i]); - } - } - } - } else { - for (auto b = 0; b < batches; ++b) { - for (auto c = 0; c < channel; ++c) { - for (auto e = 0; e < elements_per_channel; ++e) { - auto i = b * channel * elements_per_channel + - c * elements_per_channel + e; - qdata[i] = quantize_val( - scales_data[c], zero_points_data[c], rdata[i]); - } - } - } - } + quantize_tensor_per_channel_impl( + rtensor, qtensor, scales, zero_points, axis); }); } diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 4268db33fa16..89712d7bbcfb 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -82,7 +82,8 @@ list(APPEND ATen_VULKAN_TEST_SRCS list(APPEND ATen_MOBILE_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/vec256_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpu_profiling_allocator_test.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpu_caching_allocator_test.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/cpu_caching_allocator_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized_test.cpp) list(APPEND ATen_VEC256_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/vec256_test.cpp diff --git a/aten/src/ATen/test/quantized_test.cpp b/aten/src/ATen/test/quantized_test.cpp index a2b64618ccfe..71ce5053334b 100644 --- a/aten/src/ATen/test/quantized_test.cpp +++ b/aten/src/ATen/test/quantized_test.cpp @@ -101,9 +101,8 @@ TEST(TestQTensor, EmptyQuantized) { int zero_point = 10; int val = 100; int numel = 10; - Tensor q = at::_empty_affine_quantized({numel}, - at::device(at::kCPU).dtype(kQUInt8), - scale, zero_point); + Tensor q = at::_empty_affine_quantized( + {numel}, at::device(at::kCPU).dtype(kQUInt8), scale, zero_point); // Assigning to QTensor auto* q_data = q.data_ptr(); for (int i = 0; i < numel; ++i) { @@ -142,7 +141,66 @@ TEST(TestQTensor, EmptyPerchannelQuantized) { for (int i = 0; i < numel; ++i) { ASSERT_EQ( r_data[i], - (val - zero_points[i].item().to()) * - scales[i].item().to()); + (val - zero_points[i].item().to()) * scales[i].item().to()); + } +} + +TEST(TestQTensor, QuantizePerChannel4d) { + int C = 32, H = 10, W = 10; + auto scales = rand({C}).toType(kDouble); + auto zero_points = randint(10, {C}).toType(kLong); + int ch_axis = 1; + // create 4d tensor where each H x W image is a range(0, H*W) + Tensor tensor = at::empty({1, C, H, W}, at::device(at::kCPU).dtype(kFloat)); + auto* tensor_data = tensor.data_ptr(); + for (int c = 0, i = 0; c < C; ++c) { + for (int e = 0; e < H * W; ++e, ++i) { + tensor_data[i] = e; + } + } + // quantize and check values + Tensor q = at::native::quantize_per_channel_cpu( + tensor, scales, zero_points, ch_axis, kQUInt8); + auto* q_data = (uint8_t*)q.data_ptr(); + for (int c = 0, i = 0; c < C; ++c) { + auto scale = scales[c].item(); + auto zero_point = zero_points[c].item(); + for (int e = 0; e < H * W; ++e, ++i) { + // downsize qval to 255 if val is greater than max uint8_t value + int qval = std::min((int)round(e / scale) + zero_point, 255); + ASSERT_EQ((int)q_data[i], qval); + } + } +} + +TEST(TestQTensor, QuantizePerChannel4dChannelsLast) { + int C = 32, H = 10, W = 10; + auto scales = rand({C}).toType(kFloat); + auto zero_points = randint(10, {C}).toType(kInt); + int ch_axis = 1; + // create 4d tensor where each H x W image is a range(0, H*W) + Tensor tensor = at::empty( + {1, C, H, W}, + at::device(at::kCPU).dtype(kFloat).memory_format( + at::MemoryFormat::ChannelsLast)); + auto* tensor_data = tensor.data_ptr(); + for (int e = 0, i = 0; e < H * W; ++e) { + for (int c = 0; c < C; ++c, ++i) { + tensor_data[i] = e; + } + } + + // quantize and check values + Tensor q = at::native::quantize_per_channel_cpu( + tensor, scales, zero_points, ch_axis, kQUInt8); + auto* q_data = (uint8_t*)q.data_ptr(); + for (int e = 0, i = 0; e < H * W; ++e) { + for (int c = 0; c < C; ++c, ++i) { + auto scale = scales[c].item(); + auto zero_point = zero_points[c].item(); + // downsize qval to 255 if val is greater than max uint8_t value + int qval = std::min((int)round(e / scale) + zero_point, 255); + ASSERT_EQ((int)q_data[i], qval); + } } }