Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aten] embedding_bag_byte_rowwise_offsets_out #49561

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions aten/src/ATen/core/ivalue.h
Expand Up @@ -755,6 +755,8 @@ struct TORCH_API IValue final {
// None
template <typename T>
optional<T> toOptional();
template <typename T>
optional<T> toOptional() const;

/// @private [doxygen private]
/// this is a shallow comparison of two IValues to test the object identity
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/core/ivalue_inl.h
Expand Up @@ -1257,6 +1257,14 @@ inline optional<T> IValue::toOptional() {
return this->to<T>();
}

template <typename T>
inline optional<T> IValue::toOptional() const {
if (this->isNone()) {
return nullopt;
}
return this->to<T>();
}

inline bool IValue::isCustomClass() const {
return torch::isCustomClass(*this);
}
Expand Down
47 changes: 42 additions & 5 deletions aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
@@ -1,6 +1,7 @@
#include <ATen/ATen.h>
#include <ATen/native/quantized/cpu/embedding_packed_params.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qembeddingbag.h>
#include <torch/library.h>
#ifdef USE_FBGEMM
#include <fbgemm/Fbgemm.h>
Expand Down Expand Up @@ -199,7 +200,8 @@ at::Tensor embedding_bag_4bit_impl(
}

template <typename IndexType, typename OffsetType>
at::Tensor embedding_bag_byte_impl(
at::Tensor& embedding_bag_byte_impl(
at::Tensor& output,
const at::Tensor& weight,
const at::Tensor& indices,
const at::Tensor& offsets,
Expand Down Expand Up @@ -261,7 +263,7 @@ at::Tensor embedding_bag_byte_impl(
} else {
shape = {output_size, D};
}
auto output = at::empty(shape, weight.options().dtype(at::kFloat));
output.resize_(shape);
auto* output_data = output.data_ptr<float>();

const int index_size = indices.numel();
Expand Down Expand Up @@ -332,7 +334,8 @@ at::Tensor embedding_bag_byte_impl(
"embedding_bag_byte expects FBGEMM support. This PyTorch installation was not built with FBGEMM operators");
}

at::Tensor embedding_bag_byte_helper(
at::Tensor& embedding_bag_byte_helper(
at::Tensor& output,
const at::Tensor& weight,
const at::Tensor& indices,
const c10::optional<at::Tensor>& offsets_in,
Expand Down Expand Up @@ -381,6 +384,7 @@ at::Tensor embedding_bag_byte_helper(
// need to cast, which can be additional performance overhead
if (indices.scalar_type() == at::kInt && offsets.scalar_type() == at::kInt) {
return embedding_bag_byte_impl<int, int>(
output,
weight,
indices,
offsets,
Expand All @@ -392,6 +396,7 @@ at::Tensor embedding_bag_byte_helper(
} else if (
indices.scalar_type() == at::kInt && offsets.scalar_type() == at::kLong) {
return embedding_bag_byte_impl<int, int64_t>(
output,
weight,
indices,
offsets,
Expand All @@ -403,6 +408,7 @@ at::Tensor embedding_bag_byte_helper(
} else if (
indices.scalar_type() == at::kLong && offsets.scalar_type() == at::kInt) {
return embedding_bag_byte_impl<int64_t, int>(
output,
weight,
indices,
offsets,
Expand All @@ -415,6 +421,7 @@ at::Tensor embedding_bag_byte_helper(

// default case given the TORCH_CHECK above
return embedding_bag_byte_impl<int64_t, int64_t>(
output,
weight,
indices,
offsets,
Expand Down Expand Up @@ -521,7 +528,9 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_byte(
const c10::optional<at::Tensor>& compressed_indices_mapping,
bool include_last_offset,
bool is_embedding_op) {
auto output = at::empty({0}, packed_w.options().dtype(at::kFloat));
return embedding_bag_byte_helper(
output,
packed_w,
indices,
offsets_in,
Expand Down Expand Up @@ -562,9 +571,9 @@ at::Tensor PackedEmbeddingBagWeight::embeddingbag_4bit(

namespace at {
namespace native {
namespace {

Tensor embedding_bag_byte_rowwise_offsets(
Tensor& embedding_bag_byte_rowwise_offsets_out(
Tensor& output,
const Tensor& weight,
const Tensor& indices,
const c10::optional<Tensor>& offsets_in,
Expand All @@ -575,6 +584,7 @@ Tensor embedding_bag_byte_rowwise_offsets(
const c10::optional<Tensor>& compressed_indices_mapping,
bool include_last_offset) {
return embedding_bag_byte_helper(
output,
weight,
indices,
offsets_in,
Expand All @@ -585,6 +595,33 @@ Tensor embedding_bag_byte_rowwise_offsets(
false /* is_embedding_op */);
}

namespace {

Tensor embedding_bag_byte_rowwise_offsets(
const Tensor& weight,
const Tensor& indices,
const c10::optional<Tensor>& offsets_in,
const bool /* scale_grad_by_freq */,
const int64_t /* mode */,
bool pruned_weights,
const c10::optional<Tensor>& per_sample_weights_,
const c10::optional<Tensor>& compressed_indices_mapping,
bool include_last_offset) {
auto output = at::empty({0}, weight.options().dtype(at::kFloat));
embedding_bag_byte_rowwise_offsets_out(
output,
weight,
indices,
offsets_in,
false /*unused scale_grad_by_freq*/,
0 /*unused mode*/,
pruned_weights,
per_sample_weights_,
compressed_indices_mapping,
include_last_offset);
return output;
}

Tensor embedding_bag_4bit_rowwise_offsets(
const Tensor& weight,
const Tensor& indices,
Expand Down
17 changes: 17 additions & 0 deletions aten/src/ATen/native/quantized/cpu/qembeddingbag.h
@@ -0,0 +1,17 @@
#include <ATen/ATen.h>

namespace at {
namespace native {
Tensor& embedding_bag_byte_rowwise_offsets_out(
Tensor& output,
const Tensor& weight,
const Tensor& indices,
const c10::optional<Tensor>& offsets_in,
const bool /* scale_grad_by_freq */,
const int64_t /* mode */,
bool pruned_weights,
const c10::optional<Tensor>& per_sample_weights_,
const c10::optional<Tensor>& compressed_indices_mapping,
bool include_last_offset);
} // native
} // at
35 changes: 35 additions & 0 deletions torch/csrc/jit/runtime/static/ops.cpp
@@ -1,5 +1,6 @@
#include <torch/csrc/jit/runtime/static/ops.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/quantized/cpu/qembeddingbag.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/vararg_functions.h>

Expand Down Expand Up @@ -301,6 +302,40 @@ REGISTER_OPERATOR_FUNCTOR(aten::clone, aten_clone, [](Node* n) -> SROperator {
at::native::copy_(out_t, in0_t, false);
};
});
REGISTER_OPERATOR_FUNCTOR_OPT(
quantized::embedding_bag_byte_rowwise_offsets,
quantized_embedding_bag_byte_rowwise_offsets,
false, // don't reuse byte inputs
true,
[](Node* n) -> SROperator {
return [](ProcessedNode* p_node) {
auto weight = p_node->Input(0).toTensor();
auto indices = p_node->Input(1).toTensor();
auto offsets = p_node->Input(2).toOptional<at::Tensor>();
auto pruned_weights = p_node->Input(5).toBool();
auto per_sample_weights = p_node->Input(6).toOptional<at::Tensor>();
auto compressed_indices_mapping =
p_node->Input(7).toOptional<at::Tensor>();
auto include_last_offset = p_node->Input(8).toBool();
if (p_node->Output(0).isNone()) {
p_node->Output(0) =
at::empty({0}, weight.options().dtype(at::kFloat));
}
auto out_t = p_node->Output(0).toTensor();
out_t.resize_({0});
return at::native::embedding_bag_byte_rowwise_offsets_out(
out_t,
weight,
indices,
offsets,
false, // unused scale_grad_by_freq
0, // unused mode
pruned_weights,
per_sample_weights,
compressed_indices_mapping,
include_last_offset);
};
});

std::function<void(ProcessedNode*)> getOutOfPlaceOperation(Node* n) {
auto op_name = n->kind().toQualString();
Expand Down