-
Notifications
You must be signed in to change notification settings - Fork 25k
F::embedding, F::embedding_bag, moved Embedding and EmbeddingBag options to embedding.h in options #28669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
c325e0a
to
8bb57c9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 Thanks a lot for the awesome work! I left some initial comments.
TORCH_ARG(bool, scale_grad_by_freq) = false; | ||
/// ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag. ``"sum"`` computes the weighted sum, taking `per_sample_weights` | ||
/// into consideration. ``"mean"`` computes the average of the values in the bag, ``"max"`` computes the max value over each bag. | ||
typedef c10::variant<enumtype::kSum, enumtype::kMean, enumtype::kMax> mode_t; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably move this line above EmbeddingBagOptions(int64_t num_embeddings, int64_t embedding_dim);
, because the TORCH_ARG
usage in the previous line could hide this line under private:
:(
(There is a discussion at #28413 (comment))
#include <torch/nn/pimpl.h> | ||
#include <torch/types.h> | ||
|
||
#include <cstdint> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a specific reason we need to include this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch! removed
@@ -1,6 +1,7 @@ | |||
#include <torch/nn/modules/embedding.h> | |||
|
|||
#include <torch/types.h> | |||
#include <torch/cuda.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a specific reason we need to include this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
Tensor weight, | ||
const Tensor& offsets, | ||
const EmbeddingBagOptions& options, | ||
const Tensor& per_sample_weights) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that offsets
and per_sample_weights
should also belong to EmbeddingBagOptions
(we can add a comment in EmbeddingBagOptions
saying that these two only take affect in F::embedding_bag
), and I think we can add a default constructor for EmbeddingBagOptions
so that people don't need to specify num_embeddings
and embedding_dim
when they use EmbeddingBagOptions
for F::embedding_bag
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can also add a default constructor for EmbeddingOptions
, so that we can remove the checks like TORCH_CHECK((*options).num_embeddings() == embeddings.size(0)
in Embedding::from_pretrained
and EmbeddingBag::from_pretrained
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a good point, but then what should the variables num_embeddings and embedding_dim be set to for the default constructor? They will have to be valid non-zero values otherwise there will be a crash because of the way we are initializing weight in reset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also I was wondering if we should add checks where certain options are not supported under circumstances(instead of just mentioning them in the documentation) like sparse option is not supported when mode = kMax, per_sample_weights only supported for mode = kSum etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a good point, but then what should the variables num_embeddings and embedding_dim be set to for the default constructor? They will have to be valid non-zero values otherwise there will be a crash because of the way we are initializing weight in reset.
That's a good catch! I think we can change num_embeddings
and embedding_dim
to be c10::optional<int64_t>
type, and then add checks to make sure their values are defined in places like the reset
method (and other places that use options.num_embeddings()
and options.embedding_dim()
currently).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also I was wondering if we should add checks where certain options are not supported under circumstances(instead of just mentioning them in the documentation) like sparse option is not supported when mode = kMax, per_sample_weights only supported for mode = kSum etc
Yes I think we currently do have the checks for sparse option is not supported when mode = kMax
and per_sample_weights only supported for mode = kSum
in this file. It would be awesome to add the "not-defined" checks for options.offsets()
and options.per_sample_weights()
in EmbeddingBag::reset
as well, and tell the user that these two only take affect in F::embedding_bag
8bb57c9
to
8d460d3
Compare
TORCH_CHECK((*options).embedding_dim() == embeddings.size(1), "Expects options.embeddings_dim to be ", embeddings.size(1) , "but found ", (*options).embedding_dim()); | ||
if(options.num_embeddings() && options.num_embeddings()) { | ||
TORCH_CHECK(*options.num_embeddings() == embeddings.size(0), "Expects options.num_embeddings to be ", embeddings.size(0) , "but found ", *options.num_embeddings()); | ||
TORCH_CHECK(*options.embedding_dim() == embeddings.size(1), "Expects options.embeddings_dim to be ", embeddings.size(1) , "but found ", *options.embedding_dim()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the from_pretrained
methods of Embedding
, I believe we don't have to define options.num_embeddings()
and options.num_embeddings()
, and can just take the correct values from embeddings.size(0)
and embeddings.size(1)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah but if the user passes options with options.num_embeddings != embeddings.size(0) and(or) options.embedding_dim() != embeddings.size(1), there should be an error right?
Although I do think if(options.num_embeddings() && options.num_embeddings()) should be changed to something else because if user defines only one of the two values and that value is not equal to embeddings.size(i), then won't be an error but there probably should be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline: we probably don't need to raise error if user specifies num_embeddings
and embedding_dim
in the options when they call from_pretrained
, since specifying those two parameters doesn't change the behavior of from_pretrained
. However we should raise a warning (using TORCH_WARN
) when the user does specify those two parameters, and the warning should say that num_embeddings
and embedding_dim
options parameters are ignored in from_pretrained
.
TORCH_CHECK((*options).embedding_dim() == embeddings.size(1), "Expects options.embeddings_dim to be ", embeddings.size(1) , "but found ", (*options).embedding_dim()); | ||
if(options.num_embeddings() && options.num_embeddings()) { | ||
TORCH_CHECK(*options.num_embeddings() == embeddings.size(0), "Expects options.num_embeddings to be ", embeddings.size(0) , "but found ", *options.num_embeddings()); | ||
TORCH_CHECK(*options.embedding_dim() == embeddings.size(1), "Expects options.embeddings_dim to be ", embeddings.size(1) , "but found ", *options.embedding_dim()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the from_pretrained
methods of EmbeddingBag
, I believe we don't have to define options.num_embeddings()
and options.num_embeddings()
, and can just take the correct values from embeddings.size(0)
and embeddings.size(1)
.
|
||
if (options.max_norm() != c10::nullopt) { | ||
torch::NoGradGuard no_grad; | ||
torch::embedding_renorm_(weight, input.contiguous(), *options.max_norm(), options.norm_type()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be awesome to implement
def _no_grad_embedding_renorm_(weight, input, max_norm, norm_type):
# type: (Tensor, Tensor, float, float) -> Tensor
with torch.no_grad():
torch.embedding_renorm_(weight, input, max_norm, norm_type)
and call _no_grad_embedding_renorm_
from here :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and we can pass input
here if we call input = input.contiguous()
above
|
||
if (options.max_norm() != c10::nullopt) { | ||
torch::NoGradGuard no_grad; | ||
torch::embedding_renorm_(weight, input_, *options.max_norm(), options.norm_type()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto regarding calling _no_grad_embedding_renorm_
from here
torch::NoGradGuard no_grad; | ||
torch::embedding_renorm_(weight, input.contiguous(), *options.max_norm(), options.norm_type()); | ||
} | ||
return torch::embedding(weight, input.contiguous(), *options.padding_idx(), options.scale_grad_by_freq(), options.sparse()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and we can pass input
here if we call input = input.contiguous()
above
} | ||
|
||
inline Tensor embedding_bag( | ||
const Tensor& input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can pass input
by value and remove the need for input_
as well :D
inline Tensor embedding_bag( | ||
const Tensor& input, | ||
Tensor weight, | ||
const EmbeddingBagOptions& options = {}) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can pass options
by value as well, to remove the need for offsets_
and per_sample_weights_
:D
!per_sample_weights_.defined() || c10::get_if<enumtype::kMean>(&options.mode()), | ||
"embedding_bag: per_sample_weights was not null. ", | ||
"per_sample_weights is only supported for mode='kSum' (got mode='", | ||
c10::visit(torch::enumtype::enum_name{}, options.mode()), "').Please open a feature request on GitHub."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we pull latest master into this PR branch, we will be able to call
c10::visit(torch::enumtype::enum_name{}, options.mode()), "').Please open a feature request on GitHub."); | |
torch::enumtype::get_enum_name(options.mode()), "').Please open a feature request on GitHub."); |
8d460d3
to
2532d5d
Compare
TORCH_CHECK((*options).embedding_dim() == embeddings.size(1), "Expects options.embeddings_dim to be ", embeddings.size(1) , "but found ", (*options).embedding_dim()); | ||
} else { | ||
options = EmbeddingOptions(embeddings.size(0), embeddings.size(1)); | ||
if(options.num_embeddings()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
if(options.num_embeddings()) { | |
if (options.num_embeddings()) { |
} | ||
Embedding embedding((*options)._weight(embeddings)); | ||
if(options.embedding_dim()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
if(options.embedding_dim()) { | |
if (options.embedding_dim()) { |
} else { | ||
options = EmbeddingOptions(embeddings.size(0), embeddings.size(1)); | ||
if(options.num_embeddings()) { | ||
TORCH_WARN(*options.num_embeddings() == embeddings.size(0), "Expects options.num_embeddings to be ", embeddings.size(0) , "but found ", *options.num_embeddings()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we can write this instead:
TORCH_WARN(*options.num_embeddings() == embeddings.size(0), "Expects options.num_embeddings to be ", embeddings.size(0) , "but found ", *options.num_embeddings()); | |
TORCH_WARN("`num_embeddings` options parameter is ignored in `torch::nn::Embedding::from_pretrained`.); |
} | ||
Embedding embedding((*options)._weight(embeddings)); | ||
if(options.embedding_dim()) { | ||
TORCH_WARN(*options.embedding_dim() == embeddings.size(1), "Expects options.num_embeddings to be ", embeddings.size(1) , "but found ", *options.embedding_dim()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we can write this instead:
TORCH_WARN(*options.embedding_dim() == embeddings.size(1), "Expects options.num_embeddings to be ", embeddings.size(1) , "but found ", *options.embedding_dim()); | |
TORCH_WARN("`embedding_dim` options parameter is ignored in `torch::nn::Embedding::from_pretrained`.); |
options.num_embeddings(embeddings.size(0)); | ||
options.embedding_dim(embeddings.size(1)); | ||
|
||
Embedding embedding(options._weight(embeddings)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably combine the above 3 lines into one:
Embedding embedding(options._weight(embeddings)); | |
Embedding embedding(options.num_embeddings(embeddings.size(0)).embedding_dim(embeddings.size(1))._weight(embeddings)); |
} else { | ||
options = EmbeddingBagOptions(embeddings.size(0), embeddings.size(1)); | ||
if(options.num_embeddings()) { | ||
TORCH_WARN(*options.num_embeddings() == embeddings.size(0), "Expects options.num_embeddings to be ", embeddings.size(0) , "but found ", *options.num_embeddings()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we can write this instead:
TORCH_WARN(*options.num_embeddings() == embeddings.size(0), "Expects options.num_embeddings to be ", embeddings.size(0) , "but found ", *options.num_embeddings()); | |
TORCH_WARN("`num_embeddings` options parameter is ignored in `torch::nn::EmbeddingBag::from_pretrained`.); |
TORCH_WARN(*options.num_embeddings() == embeddings.size(0), "Expects options.num_embeddings to be ", embeddings.size(0) , "but found ", *options.num_embeddings()); | ||
} | ||
if(options.embedding_dim()) { | ||
TORCH_WARN(*options.embedding_dim() == embeddings.size(1), "Expects options.num_embeddings to be ", embeddings.size(1) , "but found ", *options.embedding_dim()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we can write this instead:
TORCH_WARN(*options.embedding_dim() == embeddings.size(1), "Expects options.num_embeddings to be ", embeddings.size(1) , "but found ", *options.embedding_dim()); | |
TORCH_WARN("`embedding_dim` options parameter is ignored in `torch::nn::EmbeddingBag::from_pretrained`.); |
EmbeddingBag embeddingbag((*options)._weight(embeddings)); | ||
options.num_embeddings(embeddings.size(0)); | ||
options.embedding_dim(embeddings.size(1)); | ||
EmbeddingBag embeddingbag(options._weight(embeddings)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably combine the above 3 lines into one:
EmbeddingBag embeddingbag(options._weight(embeddings)); | |
EmbeddingBag embeddingbag(options.num_embeddings(embeddings.size(0)).embedding_dim(embeddings.size(1))._weight(embeddings)); |
} | ||
|
||
if (options.max_norm() != c10::nullopt) { | ||
torch::NoGradGuard no_grad; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can likely remove this line because _no_grad_embedding_renorm_
already has it :D
2532d5d
to
371b201
Compare
TORCH_CHECK(options.offsets().dim() == 1, "offsets has to be a 1D Tensor"); | ||
TORCH_CHECK(options.offsets()[0].item<int64_t>() == 0, "offsets[0] has to be 0, i.e., the first sequence in the mini-batch has to start from position 0. However, got ", | ||
options.offsets()[0].item<int64_t>()); | ||
TORCH_CHECK(options.offsets()[-1].item<int64_t>() <= input.size(0), "offsets[-1] can not be greater than input's length({)", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
TORCH_CHECK(options.offsets()[-1].item<int64_t>() <= input.size(0), "offsets[-1] can not be greater than input's length({)", | |
TORCH_CHECK(options.offsets()[-1].item<int64_t>() <= input.size(0), "offsets[-1] can not be greater than input's length({", |
TORCH_CHECK(options.offsets()[-1].item<int64_t>() <= input.size(0), "offsets[-1] can not be greater than input's length({)", | ||
input.size(0), "}), but got offsets[-1] of {", options.offsets()[-1].item<int64_t>(), "}"); | ||
} else { | ||
TORCH_CHECK(false, "input has to be 1D or 2D Tensor,but got Tensor of dimension ", input.dim()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
TORCH_CHECK(false, "input has to be 1D or 2D Tensor,but got Tensor of dimension ", input.dim()); | |
TORCH_CHECK(false, "input has to be 1D or 2D Tensor, but got Tensor of dimension ", input.dim()); |
mode_enum, | ||
options.sparse(), | ||
per_sample_weights_)); | ||
torch::Tensor EmbeddingBagImpl::forward(const Tensor& input) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we actually need to preserve offsets
and per_sample_weights
as arguments to this function (and allow them to take {}
as default value), because the Python version expects those two as arguments to this function.
options.sparse(), | ||
per_sample_weights_)); | ||
torch::Tensor EmbeddingBagImpl::forward(const Tensor& input) { | ||
return F::embedding_bag(input, weight, options); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and here we will need to construct a new EmbeddingBagOptions
(which contains values for offsets
and per_sample_weights
, along with all the other arguments that Python version passes) and pass it into F::embedding_bag
, because we don't want to in-place change this module's options
/// If given, each embedding vector with norm larger than `max_norm` is renormalized to have norm `max_norm`. | ||
TORCH_ARG(c10::optional<float>, max_norm) = c10::nullopt; | ||
/// The p of the p-norm to compute for the `max_norm` option. Default ``2``. | ||
TORCH_ARG(float, norm_type) = 2.; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably need to change this to double
type, to match the Python version better (float
in Python is 64-bit)
/// If given, each embedding vector with norm larger than `max_norm` is renormalized to have norm `max_norm`. | ||
TORCH_ARG(c10::optional<float>, max_norm) = c10::nullopt; | ||
/// The p of the p-norm to compute for the `max_norm` option. Default ``2``. | ||
TORCH_ARG(float, norm_type) = 2.; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably need to change this to double
type, to match the Python version better (float
in Python is 64-bit)
// per_sample_weights is a tensor of float / double weights, or NULL to indicate all weights should be taken to be ``1``. If specified, `per_sample_weights` | ||
// must have exactly the same shape as input and is treated as having the same `offsets`, if those are not ``NULL``. Only supported for ``mode='sum'``. | ||
// --only used in F::embedding_bag | ||
TORCH_ARG(torch::Tensor, per_sample_weights) = Tensor(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline: the ideal design is to have a separate F::EmbeddingBagFuncOptions
class which contains the arguments needed by F::embedding_bag
(and the set of arguments can be different from torch::nn::EmbeddingBagOptions
). For now we can remove offsets
and per_sample_weights
from torch::nn::EmbeddingBagOptions
and pass those two arguments explicitly to F::embedding_bag
, and after #29265 is merged I will open a PR to add F::EmbeddingBagFuncOptions
.
371b201
to
32602dc
Compare
…ding and embeddingbag options to options file
This reverts commit f092464945133e48c69c6da5fbb2bbcb77ac55e6.
…e prev line secretly privatize the mode_t from users
…ged float data type to bool for max_norm, norm_type
32602dc
to
0515002
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 Thanks a lot for the awesome work! I left some minor comments about TORCH_WARN
} else { | ||
options = EmbeddingOptions(embeddings.size(0), embeddings.size(1)); | ||
if (options.num_embeddings()) { | ||
TORCH_WARN(*options.num_embeddings() == embeddings.size(0), "`num_embeddings` options parameter is ignored in `torch::nn::Embedding::from_pretrained`."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we don't need to check *options.num_embeddings() == embeddings.size(0)
, because options.num_embeddings()
would just be ignored. We can write:
TORCH_WARN(*options.num_embeddings() == embeddings.size(0), "`num_embeddings` options parameter is ignored in `torch::nn::Embedding::from_pretrained`."); | |
TORCH_WARN("`num_embeddings` options parameter is ignored in `torch::nn::Embedding::from_pretrained`."); |
} | ||
Embedding embedding((*options)._weight(embeddings)); | ||
if (options.embedding_dim()) { | ||
TORCH_WARN(*options.embedding_dim() == embeddings.size(1), "`embedding_dim` options parameter is ignored in `torch::nn::Embedding::from_pretrained`."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we don't need to check *options.embedding_dim() == embeddings.size(1)
, because options.embedding_dim()
would just be ignored. We can write:
TORCH_WARN(*options.embedding_dim() == embeddings.size(1), "`embedding_dim` options parameter is ignored in `torch::nn::Embedding::from_pretrained`."); | |
TORCH_WARN("`embedding_dim` options parameter is ignored in `torch::nn::Embedding::from_pretrained`."); |
} else { | ||
options = EmbeddingBagOptions(embeddings.size(0), embeddings.size(1)); | ||
if (options.num_embeddings()) { | ||
TORCH_WARN(*options.num_embeddings() == embeddings.size(0), "`num_embeddings` options parameter is ignored in `torch::nn::EmbeddingBag::from_pretrained`."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto for this one
TORCH_WARN(*options.num_embeddings() == embeddings.size(0), "`num_embeddings` options parameter is ignored in `torch::nn::EmbeddingBag::from_pretrained`."); | ||
} | ||
if (options.embedding_dim()) { | ||
TORCH_WARN(*options.embedding_dim() == embeddings.size(1), "`embedding_dim` options parameter is ignored in `torch::nn::EmbeddingBag::from_pretrained`."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto for this one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@anjali411 merged this pull request in 604fc9e. |
No description provided.