Skip to content

Commit

Permalink
Make CUDAFuture remember and restore current device in callback
Browse files Browse the repository at this point in the history
CUDAFuture aims to "capture" the current state of CUDA-related stuff when the future is marked complete (e.g., by looking at current streams and recording events on them) and then "replicate" a similar state when users synchronize with the result of the future (by synchronizing the current streams with these events).

However, one "contextual" aspect of CUDA that we weren't capturing/replicating was the current device. This diff tries to fix that. I must mention that we can only do this for callbacks, while we cannot do it for the wait() method. I don't know if such a discrepancy between the two actually makes the overall behavior _worse_. I'd love to hear people's opinions on this.

Differential Revision: [D25210335](https://our.internmc.facebook.com/intern/diff/D25210335/)

[ghstack-poisoned]
  • Loading branch information
lw committed Dec 3, 2020
1 parent 5d9d272 commit ae4e7f3
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/cuda/CUDAFuture.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future {
}

void postMarkCompletedHook(const at::IValue& value) override {
currentDevice_ = c10::cuda::current_device();

// Extract them once and cache them for later uses.
dataPtrs_ = extractDataPtrs(value);

Expand Down Expand Up @@ -98,6 +100,8 @@ struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future {
}
}

c10::cuda::CUDAGuard deviceGuard(currentDevice_);

callback();
};
}
Expand All @@ -122,6 +126,10 @@ struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future {
// Once WorkNCCL is gone (as part of the Future and Work merge) this should be
// fixed.
protected:
// The device that was current when markCompleted was called, which we'll
// restore when invoking callbacks.
c10::DeviceIndex currentDevice_;

std::shared_ptr<std::vector<at::cuda::CUDAEvent>> cudaEvents_;
std::vector<std::reference_wrapper<const at::DataPtr>> dataPtrs_;

Expand Down
1 change: 1 addition & 0 deletions torch/lib/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ class ProcessGroupNCCL : public ProcessGroup {
return ev.device_index() == data_ptr.device().index();
}) != cudaEvents->end());
}
currentDevice_ = c10::cuda::current_device();
cudaEvents_ = std::move(cudaEvents);
dataPtrs_ = std::move(dataPtrs);
markCompleted(std::move(value));
Expand Down

0 comments on commit ae4e7f3

Please sign in to comment.