Skip to content

Commit

Permalink
fp16 -> fp32 EmbeddingBag moved into CPU impl (#47076)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #47076

Pull Request resolved: pytorch/glow#5038

Eliminate double casting in glow when submitting fp16 per sample weights

Test Plan:
buck test glow/glow/torch_glow/tests:embedding_bag_test

Due to dependency conflicts between glow and caffe2, the test has been reverted from this diff, and landed separately

Reviewed By: allwu

Differential Revision: D24421367

fbshipit-source-id: eb3615144a2cad3d593543428dfdec165ad301df
  • Loading branch information
b-koopman authored and facebook-github-bot committed Nov 13, 2020
1 parent 6a4d55f commit 7b8bd91
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
Expand Up @@ -512,12 +512,23 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit(
const c10::optional<at::Tensor>& per_sample_weights_,
const c10::optional<at::Tensor>& compressed_indices_mapping,
bool include_last_offset) {
if (per_sample_weights_.has_value()) {
TORCH_CHECK(
(per_sample_weights_.value().scalar_type() == at::kFloat ||
per_sample_weights_.value().scalar_type() == at::kHalf),
"Expect fp32 or fp16 weights, but found",
per_sample_weights_.value().scalar_type(),
" instead")
}

return embedding_bag_4bit_helper(
packed_w.contiguous(),
indices,
offsets_in,
pruned_weights,
per_sample_weights_,
per_sample_weights_.has_value()
? per_sample_weights_.value().to(at::kFloat)
: per_sample_weights_,
compressed_indices_mapping,
include_last_offset);
}
Expand Down Expand Up @@ -557,12 +568,23 @@ Tensor embedding_bag_4bit_rowwise_offsets(
const c10::optional<Tensor>& per_sample_weights_,
const c10::optional<Tensor>& compressed_indices_mapping,
bool include_last_offset) {
if (per_sample_weights_.has_value()) {
TORCH_CHECK(
(per_sample_weights_.value().scalar_type() == at::kFloat ||
per_sample_weights_.value().scalar_type() == at::kHalf),
"Expect fp32 or fp16 weights, but found",
per_sample_weights_.value().scalar_type(),
" instead")
}

return embedding_bag_4bit_helper(
weight.contiguous(),
indices,
offsets_in,
pruned_weights,
per_sample_weights_,
per_sample_weights_.has_value()
? per_sample_weights_.value().to(at::kFloat)
: per_sample_weights_,
compressed_indices_mapping,
include_last_offset);
}
Expand Down Expand Up @@ -591,12 +613,12 @@ class QEmbeddingBag final {
false /* is_embedding_op */);
} else if (bit_rate == 4) {
return packed_weight->embeddingbag_4bit(
indices,
offsets,
pruned_weights,
per_sample_weights_,
compressed_indices_mapping,
include_last_offset);
indices,
offsets,
pruned_weights,
per_sample_weights_,
compressed_indices_mapping,
include_last_offset);
} else {
TORCH_INTERNAL_ASSERT(
"Currently only support 8-bit embedding_bag quantization");
Expand Down

0 comments on commit 7b8bd91

Please sign in to comment.