-
Notifications
You must be signed in to change notification settings - Fork 25.3k
Add test to check that COW inputs are not materialized #119507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/119507
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 2 Unrelated FailuresAs of commit 74061db with merge base d534a49 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Part of #97856 [ghstack-poisoned]
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
(some of) test fail looks real |
According to the trace for that failing macos job, EDIT: Oh I see. The pytorch/aten/src/ATen/native/cpu/DepthwiseConvKernel.cpp Lines 295 to 301 in a7f82b7
By the time one of the threads tries to delete the COWDeleterContext, another thread has already changed the refcount. The refcount is atomic, but in this case that doesn't prevent the race since the combined operation of decrementing the refcount and then deleting the context is not mutexed. pytorch/c10/core/impl/COWDeleter.cpp Lines 18 to 33 in a7f82b7
I'll have to add a cpp test case that makes this race happen, and then fix it |
It turns out that fixing this race condition is trickier than I thought, unless I'm overlooking something. Here's a basic example to make sure we're on the same page: Tensor tensor = at::_lazy_clone(at::randn({10}));
int64_t num_thread = 100;
at::parallel_for(0, num_threads, 1, [&](int64_t begin, int64_t end) {
void* ptr = tensor.mutable_data_ptr();
}); Inside of the To materialize the data, thread A needs to replace the During this time, thread A needs to somehow coordinate with the other threads to make them wait until it is finished materializing the I was considering using the mutex in the I was also considering having each thread increment the refcount of the I thought of one solution that works, but it doesn't seem ideal. The idea is to add a mutex to the Another option would be to provide a different mutex for each Another option is maybe to create a global container of mutexes. Each mutex would be paired with a reference to the Maybe there's a better solution that I haven't thought of. What do you think @ezyang ? |
My proposal will require some upfront work, but I think it is the simplest: let's ban multithreaded access to the mutable data pointer. The reasoning here is that in C++ it's typical that an object is not safe to access mutably concurrently from multiple threads, unless the object has been explicitly synchronized. at::Tensor traditionally falls in this bucket. Now, in this particular case, we have code accessing data pointer in parallel from multiple threads from parallel_for. But this is easy to fix: we should access the mutable pointer once from outside of parallel for, and close over the resulting pointer (with the side condition that writes must be non-aliasing; we can't test for this in the compiler.) This means we have to audit all existing parallel_for calls and ban data_ptr calls from inside them. Not sure if there's a way to enforce this dynamically cheaply, but that would help too. |
@ezyang, that might be the best idea. However, it seems like it's not actually good enough to only ban direct I've been looking through the places where So more specifically, we would need to ban anything that causes a Without a way to automatically detect this, I think it will be very difficult to enforce. It's not obvious to me how we would detect it (EDIT: Actually, I do have a possible idea I've added below). So I think it's worth asking, if we can't think of a good way to detect that something inside a I do have a potential idea for detecting it: std::thread::id main_thread_id = std::this_thread::get_id();
C10_API void materialize_cow_storage(StorageImpl& storage) {
TORCH_INTERNAL_ASSERT(
std::this_thread::get_id() == main_thread_id,
"Materializing a storage from a subthread is forbidden");
... However, I tested this a bit and it doesn't seem very good. Because the first thread in I have another idea though: At the beginning of a |
I tried this and it seems to work. Is this alright with you? |
@pytorchbot merge -f "unrelated CI failures" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hey @kurtamohler, it looks like XLA job failure is actually caused by your PR. See hud. To verify, I retried the job on the previous commit in trunk and it succeeded: I'll revert your stack to keep trunk green, unless you have any objections. Thank you. |
Thanks for the heads up @izaitsevfb , sure, you can revert these two PRs. There are several other failures elsewhere that I need to fix as well |
@pytorchbot revert -m "breaks xla jobs" -c ignoredsignal |
@pytorchbot successfully started a revert job. Check the current status here. |
…)" This reverts commit 2ebf2c8. Reverted #119507 on behalf of https://github.com/izaitsevfb due to breaks xla jobs ([comment](#119507 (comment)))
@kurtamohler your PR has been successfully reverted. |
Part of #97856 [ghstack-poisoned]
Part of #97856 Pull Request resolved: #119507 Approved by: https://github.com/ezyang ghstack dependencies: #120455 ghstack-source-id: 3fa863d
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
at::parallel_for/parallel_reduce
#120455Part of #97856