Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make CUDAFuture remember and restore current device in callback #48789

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions aten/src/ATen/cuda/CUDAFuture.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future {
}

void postMarkCompletedHook(const at::IValue& value) override {
currentDevice_ = c10::cuda::current_device();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, why we are recording the device when the future is marked as completed, instead of remember the device when the callback was inserted (through then/addCallback)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess those are two different approaches to this. The "philosophy" I was following is this: then() and addCallback() are used to perform a computation after another one is complete. If we were dealing with with sync operations, one would do this:

do_sth_sync()
do_sth_later()

but with async ops this needs to change and become

fut = do_sth_async()
fut.then(do_sth_later)

In the sync scenario, do_sth_later() runs in the same "environment"/"context" as do_sth_sync() (same current device, same current streams, ...). In this diff I was trying to recreate this in the async case, by "recording" the environment's state at the end of the async operation, and then "recreating" it in the callback.

Note that the approach we have for streams is somewhat similar. We do not run the callback in the streams that were current when the callback was inserted. (In honesty, we also do not run it in the streams that were current when the async op finished (as that would mean running computations in the I/O streams), but we do synchronize the "fresh" streams with those streams).

Note also that I am not opposed to changing this behavior, and replicate the "environment" that was current when the user inserted the callback. That has its own sets of advantages. For streams, for example, it allows the user to very precisely control what streams the callback will use.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For streams, for example, it allows the user to very precisely control what streams the callback will use.

Yep, same for device as well, especially when the callback fn is an imported function and users cannot easily change it. The current behavior should be sufficient to unblock RPC use cases. So I am OK to land this PR and modify it later if necessary.


// Extract them once and cache them for later uses.
dataPtrs_ = extractDataPtrs(value);

Expand Down Expand Up @@ -85,6 +87,8 @@ struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future {
}
}

c10::cuda::CUDAGuard deviceGuard(currentDevice_);

callback();
};
}
Expand All @@ -109,6 +113,10 @@ struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future {
// Once WorkNCCL is gone (as part of the Future and Work merge) this should be
// fixed.
protected:
// The device that was current when markCompleted was called, which we'll
// restore when invoking callbacks.
c10::DeviceIndex currentDevice_;

// The events that correspond to the completion of the async I/O kernels. They
// are recorded on the appropriate streams when the future is marked completed
// and can then be queried/waited/blocked on. There is one event for each
Expand Down
1 change: 1 addition & 0 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ class ProcessGroupNCCL : public ProcessGroup {
return ev.device_index() == data_ptr.device().index();
}) != cudaEvents->end());
}
currentDevice_ = c10::cuda::current_device();
cudaEvents_ = std::move(cudaEvents);
dataPtrs_ = std::move(dataPtrs);
markCompleted(std::move(value));
Expand Down