Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 196 additions & 13 deletions kernels/quantized/cpu/op_dequantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#include <algorithm>
#include <cinttypes>
#include <cmath>
#if defined(__aarch64__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif

/**
* For an input tensor, use the scale and zero_point arguments to quantize it.
Expand All @@ -22,6 +25,8 @@ namespace native {
using Tensor = exec_aten::Tensor;
using Scalar = exec_aten::Scalar;
using ScalarType = exec_aten::ScalarType;
using StridesType = exec_aten::StridesType;
using SizesType = exec_aten::SizesType;

namespace {

Expand Down Expand Up @@ -63,6 +68,183 @@ void check_dequantize_per_tensor_args(
quant_max);
}

/**
* Useful to reduce a tensor `in` over a given dimension `dim` using the
* reduce function `fn`, which should have the following signature:
* void fn(const size_t size, const size_t stride, const size_t base_ix)
* where `size` and `stride` are the size and stride of the dimension being
* reduced and `base_ix` is the index of the first element of the reduction.
*/
template <typename Fn>
void apply_over_unpacked_dim(
const Fn& fn,
const exec_aten::Tensor& in,
const int64_t& dim) {
if (in.numel() == 0) {
return;
}

ET_CHECK_MSG(in.dim() > 0, "Input tensor must have at least one dimension");
ET_CHECK_VALID_DIM(dim, in.dim());

const size_t d = ET_NORMALIZE_IX(dim, in.dim());
const size_t dim_size = in.size(d);
const size_t outer_size = getLeadingDims(in, d);
const size_t inner_size = getTrailingDims(in, d);
// Loop through all outer dimensions
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
// Loop through dim
for (size_t unpacked_dim_idx = 0; unpacked_dim_idx < dim_size;
++unpacked_dim_idx) {
fn(inner_size, outer_idx, unpacked_dim_idx);
}
}
}

void dequantize_optimized(
const int8_t* in,
const double scale,
const int64_t zero_point,
float* out,
int64_t quant_min,
int64_t quant_max,
size_t numel) {
ET_CHECK_MSG(
zero_point >= quant_min,
"zero_point must be %" PRId64 " <= quant_min %" PRId64,
zero_point,
quant_min);
ET_CHECK_MSG(
zero_point <= quant_max,
"zero_point must be %" PRId64 " >= quant_max %" PRId64,
zero_point,
quant_max);
size_t i = 0;
#if defined(__aarch64__) || defined(__ARM_NEON)
int8x8_t zero_point_vec = vdup_n_s8(zero_point);
float32x4_t scales = vdupq_n_f32(static_cast<float>(scale));
constexpr int32_t kVecSize = 16;
const size_t num_vecs = numel / kVecSize;
const int8_t* in_copy = in;
float* out_copy = out;
for (; i < num_vecs; i++) {
int8x16_t in_vec = vld1q_s8(in_copy);
int16x8_t sub_vec_0_7 = vsubl_s8(vget_low_s8(in_vec), zero_point_vec);
int32x4_t sub_vec_0_3 = vmovl_s16(vget_low_s16(sub_vec_0_7));
int32x4_t sub_vec_4_7 = vmovl_s16(vget_high_s16(sub_vec_0_7));
float32x4_t out_vec_0_3 = vmulq_f32(vcvtq_f32_s32(sub_vec_0_3), scales);
float32x4_t out_vec_4_7 = vmulq_f32(vcvtq_f32_s32(sub_vec_4_7), scales);

int16x8_t sub_vec_8_15 = vsubl_s8(vget_high_s8(in_vec), zero_point_vec);
int32x4_t sub_vec_8_11 = vmovl_s16(vget_low_s16(sub_vec_8_15));
int32x4_t sub_vec_12_15 = vmovl_s16(vget_high_s16(sub_vec_8_15));
float32x4_t out_vec_8_11 = vmulq_f32(vcvtq_f32_s32(sub_vec_8_11), scales);
float32x4_t out_vec_12_15 = vmulq_f32(vcvtq_f32_s32(sub_vec_12_15), scales);
vst1q_f32(out_copy + 0, out_vec_0_3);
vst1q_f32(out_copy + 4, out_vec_4_7);
vst1q_f32(out_copy + 8, out_vec_8_11);
vst1q_f32(out_copy + 12, out_vec_12_15);
in_copy += kVecSize;
out_copy += kVecSize;
}
i = i * kVecSize;
#endif
for (; i < numel; i++) {
out[i] = (in[i] - zero_point) * scale;
}
}

float get_scale(const Tensor& scale, size_t channel_ix) {
ET_CHECK_MSG(
(scale.scalar_type() == ScalarType::Double) ||
(scale.scalar_type() == ScalarType::Float),
"scale.scalar_type() %" PRId8 " is not double or float type",
static_cast<int8_t>(scale.scalar_type()));
if (scale.scalar_type() == ScalarType::Double) {
return static_cast<float>(scale.const_data_ptr<double>()[channel_ix]);
} else {
return scale.const_data_ptr<float>()[channel_ix];
}
}

bool can_use_optimized_dequantize_per_channel(
const Tensor& in,
const ScalarType in_dtype,
exec_aten::optional<ScalarType>& out_dtype) {
bool is_contiguous = false;
#ifdef USE_ATEN_LIB
is_contiguous = in.is_contiguous();
#else
is_contiguous = executorch::runtime::is_contiguous_dim_order(
in.dim_order().data(), in.dim());
#endif
if (!is_contiguous || (in_dtype != ScalarType::Char) ||
(out_dtype.has_value() && out_dtype.value() != ScalarType::Float)) {
return false;
}
return true;
}

void dequantize_per_channel_optimized(
const Tensor& in,
const Tensor& scales,
const optional<Tensor>& opt_zero_points,
Tensor& out,
int64_t axis,
int64_t quant_min,
int64_t quant_max,
ScalarType in_dtype,
exec_aten::optional<ScalarType>& out_dtype) {
check_dequantize_per_tensor_args(
in, quant_min, quant_max, in_dtype, out_dtype, out);
ET_CHECK_MSG(
in_dtype == ScalarType::Char,
"in.scalar_type() %" PRId8 " is not supported:",
static_cast<int8_t>(in.scalar_type()));
if (out_dtype.has_value()) {
ET_CHECK_MSG(
out_dtype.value() == ScalarType::Float,
"Only float output is supported");
}
const int8_t* in_data = in.const_data_ptr<int8_t>();
float* out_data = out.mutable_data_ptr<float>();
const int64_t* zero_points_data = nullptr;
if (opt_zero_points.has_value()) {
zero_points_data = opt_zero_points.value().const_data_ptr<int64_t>();
}
const StridesType axis_stride = in.strides()[axis];
const StridesType outer_stride = in.size(axis) * axis_stride;
apply_over_unpacked_dim(
[in_data,
out_data,
&scales,
zero_points_data,
axis_stride,
outer_stride,
quant_min,
quant_max](
SizesType numel, SizesType outer_idx, SizesType unpacked_dim_idx) {
const int8_t* in_data_local =
in_data + outer_idx * outer_stride + unpacked_dim_idx * axis_stride;
const double scale = get_scale(scales, unpacked_dim_idx);
const int64_t zero_point = zero_points_data != nullptr
? zero_points_data[unpacked_dim_idx]
: 0;
float* out_data_local = out_data + outer_idx * outer_stride +
unpacked_dim_idx * axis_stride;
dequantize_optimized(
in_data_local,
scale,
zero_point,
out_data_local,
quant_min,
quant_max,
numel);
},
in,
axis);
}

} // namespace

/**
Expand Down Expand Up @@ -172,19 +354,6 @@ Tensor& dequantize_per_tensor_tensor_args_out(
return out;
}

float get_scale(const Tensor& scale, size_t channel_ix) {
ET_CHECK_MSG(
(scale.scalar_type() == ScalarType::Double) ||
(scale.scalar_type() == ScalarType::Float),
"scale.scalar_type() %" PRId8 " is not double or float type",
static_cast<int8_t>(scale.scalar_type()));
if (scale.scalar_type() == ScalarType::Double) {
return static_cast<float>(scale.const_data_ptr<double>()[channel_ix]);
} else {
return scale.const_data_ptr<float>()[channel_ix];
}
}

Tensor& dequantize_per_channel_out(
const Tensor& input,
const Tensor& scale,
Expand Down Expand Up @@ -229,6 +398,20 @@ Tensor& dequantize_per_channel_out(
check_dequantize_per_tensor_args(
input, quant_min, quant_max, dtype, out_dtype, out);

if (can_use_optimized_dequantize_per_channel(input, dtype, out_dtype)) {
dequantize_per_channel_optimized(
input,
scale,
opt_zero_points,
out,
axis,
quant_min,
quant_max,
dtype,
out_dtype);
return out;
}

// a list contains all dimensions except axis
int64_t dims[kTensorDimensionLimit];
for (int64_t i = 0; i < input.dim() - 1; i++) {
Expand Down
50 changes: 42 additions & 8 deletions kernels/quantized/test/op_dequantize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,13 @@ TEST(OpDequantizeOutTest, TensorArgOverload) {
EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpDequantizeOutTest, DequantizePerChannel) {
et_pal_init();
TensorFactory<ScalarType::Byte> tf_byte;
template <ScalarType DTYPE>
void test_per_channel_dtype() {
TensorFactory<DTYPE> tf;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

Tensor input = tf_byte.full({3, 2}, 100);
Tensor input = tf.full({3, 2}, 100);
Tensor scale = tf_double.make({2}, {0.5, 1});
Tensor zero_point = tf_long.make({2}, {30, 60});
int64_t quant_min = 0;
Expand All @@ -147,7 +147,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
/*axis=*/1,
quant_min,
quant_max,
ScalarType::Byte,
DTYPE,
optional<ScalarType>(),
out);

Expand All @@ -168,15 +168,15 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
/*axis=*/0,
quant_min,
quant_max,
ScalarType::Byte,
DTYPE,
optional<ScalarType>(),
out);

EXPECT_TENSOR_EQ(out, expected);

// Test with a different axis
out = tfo.zeros({3});
input = tf_byte.make({3}, {100, 100, 100});
input = tf.make({3}, {100, 100, 100});
scale = tf_double.make({3}, {0.5, 0.75, 1});
zero_point = tf_long.make({3}, {30, 50, 60});
// (100 - 30) * 0.5
Expand All @@ -190,8 +190,42 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
/*axis=*/0,
quant_min,
quant_max,
ScalarType::Byte,
DTYPE,
optional<ScalarType>(),
out);
EXPECT_TENSOR_EQ(out, expected);

// Test with a different axis
input = tf.full({3, 19}, 100);
out = tfo.zeros({3, 19});
scale = tf_double.make({3}, {0.5, 0.75, 1});
zero_point = tf_long.make({3}, {30, 50, 60});
// (100 - 30) * 0.5
// (100 - 50) * 0.75
// (100 - 60) * 1
expected = tfo.make(
{3, 19},
{35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35,
35, 35, 35, 35, 35, 35, 35, 37.5, 37.5, 37.5, 37.5, 37.5,
37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5, 37.5,
37.5, 37.5, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40,
40, 40, 40, 40, 40, 40, 40, 40, 40});
dequantize_per_channel_out(
input,
scale,
zero_point,
/*axis=*/0,
quant_min,
quant_max,
DTYPE,
optional<ScalarType>(),
out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpDequantizeOutTest, DequantizePerChannel) {
et_pal_init();
test_per_channel_dtype<ScalarType::Byte>();
test_per_channel_dtype<ScalarType::Char>();
}
Loading