Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant] Refactor qembeddingbag to remove duplicate code #45881

Closed
wants to merge 4 commits into from
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
118 changes: 27 additions & 91 deletions aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
Expand Up @@ -197,9 +197,9 @@ at::Tensor embedding_bag_4bit_helper(
#endif
return output;
}
} // namespace

at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte(
at::Tensor embedding_bag_byte_helper(
const at::Tensor& packed_w,
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets_in,
bool sparse,
Expand Down Expand Up @@ -297,6 +297,23 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte(
return output;
}

} // namespace

at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte(
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets_in,
bool sparse,
const c10::optional<at::Tensor>& per_sample_weights_,
bool include_last_offset) {
return embedding_bag_byte_helper(
packed_w,
indices,
offsets_in,
sparse,
per_sample_weights_,
include_last_offset);
}

at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit(
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets_in,
Expand Down Expand Up @@ -324,99 +341,18 @@ Tensor embedding_bag_byte_rowwise_offsets(
const c10::optional<Tensor>& offsets_in,
const bool /* scale_grad_by_freq */,
const int64_t /* mode */,
bool /* sparse */,
bool sparse,
const c10::optional<Tensor>& per_sample_weights_,
bool include_last_offset) {
TORCH_CHECK(weight.scalar_type() == at::kByte);
TORCH_CHECK(weight.ndimension() == 2);
TORCH_CHECK(
offsets_in.has_value(),
"embedding_bag_byte_rowwise_offsets expects offsets to be set");

auto offsets = offsets_in.value();
auto offsets_data = offsets.data_ptr<int64_t>();
const auto weight_data = weight.data_ptr<uint8_t>();
const auto indices_data = indices.data_ptr<int64_t>();

const int64_t N = weight.size(0);
const int64_t D = weight.size(1) - 8; // NB: -8 to account for scale and bias
const int64_t M = offsets.size(0);

int64_t output_size = M - 1;
std::vector<int64_t> offsets_include_last;

if (!include_last_offset) {
output_size = M;
offsets_include_last.resize(M + 1);
std::memcpy(
offsets_include_last.data(),
offsets.data_ptr<int64_t>(),
sizeof(int64_t) * M);
offsets_include_last[M] = indices.numel();
offsets_data = offsets_include_last.data();
}

std::vector<int64_t> shape = {output_size, D};
auto output = at::empty(shape, weight.options().dtype(at::kFloat));
auto* output_data = output.data_ptr<float>();

#ifdef USE_FBGEMM

auto kernel_i8_i64 =
fbgemm::GenerateEmbeddingSpMDM<uint8_t, int64_t, int64_t>(
/*block_size=*/D,
/*has_weight=*/per_sample_weights_.has_value(),
/*normalize_by_lengths=*/false,
/*prefetch=*/16, // NOLINT(cppcoreguidelines-avoid-magic-numbers)
/*is_weight_positional=*/false,
/*use_offsets=*/true);

if (weight.is_contiguous()) {
at::parallel_for(
0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
bool success = kernel_i8_i64(
/*output_size=*/end_idx - start_idx,
/*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
/*data_size=*/N,
/*input=*/weight_data,
/*indices=*/indices_data + offsets_data[start_idx],
/*offsets_or_lengths=*/offsets_data + start_idx,
/*weights=*/
per_sample_weights_
? per_sample_weights_.value().data_ptr<float>() +
offsets_data[start_idx]
: nullptr,
/*out=*/output_data + start_idx * D);

TORCH_CHECK(
success,
"FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input");
});
} else {
auto weight_contig = weight.contiguous();
const auto weight_data_contig = weight_contig.data_ptr<uint8_t>();
at::parallel_for(
0, output_size, 1, [&](int64_t start_idx, int64_t end_idx) {
bool success = kernel_i8_i64(
/*output_size=*/end_idx - start_idx,
/*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
/*data_size=*/N,
/*input=*/weight_data_contig,
/*indices=*/indices_data + offsets_data[start_idx],
/*offsets_or_lengths=*/offsets_data + start_idx,
/*weights=*/
per_sample_weights_
? per_sample_weights_.value().data_ptr<float>() +
offsets_data[start_idx]
: nullptr,
/*out=*/output_data + start_idx * D);
TORCH_CHECK(
success,
"FBGEMM GenerateEmbeddingSpMDM kernel failed for 8-bit input");
});
}
#endif
return output;
return embedding_bag_byte_helper(
weight,
indices,
offsets_in,
sparse,
per_sample_weights_,
include_last_offset);
}

Tensor embedding_bag_4bit_rowwise_offsets(
Expand Down