Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 66 additions & 4 deletions kernels/quantized/cpu/op_embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,36 @@ void check_embedding_byte_args(
const int64_t weight_quant_max,
const Tensor& indices,
Tensor& out) {
ET_CHECK_MSG(
weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim());

ET_CHECK_MSG(
weight_scales.dim() <= 2,
"weight_scales must be 1D or 2D but got() %zd dims",
weight_scales.dim());

auto weight_scales_size = weight_scales.size(0);

ET_CHECK_MSG(
weight_scales_size == weight.size(0),
"Number of scales must be == weight.size(0)=%zd"
", but got %zd",
weight_scales_size,
weight.size(0));

if (weight_scales_size >= weight.size(0)) {
if (weight_scales.dim() == 2) {
auto num_groups = weight_scales.size(1);
auto remainder = weight.size(1) % num_groups;
ET_CHECK_MSG(
remainder == 0,
"Number of groups must divide weight.size(1)=%zd"
", but got # of groups = %zd",
weight.size(1),
num_groups);
}
}

ET_CHECK_MSG(
weight.scalar_type() == ScalarType::Byte ||
weight.scalar_type() == ScalarType::Char,
Expand All @@ -50,11 +80,29 @@ void check_embedding_byte_args(
static_cast<int8_t>(weight_scales.scalar_type()));

if (opt_weight_zero_points.has_value()) {
ET_CHECK_MSG(
opt_weight_zero_points.value().dim() == weight_scales.dim(),
"weight_zero_points's rank match that of weight_scales. "
"weight_zero_points rank: %" PRId8 ", weight_scales rank: %" PRId8,
static_cast<int8_t>(opt_weight_zero_points.value().dim()),
static_cast<int8_t>(weight_scales.dim()));

ET_CHECK_MSG(
opt_weight_zero_points.value().scalar_type() == out.scalar_type(),
"weight zero points scalar type %" PRId8
" does not match out.scalar_type()",
static_cast<int8_t>(opt_weight_zero_points.value().scalar_type()));

for (int32_t i = 0; i < weight_scales.dim(); ++i) {
ET_CHECK_MSG(
opt_weight_zero_points.value().size(i) == weight_scales.size(i),
"Dimension size misatch at dim %" PRId8
"Weight_zero_point size = %zd"
", weight_scales size = %zd.",
i,
opt_weight_zero_points.value().size(i),
weight_scales.size(i));
}
}

ET_CHECK_MSG(
Expand All @@ -81,10 +129,16 @@ void embedding_byte_per_channel(
const optional<Tensor>& opt_weight_zero_points,
const Tensor& indices,
Tensor& out) {
// An embedding layer nn.Embedding(num_embeddings, embedding_dim) has a weight
// of shape (num_embeddings, embedding_dim).
// An embedding layer nn.Embedding(num_embeddings, embedding_dim) has a
// weight of shape (num_embeddings, embedding_dim).
auto embedding_dim = weight.size(1);

int32_t num_groups_per_channel = 1;
if (weight_scales.dim() == 2) {
num_groups_per_channel = weight_scales.size(1);
}
int32_t group_size = weight.size(1) / num_groups_per_channel;

CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
const int64_t* indices_ptr = indices.const_data_ptr<int64_t>();

Expand All @@ -96,16 +150,24 @@ void embedding_byte_per_channel(

for (int i = 0; i < indices.numel(); i++) {
int64_t index = indices_ptr[i];
// If using groupwise embedding
int32_t qparams_index = index * num_groups_per_channel;
CTYPE_OUT zp = 0.0;
const CTYPE_OUT* scale_ptr = scales + qparams_index;
const CTYPE_OUT* zero_points_ptr = nullptr;
if (opt_weight_zero_points.has_value()) {
zp = zero_points[index];
zero_points_ptr = zero_points + qparams_index;
}
CTYPE_OUT scale = scales[index];

const CTYPE_WEIGHT* w_data =
weight.data_ptr<CTYPE_WEIGHT>() + embedding_dim * index;

for (int j = 0; j < embedding_dim; ++j) {
int32_t group_id = j / group_size;
const CTYPE_OUT scale = scale_ptr[group_id];
if (opt_weight_zero_points.has_value()) {
zp = zero_points_ptr[group_id];
}
out_data[j] = static_cast<CTYPE_OUT>(
(static_cast<float>(w_data[j]) - static_cast<float>(zp)) *
static_cast<float>(scale));
Expand Down
204 changes: 204 additions & 0 deletions kernels/quantized/test/op_embedding_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,207 @@ TEST(OpQuantizedEmbeddingTest, ConsitencyWithReferencePattern) {
EXPECT_TENSOR_EQ(out, fp_out);
EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbedding) {
et_pal_init();
TensorFactory<ScalarType::Float> tf;
TensorFactory<ScalarType::Int> tf_i;
TensorFactory<ScalarType::Long> tf_l;

int64_t quant_min = 0;
int64_t quant_max = 255;

Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5});
Tensor weight_zero_points = tf.make({3}, {1, 5, 7});
TensorFactory<ScalarType::Byte> tfo;
Tensor qweight =
tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12});

Tensor indices = tf_l.make({3}, {0, 2, 1});

Tensor out = tf.zeros({3, 4});
Tensor expected = tf.make(
{3, 4}, {3.5, 4.5, 5.5, 6.5, 1.5, 3.0, 4.5, 7.5, 5.0, 7.0, 7.0, 9.0});

quantized_embedding_byte_out(
qweight,
weight_scales,
weight_zero_points,
quant_min,
quant_max,
indices,
out);

EXPECT_TENSOR_EQ(out, expected);

// Groupwise quantization. groupsize = 2
weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.0, 2.5, 3.0});
weight_zero_points = tf.make({3, 2}, {1, 5, 7, 9, 11, 13});
/*
fp_weight = [3.5, 4.5, 7, 9,
4.5, 7.5, 6, 10,
-7.5, -5.0, -9.0, -3.0]
*/

out = tf.zeros({3, 4});
expected = tf.make(
{3, 4}, {3.5, 4.5, 7, 9, -7.5, -5.0, -9.0, -3.0, 4.5, 7.5, 6, 10});

quantized_embedding_byte_out(
qweight,
weight_scales,
weight_zero_points,
quant_min,
quant_max,
indices,
out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath1) {
et_pal_init();
TensorFactory<ScalarType::Float> tf;
TensorFactory<ScalarType::Int> tf_i;
TensorFactory<ScalarType::Long> tf_l;

int64_t quant_min = 0;
int64_t quant_max = 255;

Tensor weight_scales = tf.make({4}, {0.5, 1.0, 1.5, 3.3});
Tensor weight_zero_points = tf.make({4}, {1, 5, 7, 5});
TensorFactory<ScalarType::Byte> tfo;
Tensor qweight =
tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12});

Tensor indices = tf_l.make({3}, {0, 2, 1});

Tensor out = tf.zeros({3, 4});
ET_EXPECT_DEATH(
quantized_embedding_byte_out(
qweight,
weight_scales,
weight_zero_points,
quant_min,
quant_max,
indices,
out),
"");
}

TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath2) {
et_pal_init();
TensorFactory<ScalarType::Float> tf;
TensorFactory<ScalarType::Int> tf_i;
TensorFactory<ScalarType::Long> tf_l;

int64_t quant_min = 0;
int64_t quant_max = 255;

Tensor weight_scales = tf.make({2}, {0.5, 1.0});
Tensor weight_zero_points = tf.make({2}, {1, 5});
TensorFactory<ScalarType::Byte> tfo;
Tensor qweight =
tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12});

Tensor indices = tf_l.make({3}, {0, 2, 1});

Tensor out = tf.zeros({3, 4});
ET_EXPECT_DEATH(
quantized_embedding_byte_out(
qweight,
weight_scales,
weight_zero_points,
quant_min,
quant_max,
indices,
out),
"");
}

TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath3) {
et_pal_init();
TensorFactory<ScalarType::Float> tf;
TensorFactory<ScalarType::Int> tf_i;
TensorFactory<ScalarType::Long> tf_l;

int64_t quant_min = 0;
int64_t quant_max = 255;

Tensor weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.5, 3.5, 3.5});
Tensor weight_zero_points = tf.make({3, 2}, {1, 5, 7, 9, 11, 13});
TensorFactory<ScalarType::Byte> tfo;
Tensor qweight = tfo.make({3, 3}, {8, 10, 12, 14, 10, 12, 12, 14, 8});

Tensor indices = tf_l.make({3}, {0, 2, 1});

Tensor out = tf.zeros({3, 3});
ET_EXPECT_DEATH(
quantized_embedding_byte_out(
qweight,
weight_scales,
weight_zero_points,
quant_min,
quant_max,
indices,
out),
"");
}

TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath4) {
et_pal_init();
TensorFactory<ScalarType::Float> tf;
TensorFactory<ScalarType::Int> tf_i;
TensorFactory<ScalarType::Long> tf_l;

int64_t quant_min = 0;
int64_t quant_max = 255;

Tensor weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.5, 3.5, 3.5});
Tensor weight_zero_points = tf.make({3}, {1, 5, 7});
TensorFactory<ScalarType::Byte> tfo;
Tensor qweight = tfo.make({3, 3}, {8, 10, 12, 14, 10, 12, 12, 14, 8});

Tensor indices = tf_l.make({3}, {0, 2, 1});

Tensor out = tf.zeros({3, 3});
ET_EXPECT_DEATH(
quantized_embedding_byte_out(
qweight,
weight_scales,
weight_zero_points,
quant_min,
quant_max,
indices,
out),
"");
}

TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath5) {
et_pal_init();
TensorFactory<ScalarType::Float> tf;
TensorFactory<ScalarType::Int> tf_i;
TensorFactory<ScalarType::Long> tf_l;

int64_t quant_min = 0;
int64_t quant_max = 255;

Tensor weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.5, 3.5, 3.5});
Tensor weight_zero_points = tf.make({3, 3}, {1, 5, 7, 1, 5, 7, 1, 5, 7});
TensorFactory<ScalarType::Byte> tfo;
Tensor qweight = tfo.make({3, 3}, {8, 10, 12, 14, 10, 12, 12, 14, 8});

Tensor indices = tf_l.make({3}, {0, 2, 1});

Tensor out = tf.zeros({3, 3});
ET_EXPECT_DEATH(
quantized_embedding_byte_out(
qweight,
weight_scales,
weight_zero_points,
quant_min,
quant_max,
indices,
out),
"");
}