Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of torch::cuda::synchronize #50072

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions aten/src/ATen/cuda/detail/CUDAHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,11 @@ int CUDAHooks::getNumGPUs() const {
return at::cuda::device_count();
}

void CUDAHooks::deviceSynchronize(int64_t device_index) const {
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
c10::cuda::device_synchronize();
}

// Sigh, the registry doesn't support namespaces :(
using at::CUDAHooksRegistry;
using at::RegistererCUDAHooksRegistry;
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/cuda/detail/CUDAHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
int64_t cuFFTGetPlanCacheSize(int64_t device_index) const override;
void cuFFTClearPlanCache(int64_t device_index) const override;
int getNumGPUs() const override;
void deviceSynchronize(int64_t device_index) const override;
};

}}} // at::cuda::detail
4 changes: 4 additions & 0 deletions aten/src/ATen/detail/CUDAHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ struct TORCH_API CUDAHooksInterface {
virtual int getNumGPUs() const {
return 0;
}

virtual void deviceSynchronize(int64_t device_index) const {
TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP);
}
};

// NB: dummy argument to suppress "ISO C++11 requires at least one argument
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/api/include/torch/cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,8 @@ void TORCH_API manual_seed(uint64_t seed);
/// Sets the seed for all available GPUs.
void TORCH_API manual_seed_all(uint64_t seed);

/// Waits for all kernels in all streams on a CUDA device to complete.
void TORCH_API synchronize(int64_t device_index = -1);

} // namespace cuda
} // namespace torch
9 changes: 9 additions & 0 deletions torch/csrc/api/src/cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/cuda.h>

#include <ATen/Context.h>
#include <c10/core/DeviceGuard.h>

#include <cstddef>

Expand Down Expand Up @@ -49,5 +50,13 @@ void manual_seed_all(uint64_t seed) {
}
}

void synchronize(int64_t device_index) {
TORCH_CHECK(is_available(), "No CUDA GPUs are available");
int64_t num_gpus = cuda::device_count();
TORCH_CHECK(device_index == -1 || device_index < num_gpus,
"Device index out of range: ", device_index);
at::detail::getCUDAHooks().deviceSynchronize(device_index);
}

} // namespace cuda
} // namespace torch