Skip to content

Commit

Permalink
quantize_tensor_per_channel ARM implementation
Browse files Browse the repository at this point in the history
Summary: Currently on mobile devices quantize_tensor has a vectorized implementation
using ARM intrinsics; however quantize_tensor_per_channel does not.

Test Plan: Build and push to mobile device
```
BUILD_MOBILE_BENCHMARK=1 BUILD_MOBILE_TEST=1 ANDROID_DEBUG_SYMBOLS=1 BUILD_PYTORCH_MOBILE=1 ANDROID_ABI=arm64-v8a ./scripts/build_android.sh  -DANDROID_CCACHE=$(which ccache) -DBUILD_BINARY=ON
adb push build_android/bin/quantize_per_channel /data/local/tmp
```
and then run the benchmark binary over adb shell

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 73198c6b8608833effb1840697f883f7402cf1fc
Pull Request resolved: #46018
  • Loading branch information
ajliu committed Oct 17, 2020
1 parent dae747a commit 41bdb3b
Show file tree
Hide file tree
Showing 3 changed files with 310 additions and 43 deletions.
282 changes: 245 additions & 37 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Expand Up @@ -2568,7 +2568,9 @@ void dequantize_tensor_per_tensor_affine_cpu(
#endif // USE_FBGEMM

// TODO: add fbgemm for per channel
void quantize_tensor_per_channel_affine_cpu(
// Generic template defaults to naive quantize implementation
template <typename T>
void quantize_tensor_per_channel_impl(
Tensor rtensor,
Tensor qtensor,
Tensor scales,
Expand All @@ -2580,47 +2582,253 @@ void quantize_tensor_per_channel_affine_cpu(
// Since current implemntation on channels_last format does not
// cover per channel quant with arbitrary axis value, it is better
// to check and fail.
TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
int64_t channels = rtensor.size(axis);
auto scales_data = scales.data_ptr<double>();
auto zero_points_data = zero_points.data_ptr<int64_t>();
const float* in = rtensor.data_ptr<float>();
auto out = qtensor.data_ptr<T>();
if (axis == 1 &&
(rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
// This code handles per channel quant when axis = 1 and
// channels_last contig.
// If axis = 0 and channels_last contig, implementation for channels
// first (NCHW) works.
for (auto b = 0; b < batches; ++b) {
for (auto e = 0; e < elements_per_channel; ++e) {
for (auto c = 0; c < channels; ++c) {
auto i = b * channels * elements_per_channel + e * channels + c;
out[i] = at::native::quantize_val<T>(
scales_data[c], zero_points_data[c], in[i]);
}
}
}
} else {
for (auto b = 0; b < batches; ++b) {
for (auto c = 0; c < channels; ++c) {
for (auto e = 0; e < elements_per_channel; ++e) {
auto i = b * channels * elements_per_channel +
c * elements_per_channel + e;
out[i] = at::native::quantize_val<T>(
scales_data[c], zero_points_data[c], in[i]);
}
}
}
}
}

#if defined(__ARM_NEON__) || defined(__aarch64__)
// Specialized implementation from caffe2::Int8Quantize.
// There may be slight accuracy difference between this and implementation of
// quantize_val
// TODO Update quantize_tensor_per_channel_impl implementation to follow
// quantize_val, i.e. f = Round(value/scale + zero_point)
// TODO Make quantize_tensor_per_channel_impl work for other datatypes too
// (int8, int32).
template <>
void quantize_tensor_per_channel_impl<c10::quint8>(
Tensor rtensor,
Tensor qtensor,
Tensor scales,
Tensor zero_points,
int64_t axis) {
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel = size_from_dim_(axis + 1, rtensor.sizes());
int64_t channels = rtensor.size(axis);
auto scales_data = scales.data_ptr<double>();
auto zero_points_data = zero_points.data_ptr<int64_t>();
const float* in = rtensor.data_ptr<float>();
auto out = (uint8_t*)qtensor.data_ptr<c10::quint8>();
#if defined(__ARM_NEON__)
// magic float and magic int to take care of rounding
// int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000
// Some detail:
// 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you
// add a small number to a large number, the result rounds to the precision of
// the least significant bit of the large number. For IEEE-754
// single-precision number mantissa has 23 bits, and adding 2**23 would cause
// rounding to the nearest even integer. The we cast to int and subtract the
// same number (0x4B400000 is the integer representation of 12582912.0f) to
// get only the mantissa. This works if -2**22 < x < 2**22, but preserves the
// sign for negative numbers.
const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f);
// Copy reciprocal of scales (double) into float array
// Copy zero_points with magic int (int64_t) into int32_t array
std::vector<float> inv_scales(channels);
std::vector<int32_t> zero_points_int32t(channels);
for (int i = 0; i < channels; ++i) {
inv_scales[i] = 1.0f / (float)scales_data[i];
zero_points_int32t[i] = (int32_t)(uint32_t)zero_points_data[i] - 0x4B400000;
}
if (axis == 1 &&
(rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
// This code handles per channel quant when axis = 1 and
// channels_last contig.
// If axis = 0 and channels_last contig, implementation for channels
// first (NCHW) works.
for (uint32_t b = 0; b < batches; ++b) {
for (uint32_t e = 0; e < elements_per_channel; ++e) {
uint32_t c = 0;
while (c + 8 < channels) {
const int32x4_t voffset0123 = vld1q_s32(&zero_points_int32t[c]);
const float32x4_t vinv_scale0123 = vld1q_f32(&inv_scales[c]);
c += 4;
const int32x4_t voffset4567 = vld1q_s32(&zero_points_int32t[c]);
const float32x4_t vinv_scale4567 = vld1q_f32(&inv_scales[c]);
c += 4;
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t vraw0123 = vaddq_s32(
voffset0123,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale0123))));
const int32x4_t vraw4567 = vaddq_s32(
voffset4567,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale4567))));
const int16x8_t vraw01234567 =
vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
vst1_u8(out, vout01234567);
out += 8;
}
for (; c < channels; ++c) {
(*out++) =
at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++));
}
}
}
} else {
for (uint32_t b = 0; b < batches; ++b) {
for (uint32_t c = 0; c < channels; ++c) {
uint32_t e = 0;
const int32x4_t voffset = vdupq_n_s32(zero_points_int32t[c]);
const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]);
for (; e + 8 < elements_per_channel; e += 8) {
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t vraw0123 = vaddq_s32(
voffset,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale))));
const int32x4_t vraw4567 = vaddq_s32(
voffset,
vreinterpretq_s32_f32(
vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale))));
const int16x8_t vraw01234567 =
vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
vst1_u8(out, vout01234567);
out += 8;
}
for (; e < elements_per_channel; ++e) {
(*out++) =
at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++));
}
}
}
}
#else // defined(__ARM_NEON__)
// Copy scales (double) into float array
// Copy zero_points (int64_t) into int16_t array
std::vector<float> inv_scales(channels);
std::vector<int16_t> zero_points_int16t(channels);
for (int i = 0; i < channels; ++i) {
inv_scales[i] = 1.0f / (float)scales_data[i];
zero_points_int16t[i] = (int16_t)(uint16_t)zero_points_data[i];
}
if (axis == 1 &&
(rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
// This code handles per channel quant when axis = 1 and
// channels_last contig.
// If axis = 0 and channels_last contig, implementation for channels
// first (NCHW) works.
for (uint32_t b = 0; b < batches; ++b) {
for (uint32_t e = 0; e < elements_per_channel; ++e) {
uint32_t c = 0;
while (c + 8 < channels) {
const int16x8_t vzero_point = vld1q_s16(&zero_points_int16t[c]);
const float32x4_t vinv_scale0123 = vld1q_f32(&inv_scales[c]);
c += 4;
const float32x4_t vinv_scale4567 = vld1q_f32(&inv_scales[c]);
c += 4;
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t v0123_rounded =
vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale0123));
const int32x4_t v4567_rounded =
vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale4567));
const int16x8_t v01234567_packed = vqaddq_s16(
vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded),
vzero_point);
const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed);
vst1_u8(out, vout01234567);
out += 8;
}
for (; c < channels; ++c) {
(*out++) =
at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++));
}
}
}
} else {
for (uint32_t b = 0; b < batches; ++b) {
for (uint32_t c = 0; c < channels; ++c) {
uint32_t e = 0;
const int16x8_t vzero_point = vdupq_n_s16(zero_points_int16t[c]);
const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]);
for (; e + 8 < elements_per_channel; e += 8) {
const float32x4_t vin0123 = vld1q_f32(in);
in += 4;
const float32x4_t vin4567 = vld1q_f32(in);
in += 4;
const int32x4_t v0123_rounded =
vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale));
const int32x4_t v4567_rounded =
vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale));
const int16x8_t v01234567_packed = vqaddq_s16(
vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded),
vzero_point);
const uint8x8_t vout01234567 = vqmovun_s16(v01234567_packed);
vst1_u8(out, vout01234567);
out += 8;
}
for (; e < elements_per_channel; ++e) {
(*out++) =
at::native::quantize_val_arm(scales_data[c], zero_points_data[c], (*in++));
}
}
}
}
#endif // defined(__ARM_NEON__)
}
#endif // defined(__ARM_NEON__) || defined(__aarch64__)

void quantize_tensor_per_channel_affine_cpu(
Tensor rtensor,
Tensor qtensor,
Tensor scales,
Tensor zero_points,
int64_t axis) {
TORCH_CHECK(
rtensor.is_contiguous() || (axis <= 1),
"If tensor is channels_last contig then per channel quantization "
"is supported only for axis = 0 or 1.");
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_channel_affine_cpu", [&]() {
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel =
size_from_dim_(axis + 1, rtensor.sizes());
int64_t channel = rtensor.size(axis);
auto scales_data = scales.data_ptr<double>();
auto zero_points_data = zero_points.data_ptr<int64_t>();
check_tensor_memory_format(rtensor, qtensor);
const float* rdata = rtensor.data_ptr<float>();
auto qdata = qtensor.data_ptr<scalar_t>();
if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
// This code handles per channel quant when axis = 1 and
// channels_last contig.
// If axis = 0 and channels_last contig, implementation
// for channels first (NCHW) works.
for (auto b = 0; b < batches; ++b) {
for (auto e = 0; e < elements_per_channel; ++e) {
for (auto c = 0; c < channel; ++c) {
auto i = b * channel * elements_per_channel + e * channel + c;
qdata[i] = quantize_val<scalar_t>(
scales_data[c], zero_points_data[c], rdata[i]);
}
}
}
} else {
for (auto b = 0; b < batches; ++b) {
for (auto c = 0; c < channel; ++c) {
for (auto e = 0; e < elements_per_channel; ++e) {
auto i = b * channel * elements_per_channel +
c * elements_per_channel + e;
qdata[i] = quantize_val<scalar_t>(
scales_data[c], zero_points_data[c], rdata[i]);
}
}
}
}
quantize_tensor_per_channel_impl<scalar_t>(
rtensor, qtensor, scales, zero_points, axis);
});
}

Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/test/CMakeLists.txt
Expand Up @@ -82,7 +82,8 @@ list(APPEND ATen_VULKAN_TEST_SRCS
list(APPEND ATen_MOBILE_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/vec256_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpu_profiling_allocator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpu_caching_allocator_test.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/cpu_caching_allocator_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/quantized_test.cpp)

list(APPEND ATen_VEC256_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/vec256_test.cpp
Expand Down
68 changes: 63 additions & 5 deletions aten/src/ATen/test/quantized_test.cpp
Expand Up @@ -101,9 +101,8 @@ TEST(TestQTensor, EmptyQuantized) {
int zero_point = 10;
int val = 100;
int numel = 10;
Tensor q = at::_empty_affine_quantized({numel},
at::device(at::kCPU).dtype(kQUInt8),
scale, zero_point);
Tensor q = at::_empty_affine_quantized(
{numel}, at::device(at::kCPU).dtype(kQUInt8), scale, zero_point);
// Assigning to QTensor
auto* q_data = q.data_ptr<quint8>();
for (int i = 0; i < numel; ++i) {
Expand Down Expand Up @@ -142,7 +141,66 @@ TEST(TestQTensor, EmptyPerchannelQuantized) {
for (int i = 0; i < numel; ++i) {
ASSERT_EQ(
r_data[i],
(val - zero_points[i].item().to<int>()) *
scales[i].item().to<float>());
(val - zero_points[i].item().to<int>()) * scales[i].item().to<float>());
}
}

TEST(TestQTensor, QuantizePerChannel4d) {
int C = 32, H = 10, W = 10;
auto scales = rand({C}).toType(kDouble);
auto zero_points = randint(10, {C}).toType(kLong);
int ch_axis = 1;
// create 4d tensor where each H x W image is a range(0, H*W)
Tensor tensor = at::empty({1, C, H, W}, at::device(at::kCPU).dtype(kFloat));
auto* tensor_data = tensor.data_ptr<float>();
for (int c = 0, i = 0; c < C; ++c) {
for (int e = 0; e < H * W; ++e, ++i) {
tensor_data[i] = e;
}
}
// quantize and check values
Tensor q = at::native::quantize_per_channel_cpu(
tensor, scales, zero_points, ch_axis, kQUInt8);
auto* q_data = (uint8_t*)q.data_ptr<quint8>();
for (int c = 0, i = 0; c < C; ++c) {
auto scale = scales[c].item<float>();
auto zero_point = zero_points[c].item<int>();
for (int e = 0; e < H * W; ++e, ++i) {
// downsize qval to 255 if val is greater than max uint8_t value
int qval = std::min((int)round(e / scale) + zero_point, 255);
ASSERT_EQ((int)q_data[i], qval);
}
}
}

TEST(TestQTensor, QuantizePerChannel4dChannelsLast) {
int C = 32, H = 10, W = 10;
auto scales = rand({C}).toType(kFloat);
auto zero_points = randint(10, {C}).toType(kInt);
int ch_axis = 1;
// create 4d tensor where each H x W image is a range(0, H*W)
Tensor tensor = at::empty(
{1, C, H, W},
at::device(at::kCPU).dtype(kFloat).memory_format(
at::MemoryFormat::ChannelsLast));
auto* tensor_data = tensor.data_ptr<float>();
for (int e = 0, i = 0; e < H * W; ++e) {
for (int c = 0; c < C; ++c, ++i) {
tensor_data[i] = e;
}
}

// quantize and check values
Tensor q = at::native::quantize_per_channel_cpu(
tensor, scales, zero_points, ch_axis, kQUInt8);
auto* q_data = (uint8_t*)q.data_ptr<quint8>();
for (int e = 0, i = 0; e < H * W; ++e) {
for (int c = 0; c < C; ++c, ++i) {
auto scale = scales[c].item<float>();
auto zero_point = zero_points[c].item<int>();
// downsize qval to 255 if val is greater than max uint8_t value
int qval = std::min((int)round(e / scale) + zero_point, 255);
ASSERT_EQ((int)q_data[i], qval);
}
}
}

0 comments on commit 41bdb3b

Please sign in to comment.