-
Notifications
You must be signed in to change notification settings - Fork 25.6k
introduce TensorBase::mutable_data_ptr() #97859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
See D44409928 for motivation. Note that we keep the const-ness of the existing data_ptr() member so that we don't have to change all references atomically. We just change the ones here that we have higher confidence with. Differential Revision: [D44492539](https://our.internmc.facebook.com/intern/diff/D44492539/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D44492539/)! [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97859
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 9ab65b0: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
See D44409928 for motivation. Note that we keep the const-ness of the existing data_ptr() member so that we don't have to change all references atomically. We just change the ones here that we have higher confidence with. Differential Revision: [D44492539](https://our.internmc.facebook.com/intern/diff/D44492539/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D44492539/)! cc ezyang bhosmer smessmer ljk53 bdhirsh jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 EikanWang [ghstack-poisoned]
|
||
auto batch_stride = ndim > 2 && batch_offset >= 0 ? input_strides[ndim - 3] : 0; | ||
void* values_ptr = static_cast<char*>(input.data_ptr()) + | ||
void* values_ptr = static_cast<char*>(input.mutable_data_ptr()) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks wrong. Based on the description here I doubt this pointer is actually getting mutated. Generally, you can't trust the underlying libraries to have correct const or not types
cusparseDnVecDescr_t raw_descriptor; | ||
TORCH_CUDASPARSE_CHECK(cusparseCreateDnVec( | ||
&raw_descriptor, input.numel(), input.data_ptr(), value_type)); | ||
&raw_descriptor, input.numel(), input.mutable_data_ptr(), value_type)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
batch_offset * col_indices_batch_stride * col_indices.itemsize(), | ||
// values of the sparse matrix, size = nnz | ||
static_cast<char*>(values.data_ptr()) + | ||
static_cast<char*>(values.mutable_data_ptr()) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto x3
values.data_ptr())); | ||
crow_indices.mutable_data_ptr(), | ||
col_indices.mutable_data_ptr(), | ||
values.mutable_data_ptr())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
aten/src/ATen/cudnn/Descriptors.h
Outdated
AT_ASSERT(options.dtype() == kByte); | ||
state = at::empty({static_cast<int64_t>(state_size)}, options); | ||
AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed)); | ||
AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.mutable_data_ptr(), state_size, seed)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
aten/src/ATen/cudnn/Descriptors.h
Outdated
TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout"); | ||
state = state_; | ||
void *state_ptr = state.data_ptr(); | ||
void *state_ptr = state.mutable_data_ptr(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about this one as dropout state may get subsequently mutated through the descriptor. You will need to read docs
.build(); | ||
|
||
const auto gW_data = reinterpret_cast<char*>(grad_weight.data_ptr()); | ||
const auto gO_data = reinterpret_cast<char*>(grad.data_ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gO should be read only (gW should be write)
I'm pausing review, please do a reaudit of the rest of your changes. We may also want to discuss what the API for "I promise this is non mutating but the downstream API is not const correct" (most straightforward is to use const_cast, albeit a bit wordy) |
I propose the following:
Here the result of Even if we deem that too excessive, I would vote in favor of having a member function that is explicit about what exactly we are doing, e.g. Personally, I like the idea of the debug-only hashing. Trust but verify. |
There is no logical place to do the debug check. E.g., in the cudnn APIs, you stash the pointer in to the descriptor, and then the actual mutation would only happen later when you actually do an API call. |
Sure, that's generally true, but does that describe all or even most functions? And for functions that that are invoked with a separate plan then call: is this sequence typically done within another function or is there a separation between them that makes wrapping that call in a scope that checks difficult? Alternatively, could such a check be implemented at the dispatcher level? |
A dispatcher level check could look something like this: upon entering a function, we setup some TLS mapping input tensors to whether or not they showed up in mutable or non-mutable argument positions according to schema. We now error if you access a data pointer on a tensor that is not explicitly mentioned, or access mutable data pointer on a non-mutable argument. This check is unlikely to play well with |
That seems reasonable to me. OK, that's not going to happen in the very short-term. So what do you think about just having a naming convention to address the two cases we're concerned about here:
For the latter, I'm primarily concerned with being able to communicate areas of uncertainty to local experts. For example, it would be nice to put those changes into the flash attention implementation and defer to Driss to resolve them. I don't want to have to make this call authoritatively for every single function. |
ok |
OK, I think I have an even better idea for the first problem. How about we wrap the "unsafe" APIs with const correct wrappers? Then we do the |
I'm cool with that too. |
See D44409928 for motivation. Note that we keep the const-ness of the existing data_ptr() member so that we don't have to change all references atomically. We just change the ones here that we have higher confidence with. Differential Revision: [D44492539](https://our.internmc.facebook.com/intern/diff/D44492539/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D44492539/)! cc ezyang bhosmer smessmer ljk53 bdhirsh jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 EikanWang [ghstack-poisoned]
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack (oldest at bottom):
See D44409928 for motivation.
Note that we keep the const-ness of the existing data_ptr() member so
that we don't have to change all references atomically. We just change
the ones here that we have higher confidence with.
Differential Revision: D44492539
NOTE FOR REVIEWERS: This PR has internal Meta-specific changes or comments, please review them on Phabricator!
cc @ezyang @bhosmer @smessmer @ljk53 @bdhirsh @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @EikanWang