Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp16 -> fp32 EmbeddingBag moved into CPU impl #47076

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 30 additions & 8 deletions aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
Expand Up @@ -417,12 +417,23 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit(
const c10::optional<at::Tensor>& per_sample_weights_,
const c10::optional<at::Tensor>& compressed_indices_mapping,
bool include_last_offset) {
if (per_sample_weights_.has_value()) {
TORCH_CHECK(
(per_sample_weights_.value().scalar_type() == at::kFloat ||
per_sample_weights_.value().scalar_type() == at::kHalf),
"Expect fp32 or fp16 weights, but found",
per_sample_weights_.value().scalar_type(),
" instead")
}

return embedding_bag_4bit_helper(
packed_w.contiguous(),
indices,
offsets_in,
pruned_weights,
per_sample_weights_,
per_sample_weights_.has_value()
? per_sample_weights_.value().to(at::kFloat)
: per_sample_weights_,
compressed_indices_mapping,
include_last_offset);
}
Expand Down Expand Up @@ -459,12 +470,23 @@ Tensor embedding_bag_4bit_rowwise_offsets(
const c10::optional<Tensor>& per_sample_weights_,
const c10::optional<Tensor>& compressed_indices_mapping,
bool include_last_offset) {
if (per_sample_weights_.has_value()) {
TORCH_CHECK(
(per_sample_weights_.value().scalar_type() == at::kFloat ||
per_sample_weights_.value().scalar_type() == at::kHalf),
"Expect fp32 or fp16 weights, but found",
per_sample_weights_.value().scalar_type(),
" instead")
}

return embedding_bag_4bit_helper(
weight.contiguous(),
indices,
offsets_in,
pruned_weights,
per_sample_weights_,
per_sample_weights_.has_value()
? per_sample_weights_.value().to(at::kFloat)
: per_sample_weights_,
compressed_indices_mapping,
include_last_offset);
}
Expand All @@ -491,12 +513,12 @@ class QEmbeddingBag final {
include_last_offset);
} else if (bit_rate == 4) {
return packed_weight->embeddingbag_4bit(
indices,
offsets,
pruned_weights,
per_sample_weights_,
compressed_indices_mapping,
include_last_offset);
indices,
offsets,
pruned_weights,
per_sample_weights_,
compressed_indices_mapping,
include_last_offset);
} else {
TORCH_INTERNAL_ASSERT(
"Currently only support 8-bit embedding_bag quantization");
Expand Down