-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabled masked for a bool tensor #19140
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@izdeby has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
So, how exactly is the deprecation cycle going to work? Are we going to change the behavior of these functions to do uint8 indexing eventually? |
it's not standard practice to deprecate things before the replacement is ready to use. I don't see a replacement ready to use here? |
There are 2 options here:
This PR implements the 2nd option as discussed offline. Should i go with the 1st option instead? |
I must have misunderstood the options, then. It's not useful to provide deprecation warnings without replacements. As a user, there's nothing I can do with that information. |
Sure, no problems, will switch to option #1. |
957da21
to
9869e12
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@izdeby has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@izdeby has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@izdeby has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@gchanan I'll let you review this, but let me know if you want me to look (I helped Iurii debug some issues while he was working on this patch.) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need sum ?
aten/src/ATen/Declarations.cwrap
Outdated
options: | ||
- arguments: | ||
- arg: THTensor* self | ||
broadcast: mask inplace fallback types:Byte |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this "types:Byte" thing looks questionable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@izdeby has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@izdeby has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
// As we dispatch on self and TH is type-chcked, we need different definitions. | ||
// This can be fixed by moving to ATen. | ||
if (mask.dtype() == at::ScalarType::Byte) { | ||
return at::legacy::th::_th_masked_fill_(self, mask, value); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in a future PR, it might make sense to call the bool one masked_fill_ and the other one masked_fill_byte, since that is the end state, but not necessary for now.
@@ -165,6 +247,67 @@ void THCTensor_(maskedSelect)(THCState* state, | |||
THCudaCheck(cudaGetLastError()); | |||
} | |||
|
|||
void THCTensor_(maskedSelectBool)(THCState* state, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also for the future: I think you could move these definitions into a non-generic function so you don't have to copy the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@izdeby has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Added deprecation warnings for the masked methods and enabled them for a bool tensor. Pull Request resolved: pytorch/pytorch#19140 Differential Revision: D14888021 Pulled By: izdeby fbshipit-source-id: 0e42daf8f3732ca29f36d10485402bfc502716ad
Summary: Added deprecation warnings for the masked methods and enabled them for a bool tensor. Pull Request resolved: pytorch#19140 Differential Revision: D14888021 Pulled By: izdeby fbshipit-source-id: 0e42daf8f3732ca29f36d10485402bfc502716ad
Enabled masked methods them for a bool tensor.