-
Couldn't load subscription status.
- Fork 25.7k
C++ API: torch::nn::Softmax #27446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
C++ API: torch::nn::Softmax #27446
Conversation
| /// Options for the `Softmax` module. | ||
| struct TORCH_API SoftmaxOptions { | ||
| // Dimension along which Softmax will be computed. | ||
| TORCH_ARG(int, dim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest to have -1 as a default value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I fixed it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After discussion with the team we agreed that we should use c10::optional<int64_t> instead of just int here. The default value should be c10::nullopt. I'm sorry for for misleading you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem. Thanks for the instruction.
Does this rule also apply to torch::Dtype ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dim should be a non-optional argument because allowing it to be null is already deprecated in Python version, and the C++ version should just match the newest Python version behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I fixed it to make dim non-optional argument.
|
Thanks for reviewing this PR. |
|
|
||
| if (dim == -1) { | ||
| int input_dim = input.dim(); | ||
| if (input_dim == 0 || input_dim == 1 || input_dim == 3) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please extract this to _get_softmax_dim(...) as it is in python and reuse it in LogSoftmax. Use TORCH_WARN for the warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed it.
| /// Options for the `Softmax` module. | ||
| struct TORCH_API SoftmaxOptions { | ||
| // Dimension along which Softmax will be computed. | ||
| TORCH_ARG(int, dim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After discussion with the team we agreed that we should use c10::optional<int64_t> instead of just int here. The default value should be c10::nullopt. I'm sorry for for misleading you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nuka137 Thanks so much for the great contribution! I left some comments mostly regarding the design of SoftmaxOptions.
| ret = 1; | ||
| } | ||
| return ret; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the Python version, this function is for maintaining backward compatibility for the use cases where people don't pass the dim argument to torch.nn.functional.softmax. Since torch::nn::functional::softmax in C++ API is a new addition, we don't need to maintain any backward compatibility, and we should remove this function and make dim a non-optional argument to match the newest Python API design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understood the background and deleted the function _get_softmax_dim.
| // The desired data type of returned tensor. | ||
| // If specified, the input tensor is casted to dtype before the operation | ||
| // is performed. This is useful for preventing data type overflows. | ||
| TORCH_ARG(torch::Dtype, dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think semantically dtype is not one of the options of a softmax function/module, because it doesn't control how the softmax function itself should perform the computation, but rather how the returned tensor should be stored, which is rather mechanical than mathematical.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, deleted dtype from SoftmaxOptions.
| /// Options for the `Softmax` module. | ||
| struct TORCH_API SoftmaxOptions { | ||
| // Dimension along which Softmax will be computed. | ||
| TORCH_ARG(int, dim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dim should be a non-optional argument because allowing it to be null is already deprecated in Python version, and the C++ version should just match the newest Python version behavior.
|
|
||
| /// Options for the Softmax functional and module. | ||
| struct TORCH_API SoftmaxOptions { | ||
| SoftmaxOptions(int dim = -1, torch::Dtype dtype = torch::Dtype::Undefined); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be
| SoftmaxOptions(int dim = -1, torch::Dtype dtype = torch::Dtype::Undefined); | |
| SoftmaxOptions(int dim); |
because dim shouldn't be an optional argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your suggestion. I fixed it.
| if (dtype != torch::Dtype::Undefined) { | ||
| stream << "dtype=" << options.dtype(); | ||
| } | ||
| stream << ")"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To match the Python version, we can just print dim in pretty_print.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed it to print dim only.
| return ret; | ||
| } | ||
|
|
||
| inline Tensor softmax(const Tensor& input, const SoftmaxOptions& options) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the signature should be
| inline Tensor softmax(const Tensor& input, const SoftmaxOptions& options) { | |
| inline Tensor softmax(const Tensor& input, const SoftmaxOptions& options, c10::optional<torch::Dtype> dtype = c10::nullopt) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your suggestion. I fixed it.
|
Thanks for reviewing. |
| SoftmaxOptions(int dim); | ||
|
|
||
| // Dimension along which Softmax will be computed. | ||
| TORCH_ARG(int, dim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should make dim have int64_t type, because that's the type of tensor dimensions for C++ tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, fixed it
| /// about the exact behavior of this module. | ||
| class TORCH_API SoftmaxImpl : public torch::nn::Cloneable<SoftmaxImpl> { | ||
| public: | ||
| explicit SoftmaxImpl(const SoftmaxOptions& options_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also add explicit SoftmaxImpl(int64_t dim) : SoftmaxImpl(SoftmaxOptions(dim)) {} here, to serve the torch::nn::Softmax(/*dim=*/3) use case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed it.
| int dim = options.dim(); | ||
| Tensor ret; | ||
|
|
||
| if (dtype == torch::Dtype::Undefined) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should change this to
| if (dtype == torch::Dtype::Undefined) { | |
| if (dtype == c10::nullopt) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I missed this part. Now, it was fixed.
|
|
||
| inline Tensor softmax(const Tensor& input, const SoftmaxOptions& options, | ||
| c10::optional<torch::Dtype> dtype = c10::nullopt) { | ||
| int dim = options.dim(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should change here accordingly:
| int dim = options.dim(); | |
| int64_t dim = options.dim(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I fixed it.
test/cpp/api/functional.cpp
Outdated
|
|
||
| TEST_F(FunctionalTest, Softmax) { | ||
| auto input = torch::arange(10, torch::kFloat).reshape({2, 5}); | ||
| auto output = F::softmax(input, SoftmaxOptions(1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can change this to
| auto output = F::softmax(input, SoftmaxOptions(1)); | |
| auto output = F::softmax(input, /*dim=*/1); |
to also test SoftmaxOptions's implicit constructor. ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed it.
test/cpp/api/modules.cpp
Outdated
| } | ||
|
|
||
| TEST_F(ModulesTest, Softmax) { | ||
| Softmax m(SoftmaxOptions(1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can change this to
| Softmax m(SoftmaxOptions(1)); | |
| Softmax m(/*dim=*/1); |
to also test SoftmaxOptions's implicit constructor. ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also fixed as well.
|
@nuka137 I think similar issues in this PR also exist in |
|
Fixed again all comments you reviewed.
Sure. I will change these PR when this PR is merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for the awesome work @nuka137! I will merge this today :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Add torch::nn::Softmax2d module support for the C++ API. Softmax2d only supports module in Python API, so this PR adds only module support as well. This PR is WIP because it uses the function in #27446 . After #27446 is merged, I will remove WIP. Related Issue: #25883 Reviewer: yf225 Pull Request resolved: #27509 Differential Revision: D17899715 Pulled By: yf225 fbshipit-source-id: bd891bc995f5a92bf4f5405f8bf07d1bd5de2479
Summary: Add torch::nn::Softmax module support for the C++ API Related Issue: pytorch#25883 Reviewer: yf225 Pull Request resolved: pytorch#27446 Differential Revision: D17839546 Pulled By: yf225 fbshipit-source-id: 7c7fb55111b261614de7c3a75fa1019fbde93c67
Summary: Add torch::nn::Softmax2d module support for the C++ API. Softmax2d only supports module in Python API, so this PR adds only module support as well. This PR is WIP because it uses the function in pytorch#27446 . After pytorch#27446 is merged, I will remove WIP. Related Issue: pytorch#25883 Reviewer: yf225 Pull Request resolved: pytorch#27509 Differential Revision: D17899715 Pulled By: yf225 fbshipit-source-id: bd891bc995f5a92bf4f5405f8bf07d1bd5de2479
Add torch::nn::Softmax module support for the C++ API
Related Issue: #25883
Reviewer: @yf225