From b71eee97438c6b0be944b2e98b9777f6b5d46bbf Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Mon, 5 Oct 2020 17:44:56 -0700 Subject: [PATCH] [quant] Refactor qembeddingbag to remove duplicate code Summary: Test Plan: python test/test_quantization.py TestQuantizedEmbeddingBagOps Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- .../native/quantized/cpu/qembeddingbag.cpp | 118 ++++-------------- 1 file changed, 27 insertions(+), 91 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp index d6ee2292ec1e..7f4a234fa169 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp @@ -197,9 +197,9 @@ at::Tensor embedding_bag_4bit_helper( #endif return output; } -} // namespace -at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte( +at::Tensor embedding_bag_byte_helper( + const at::Tensor& packed_w, const at::Tensor& indices, const c10::optional& offsets_in, bool sparse, @@ -297,6 +297,23 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte( return output; } +} // namespace + +at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte( + const at::Tensor& indices, + const c10::optional& offsets_in, + bool sparse, + const c10::optional& per_sample_weights_, + bool include_last_offset) { + return embedding_bag_byte_helper( + packed_w, + indices, + offsets_in, + sparse, + per_sample_weights_, + include_last_offset); +} + at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit( const at::Tensor& indices, const c10::optional& offsets_in, @@ -324,99 +341,18 @@ Tensor embedding_bag_byte_rowwise_offsets( const c10::optional& offsets_in, const bool /* scale_grad_by_freq */, const int64_t /* mode */, - bool /* sparse */, + bool sparse, const c10::optional& per_sample_weights_, bool include_last_offset) { TORCH_CHECK(weight.scalar_type() == at::kByte); TORCH_CHECK(weight.ndimension() == 2); - TORCH_CHECK( - offsets_in.has_value(), - "embedding_bag_byte_rowwise_offsets expects offsets to be set"); - - auto offsets = offsets_in.value(); - auto offsets_data = offsets.data_ptr(); - const auto weight_data = weight.data_ptr(); - const auto indices_data = indices.data_ptr(); - - const int64_t N = weight.size(0); - const int64_t D = weight.size(1) - 8; // NB: -8 to account for scale and bias - const int64_t M = offsets.size(0); - - int64_t output_size = M - 1; - std::vector offsets_include_last; - - if (!include_last_offset) { - output_size = M; - offsets_include_last.resize(M + 1); - std::memcpy( - offsets_include_last.data(), - offsets.data_ptr(), - sizeof(int64_t) * M); - offsets_include_last[M] = indices.numel(); - offsets_data = offsets_include_last.data(); - } - - std::vector shape = {output_size, D}; - auto output = at::empty(shape, weight.options().dtype(at::kFloat)); - auto* output_data = output.data_ptr(); - -#ifdef USE_FBGEMM - - auto kernel_i8_i64 = - fbgemm::GenerateEmbeddingSpMDM( - /*block_size=*/D, - /*has_weight=*/per_sample_weights_.has_value(), - /*normalize_by_lengths=*/false, - /*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers) - /*is_weight_positional=*/false, - /*use_offsets=*/true); - - if (weight.is_contiguous()) { - at::parallel_for( - 0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) { - bool success = kernel_i8_i64( - /*output_size=*/end_idx - start_idx, - /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx], - /*data_size=*/N, - /*input=*/weight_data, - /*indices=*/indices_data + offsets_data[start_idx], - /*offsets_or_lengths=*/offsets_data + start_idx, - /*weights=*/ - per_sample_weights_ - ? per_sample_weights_.value().data_ptr() + - offsets_data[start_idx] - : nullptr, - /*out=*/output_data + start_idx * D); - - TORCH_CHECK( - success, - "FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input"); - }); - } else { - auto weight_contig = weight.contiguous(); - const auto weight_data_contig = weight_contig.data_ptr(); - at::parallel_for( - 0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) { - bool success = kernel_i8_i64( - /*output_size=*/end_idx - start_idx, - /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx], - /*data_size=*/N, - /*input=*/weight_data_contig, - /*indices=*/indices_data + offsets_data[start_idx], - /*offsets_or_lengths=*/offsets_data + start_idx, - /*weights=*/ - per_sample_weights_ - ? per_sample_weights_.value().data_ptr() + - offsets_data[start_idx] - : nullptr, - /*out=*/output_data + start_idx * D); - TORCH_CHECK( - success, - "FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input"); - }); - } -#endif - return output; + return embedding_bag_byte_helper( + weight, + indices, + offsets_in, + sparse, + per_sample_weights_, + include_last_offset); } Tensor embedding_bag_4bit_rowwise_offsets(