Skip to content

Conversation

dagitses
Copy link
Collaborator

@dagitses dagitses commented Mar 29, 2023

See D44409928 for motivation.

Note that we keep the const-ness of the existing data_ptr() member so
that we don't have to change all references atomically. We just change
the ones here that we have higher confidence with.

Differential Revision: [D44492539](https://our.internmc.facebook.com/intern/diff/D44492539/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D44492539/)!

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 29, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97859

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 9ab65b0:

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@dagitses dagitses added module: internals Related to internal abstractions in c10 and ATen and removed release notes: quantization release notes category labels Mar 29, 2023
@dagitses dagitses marked this pull request as draft March 29, 2023 06:08
@dagitses dagitses marked this pull request as ready for review March 29, 2023 10:14
See D44409928 for motivation.

Note that we keep the const-ness of the existing data_ptr() member so
that we don't have to change all references atomically. We just change
the ones here that we have higher confidence with.

Differential Revision: [D44492539](https://our.internmc.facebook.com/intern/diff/D44492539/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D44492539/)!

cc ezyang bhosmer smessmer ljk53 bdhirsh jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 EikanWang

[ghstack-poisoned]
@github-actions github-actions bot added module: cpu CPU specific problem (e.g., perf, algorithm) NNC release notes: quantization release notes category labels Mar 29, 2023

auto batch_stride = ndim > 2 && batch_offset >= 0 ? input_strides[ndim - 3] : 0;
void* values_ptr = static_cast<char*>(input.data_ptr()) +
void* values_ptr = static_cast<char*>(input.mutable_data_ptr()) +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong. Based on the description here I doubt this pointer is actually getting mutated. Generally, you can't trust the underlying libraries to have correct const or not types

cusparseDnVecDescr_t raw_descriptor;
TORCH_CUDASPARSE_CHECK(cusparseCreateDnVec(
&raw_descriptor, input.numel(), input.data_ptr(), value_type));
&raw_descriptor, input.numel(), input.mutable_data_ptr(), value_type));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

batch_offset * col_indices_batch_stride * col_indices.itemsize(),
// values of the sparse matrix, size = nnz
static_cast<char*>(values.data_ptr()) +
static_cast<char*>(values.mutable_data_ptr()) +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto x3

values.data_ptr()));
crow_indices.mutable_data_ptr(),
col_indices.mutable_data_ptr(),
values.mutable_data_ptr()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

AT_ASSERT(options.dtype() == kByte);
state = at::empty({static_cast<int64_t>(state_size)}, options);
AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.data_ptr(), state_size, seed));
AT_CUDNN_CHECK(cudnnSetDropoutDescriptor(mut_desc(), handle, dropout, state.mutable_data_ptr(), state_size, seed));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

TORCH_INTERNAL_ASSERT(dropout > 0, "dropout must be nonzero; otherwise call set_no_dropout");
state = state_;
void *state_ptr = state.data_ptr();
void *state_ptr = state.mutable_data_ptr();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about this one as dropout state may get subsequently mutated through the descriptor. You will need to read docs

.build();

const auto gW_data = reinterpret_cast<char*>(grad_weight.data_ptr());
const auto gO_data = reinterpret_cast<char*>(grad.data_ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gO should be read only (gW should be write)

@ezyang
Copy link
Contributor

ezyang commented Mar 29, 2023

I'm pausing review, please do a reaudit of the rest of your changes. We may also want to discuss what the API for "I promise this is non mutating but the downstream API is not const correct" (most straightforward is to use const_cast, albeit a bit wordy)

@dagitses
Copy link
Collaborator Author

I propose the following:

  • in the short-term, we have tensor.must_audit_mutable_data_ptr(). This can just reflect the status quo of what we have an gives us a convenient string to search for.
  • regarding const_cast, it might be nice to have a debug-only API that helps verify the mutable borrow is not mutated, something like:
{
  auto data_ptr = tensor.borrow_const_data_ptr();
  someCudaApi(data_ptr.subtle_as_non_const_for_naughty_api());
}

Here the result of borrow_const_data_ptr() is an RAII object that hashes the data upon construction in debug mode and hashes on destruction in debug mode and asserts they are identical.

Even if we deem that too excessive, I would vote in favor of having a member function that is explicit about what exactly we are doing, e.g. unsafe_get_non_const_data_ptr_for_non_const_correct_api(). Otherwise, I would feel compelled to add a wordy comment any place I used const_cast to justify it.

Personally, I like the idea of the debug-only hashing. Trust but verify.

@ezyang
Copy link
Contributor

ezyang commented Mar 29, 2023

There is no logical place to do the debug check. E.g., in the cudnn APIs, you stash the pointer in to the descriptor, and then the actual mutation would only happen later when you actually do an API call.

@dagitses
Copy link
Collaborator Author

There is no logical place to do the debug check. E.g., in the cudnn APIs, you stash the pointer in to the descriptor, and then the actual mutation would only happen later when you actually do an API call.

Sure, that's generally true, but does that describe all or even most functions? And for functions that that are invoked with a separate plan then call: is this sequence typically done within another function or is there a separation between them that makes wrapping that call in a scope that checks difficult?

Alternatively, could such a check be implemented at the dispatcher level?

@ezyang
Copy link
Contributor

ezyang commented Mar 29, 2023

A dispatcher level check could look something like this: upon entering a function, we setup some TLS mapping input tensors to whether or not they showed up in mutable or non-mutable argument positions according to schema. We now error if you access a data pointer on a tensor that is not explicitly mentioned, or access mutable data pointer on a non-mutable argument.

This check is unlikely to play well with __torch_dispatch__ though, so it will take some thought on how to design it correctly. See also @albanD constantly having to fend off people who want to delete the refcount asserts from autograd.

@dagitses
Copy link
Collaborator Author

A dispatcher level check could look something like this: upon entering a function, we setup some TLS mapping input tensors to whether or not they showed up in mutable or non-mutable argument positions according to schema. We now error if you access a data pointer on a tensor that is not explicitly mentioned, or access mutable data pointer on a non-mutable argument.

This check is unlikely to play well with __torch_dispatch__ though, so it will take some thought on how to design it correctly. See also @albanD constantly having to fend off people who want to delete the refcount asserts from autograd.

That seems reasonable to me. OK, that's not going to happen in the very short-term. So what do you think about just having a naming convention to address the two cases we're concerned about here:

  • tensor.get_data_as_non_const_for_external_api(): for where we'd use const_cast
  • tensor.must_audit_get_mutable_data(): for where we want to move off of the const returning accessor but need to more closely scrutinize in the medium term.

For the latter, I'm primarily concerned with being able to communicate areas of uncertainty to local experts. For example, it would be nice to put those changes into the flash attention implementation and defer to Driss to resolve them. I don't want to have to make this call authoritatively for every single function.

@ezyang
Copy link
Contributor

ezyang commented Mar 29, 2023

ok

@dagitses
Copy link
Collaborator Author

ok

OK, I think I have an even better idea for the first problem. How about we wrap the "unsafe" APIs with const correct wrappers? Then we do the const_cast inside the wrapper. This makes the most sense to me because we should be auditing APIs for const correctness, not callers.

@ezyang
Copy link
Contributor

ezyang commented Mar 30, 2023

I'm cool with that too.

See D44409928 for motivation.

Note that we keep the const-ness of the existing data_ptr() member so
that we don't have to change all references atomically. We just change
the ones here that we have higher confidence with.

Differential Revision: [D44492539](https://our.internmc.facebook.com/intern/diff/D44492539/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D44492539/)!

cc ezyang bhosmer smessmer ljk53 bdhirsh jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 EikanWang

[ghstack-poisoned]
@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jun 18, 2023
@github-actions github-actions bot closed this Jul 18, 2023
@ezyang ezyang added the ezyang's list Stuff ezyang doesn't want to lose label Jul 18, 2023
@facebook-github-bot facebook-github-bot deleted the gh/dagitses/35/head branch August 17, 2023 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ezyang's list Stuff ezyang doesn't want to lose module: cpu CPU specific problem (e.g., perf, algorithm) module: internals Related to internal abstractions in c10 and ATen NNC open source release notes: quantization release notes category Stale topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants