Skip to content

Commit

Permalink
[quant] Refactor qembeddingbag to remove duplicate code (#45881)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #45881

Test Plan:
python test/test_quantization.py TestQuantizedEmbeddingBagOps

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D24127892

fbshipit-source-id: 344ee71d335b8c1d668c647db88775632e099dbd
  • Loading branch information
supriyar authored and facebook-github-bot committed Oct 7, 2020
1 parent 43dc7ef commit 1b31ed3
Showing 1 changed file with 27 additions and 91 deletions.
118 changes: 27 additions & 91 deletions aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
Expand Up @@ -196,9 +196,9 @@ at::Tensor embedding_bag_4bit_helper(
#endif
return output;
}
} // namespace

at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte(
at::Tensor embedding_bag_byte_helper(
const at::Tensor& packed_w,
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets_in,
bool sparse,
Expand Down Expand Up @@ -296,6 +296,23 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte(
return output;
}

} // namespace

at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte(
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets_in,
bool sparse,
const c10::optional<at::Tensor>& per_sample_weights_,
bool include_last_offset) {
return embedding_bag_byte_helper(
packed_w,
indices,
offsets_in,
sparse,
per_sample_weights_,
include_last_offset);
}

at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit(
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets_in,
Expand Down Expand Up @@ -323,99 +340,18 @@ Tensor embedding_bag_byte_rowwise_offsets(
const c10::optional<Tensor>& offsets_in,
const bool /* scale_grad_by_freq */,
const int64_t /* mode */,
bool /* sparse */,
bool sparse,
const c10::optional<Tensor>& per_sample_weights_,
bool include_last_offset) {
TORCH_CHECK(weight.scalar_type() == at::kByte);
TORCH_CHECK(weight.dim() == 2);
TORCH_CHECK(
offsets_in.has_value(),
"embedding_bag_byte_rowwise_offsets expects offsets to be set");

auto offsets = offsets_in.value();
auto offsets_data = offsets.data_ptr<int64_t>();
const auto weight_data = weight.data_ptr<uint8_t>();
const auto indices_data = indices.data_ptr<int64_t>();

const int64_t N = weight.size(0);
const int64_t D = weight.size(1) - 8; // NB: -8 to account for scale and bias
const int64_t M = offsets.size(0);

int64_t output_size = M - 1;
std::vector<int64_t> offsets_include_last;

if (!include_last_offset) {
output_size = M;
offsets_include_last.resize(M + 1);
std::memcpy(
offsets_include_last.data(),
offsets.data_ptr<int64_t>(),
sizeof(int64_t) * M);
offsets_include_last[M] = indices.numel();
offsets_data = offsets_include_last.data();
}

std::vector<int64_t> shape = {output_size, D};
auto output = at::empty(shape, weight.options().dtype(at::kFloat));
auto* output_data = output.data_ptr<float>();

#ifdef USE_FBGEMM

auto kernel_i8_i64 =
fbgemm::GenerateEmbeddingSpMDM<uint8_t, int64_t, int64_t>(
/*block_size=*/D,
/*has_weight=*/per_sample_weights_.has_value(),
/*normalize_by_lengths=*/false,
/*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
/*is_weight_positional=*/false,
/*use_offsets=*/true);

if (weight.is_contiguous()) {
at::parallel_for(
0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
bool success = kernel_i8_i64(
/*output_size=*/end_idx - start_idx,
/*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
/*data_size=*/N,
/*input=*/weight_data,
/*indices=*/indices_data + offsets_data[start_idx],
/*offsets_or_lengths=*/offsets_data + start_idx,
/*weights=*/
per_sample_weights_
? per_sample_weights_.value().data_ptr<float>() +
offsets_data[start_idx]
: nullptr,
/*out=*/output_data + start_idx * D);

TORCH_CHECK(
success,
"FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input");
});
} else {
auto weight_contig = weight.contiguous();
const auto weight_data_contig = weight_contig.data_ptr<uint8_t>();
at::parallel_for(
0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
bool success = kernel_i8_i64(
/*output_size=*/end_idx - start_idx,
/*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
/*data_size=*/N,
/*input=*/weight_data_contig,
/*indices=*/indices_data + offsets_data[start_idx],
/*offsets_or_lengths=*/offsets_data + start_idx,
/*weights=*/
per_sample_weights_
? per_sample_weights_.value().data_ptr<float>() +
offsets_data[start_idx]
: nullptr,
/*out=*/output_data + start_idx * D);
TORCH_CHECK(
success,
"FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input");
});
}
#endif
return output;
return embedding_bag_byte_helper(
weight,
indices,
offsets_in,
sparse,
per_sample_weights_,
include_last_offset);
}

Tensor embedding_bag_4bit_rowwise_offsets(
Expand Down

0 comments on commit 1b31ed3

Please sign in to comment.