-
Notifications
You must be signed in to change notification settings - Fork 25k
Use c10::variant-based enums for Reduction #27942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
@@ -90,7 +92,7 @@ struct TORCH_API TripletMarginLossOptions { | |||
/// E. Riba et al. Default: False | |||
TORCH_ARG(bool, swap) = false; | |||
/// Specifies the reduction to apply to the output. Default: Mean | |||
TORCH_ARG(Reduction::Reduction, reduction) = Reduction::Mean; | |||
TORCH_ARG(10::variant<enumtype::kNone, enumtype::kMean, enumtype::kSum>, reduction) = torch::kMean; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I was looking if a PR of mine was in the list and found this PR which kind of concerns what I was doing :P I think you are missing a c
here (should be c10
) and in previous declarations too!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CarMiranda Thanks a lot for the catch! I just fixed it. After this PR is merged, I will do a sweep to change all torch::nn
layers that use Reduction
to use the corresponding variant type. :D
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
…div / mse_loss / binary_cross_entropy" Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
torch/csrc/api/include/torch/enum.h
Outdated
// ``` | ||
// Tensor some_functional( | ||
// const Tensor& input, | ||
// const SomeOptions& options = {}) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This smells wrong to me. Why don't you just take it by value? It's a very small struct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the comment here and will do a sweep to take Options
by value in all functionals in a follow-up PR. Although it doesn't fix the problem that TORCH_OPTIONS_CTOR_VARIANT_ARG3
/TORCH_OPTIONS_CTOR_VARIANT_ARG4
try to address though :(
return torch::l1_loss( | ||
input, | ||
target, | ||
c10::visit(enumtype::_reduction_get_enum{}, options.reduction())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of manually typing out c10::visit
everywhere, why not come up with a good API for doing this and call that instead? Especially since _reduction_get_enum
is underscored...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! I added torch::enumtype::reduction_get_enum()
API for this purpose :D
Can you say more about what "fix F::kl_div / mse_loss / binary_cross_entropy" means? |
…div / mse_loss / binary_cross_entropy" Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
…div / mse_loss / binary_cross_entropy" Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
…div / mse_loss / binary_cross_entropy" Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
The original logic in F::kl_div / mse_loss / binary_cross_entropy doesn't match that of Python version. I moved the changes to another PR since it is not strictly related to the Reduction enum changes. |
Use c10::variant-based enums for Reduction gh-metadata: pytorch pytorch 27942 gh/yf225/11/head
Stack from ghstack:
Differential Revision: D18202857