diff --git a/kernels/quantized/cpu/op_embedding.cpp b/kernels/quantized/cpu/op_embedding.cpp index 7d33eb42c72..2964ecbab57 100644 --- a/kernels/quantized/cpu/op_embedding.cpp +++ b/kernels/quantized/cpu/op_embedding.cpp @@ -32,6 +32,36 @@ void check_embedding_byte_args( const int64_t weight_quant_max, const Tensor& indices, Tensor& out) { + ET_CHECK_MSG( + weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim()); + + ET_CHECK_MSG( + weight_scales.dim() <= 2, + "weight_scales must be 1D or 2D but got() %zd dims", + weight_scales.dim()); + + auto weight_scales_size = weight_scales.size(0); + + ET_CHECK_MSG( + weight_scales_size == weight.size(0), + "Number of scales must be == weight.size(0)=%zd" + ", but got %zd", + weight_scales_size, + weight.size(0)); + + if (weight_scales_size >= weight.size(0)) { + if (weight_scales.dim() == 2) { + auto num_groups = weight_scales.size(1); + auto remainder = weight.size(1) % num_groups; + ET_CHECK_MSG( + remainder == 0, + "Number of groups must divide weight.size(1)=%zd" + ", but got # of groups = %zd", + weight.size(1), + num_groups); + } + } + ET_CHECK_MSG( weight.scalar_type() == ScalarType::Byte || weight.scalar_type() == ScalarType::Char, @@ -50,11 +80,29 @@ void check_embedding_byte_args( static_cast(weight_scales.scalar_type())); if (opt_weight_zero_points.has_value()) { + ET_CHECK_MSG( + opt_weight_zero_points.value().dim() == weight_scales.dim(), + "weight_zero_points's rank match that of weight_scales. " + "weight_zero_points rank: %" PRId8 ", weight_scales rank: %" PRId8, + static_cast(opt_weight_zero_points.value().dim()), + static_cast(weight_scales.dim())); + ET_CHECK_MSG( opt_weight_zero_points.value().scalar_type() == out.scalar_type(), "weight zero points scalar type %" PRId8 " does not match out.scalar_type()", static_cast(opt_weight_zero_points.value().scalar_type())); + + for (int32_t i = 0; i < weight_scales.dim(); ++i) { + ET_CHECK_MSG( + opt_weight_zero_points.value().size(i) == weight_scales.size(i), + "Dimension size misatch at dim %" PRId8 + "Weight_zero_point size = %zd" + ", weight_scales size = %zd.", + i, + opt_weight_zero_points.value().size(i), + weight_scales.size(i)); + } } ET_CHECK_MSG( @@ -81,10 +129,16 @@ void embedding_byte_per_channel( const optional& opt_weight_zero_points, const Tensor& indices, Tensor& out) { - // An embedding layer nn.Embedding(num_embeddings, embedding_dim) has a weight - // of shape (num_embeddings, embedding_dim). + // An embedding layer nn.Embedding(num_embeddings, embedding_dim) has a + // weight of shape (num_embeddings, embedding_dim). auto embedding_dim = weight.size(1); + int32_t num_groups_per_channel = 1; + if (weight_scales.dim() == 2) { + num_groups_per_channel = weight_scales.size(1); + } + int32_t group_size = weight.size(1) / num_groups_per_channel; + CTYPE_OUT* out_data = out.mutable_data_ptr(); const int64_t* indices_ptr = indices.const_data_ptr(); @@ -96,16 +150,24 @@ void embedding_byte_per_channel( for (int i = 0; i < indices.numel(); i++) { int64_t index = indices_ptr[i]; + // If using groupwise embedding + int32_t qparams_index = index * num_groups_per_channel; CTYPE_OUT zp = 0.0; + const CTYPE_OUT* scale_ptr = scales + qparams_index; + const CTYPE_OUT* zero_points_ptr = nullptr; if (opt_weight_zero_points.has_value()) { - zp = zero_points[index]; + zero_points_ptr = zero_points + qparams_index; } - CTYPE_OUT scale = scales[index]; const CTYPE_WEIGHT* w_data = weight.data_ptr() + embedding_dim * index; for (int j = 0; j < embedding_dim; ++j) { + int32_t group_id = j / group_size; + const CTYPE_OUT scale = scale_ptr[group_id]; + if (opt_weight_zero_points.has_value()) { + zp = zero_points_ptr[group_id]; + } out_data[j] = static_cast( (static_cast(w_data[j]) - static_cast(zp)) * static_cast(scale)); diff --git a/kernels/quantized/test/op_embedding_test.cpp b/kernels/quantized/test/op_embedding_test.cpp index 085ca7e29c3..49605977cc3 100644 --- a/kernels/quantized/test/op_embedding_test.cpp +++ b/kernels/quantized/test/op_embedding_test.cpp @@ -167,3 +167,207 @@ TEST(OpQuantizedEmbeddingTest, ConsitencyWithReferencePattern) { EXPECT_TENSOR_EQ(out, fp_out); EXPECT_TENSOR_EQ(out, expected); } + +TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbedding) { + et_pal_init(); + TensorFactory tf; + TensorFactory tf_i; + TensorFactory tf_l; + + int64_t quant_min = 0; + int64_t quant_max = 255; + + Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5}); + Tensor weight_zero_points = tf.make({3}, {1, 5, 7}); + TensorFactory tfo; + Tensor qweight = + tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12}); + + Tensor indices = tf_l.make({3}, {0, 2, 1}); + + Tensor out = tf.zeros({3, 4}); + Tensor expected = tf.make( + {3, 4}, {3.5, 4.5, 5.5, 6.5, 1.5, 3.0, 4.5, 7.5, 5.0, 7.0, 7.0, 9.0}); + + quantized_embedding_byte_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_EQ(out, expected); + + // Groupwise quantization. groupsize = 2 + weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.0, 2.5, 3.0}); + weight_zero_points = tf.make({3, 2}, {1, 5, 7, 9, 11, 13}); + /* + fp_weight = [3.5, 4.5, 7, 9, + 4.5, 7.5, 6, 10, + -7.5, -5.0, -9.0, -3.0] + */ + + out = tf.zeros({3, 4}); + expected = tf.make( + {3, 4}, {3.5, 4.5, 7, 9, -7.5, -5.0, -9.0, -3.0, 4.5, 7.5, 6, 10}); + + quantized_embedding_byte_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath1) { + et_pal_init(); + TensorFactory tf; + TensorFactory tf_i; + TensorFactory tf_l; + + int64_t quant_min = 0; + int64_t quant_max = 255; + + Tensor weight_scales = tf.make({4}, {0.5, 1.0, 1.5, 3.3}); + Tensor weight_zero_points = tf.make({4}, {1, 5, 7, 5}); + TensorFactory tfo; + Tensor qweight = + tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12}); + + Tensor indices = tf_l.make({3}, {0, 2, 1}); + + Tensor out = tf.zeros({3, 4}); + ET_EXPECT_DEATH( + quantized_embedding_byte_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out), + ""); +} + +TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath2) { + et_pal_init(); + TensorFactory tf; + TensorFactory tf_i; + TensorFactory tf_l; + + int64_t quant_min = 0; + int64_t quant_max = 255; + + Tensor weight_scales = tf.make({2}, {0.5, 1.0}); + Tensor weight_zero_points = tf.make({2}, {1, 5}); + TensorFactory tfo; + Tensor qweight = + tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12}); + + Tensor indices = tf_l.make({3}, {0, 2, 1}); + + Tensor out = tf.zeros({3, 4}); + ET_EXPECT_DEATH( + quantized_embedding_byte_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out), + ""); +} + +TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath3) { + et_pal_init(); + TensorFactory tf; + TensorFactory tf_i; + TensorFactory tf_l; + + int64_t quant_min = 0; + int64_t quant_max = 255; + + Tensor weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.5, 3.5, 3.5}); + Tensor weight_zero_points = tf.make({3, 2}, {1, 5, 7, 9, 11, 13}); + TensorFactory tfo; + Tensor qweight = tfo.make({3, 3}, {8, 10, 12, 14, 10, 12, 12, 14, 8}); + + Tensor indices = tf_l.make({3}, {0, 2, 1}); + + Tensor out = tf.zeros({3, 3}); + ET_EXPECT_DEATH( + quantized_embedding_byte_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out), + ""); +} + +TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath4) { + et_pal_init(); + TensorFactory tf; + TensorFactory tf_i; + TensorFactory tf_l; + + int64_t quant_min = 0; + int64_t quant_max = 255; + + Tensor weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.5, 3.5, 3.5}); + Tensor weight_zero_points = tf.make({3}, {1, 5, 7}); + TensorFactory tfo; + Tensor qweight = tfo.make({3, 3}, {8, 10, 12, 14, 10, 12, 12, 14, 8}); + + Tensor indices = tf_l.make({3}, {0, 2, 1}); + + Tensor out = tf.zeros({3, 3}); + ET_EXPECT_DEATH( + quantized_embedding_byte_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out), + ""); +} + +TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath5) { + et_pal_init(); + TensorFactory tf; + TensorFactory tf_i; + TensorFactory tf_l; + + int64_t quant_min = 0; + int64_t quant_max = 255; + + Tensor weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.5, 3.5, 3.5}); + Tensor weight_zero_points = tf.make({3, 3}, {1, 5, 7, 1, 5, 7, 1, 5, 7}); + TensorFactory tfo; + Tensor qweight = tfo.make({3, 3}, {8, 10, 12, 14, 10, 12, 12, 14, 8}); + + Tensor indices = tf_l.make({3}, {0, 2, 1}); + + Tensor out = tf.zeros({3, 3}); + ET_EXPECT_DEATH( + quantized_embedding_byte_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out), + ""); +}