Skip to content

Commit

Permalink
Implementation of torch::cuda::synchronize (#50072)
Browse files Browse the repository at this point in the history
Summary:
Adding `torch::cuda::synchronize()` to libtorch. Note that the implementation here adds a new method to the `CUDAHooksInterface`. An alternative that was suggested to me is to add a method to the `DeviceGuard` interface.

Fixes #47722

Pull Request resolved: #50072

Reviewed By: H-Huang

Differential Revision: D25804342

Pulled By: jbschlosser

fbshipit-source-id: 45aa61d7c6fbfd3178caf2eb5ec053d6c01b5a43
  • Loading branch information
jbschlosser authored and facebook-github-bot committed Jan 6, 2021
1 parent e606e60 commit 7d9eb6c
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 0 deletions.
5 changes: 5 additions & 0 deletions aten/src/ATen/cuda/detail/CUDAHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,11 @@ int CUDAHooks::getNumGPUs() const {
return at::cuda::device_count();
}

void CUDAHooks::deviceSynchronize(int64_t device_index) const {
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
c10::cuda::device_synchronize();
}

// Sigh, the registry doesn't support namespaces :(
using at::CUDAHooksRegistry;
using at::RegistererCUDAHooksRegistry;
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/cuda/detail/CUDAHooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
int64_t cuFFTGetPlanCacheSize(int64_t device_index) const override;
void cuFFTClearPlanCache(int64_t device_index) const override;
int getNumGPUs() const override;
void deviceSynchronize(int64_t device_index) const override;
};

}}} // at::cuda::detail
4 changes: 4 additions & 0 deletions aten/src/ATen/detail/CUDAHooksInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ struct TORCH_API CUDAHooksInterface {
virtual int getNumGPUs() const {
return 0;
}

virtual void deviceSynchronize(int64_t device_index) const {
TORCH_CHECK(false, "Cannot synchronize CUDA device without ATen_cuda library. ", CUDA_HELP);
}
};

// NB: dummy argument to suppress "ISO C++11 requires at least one argument
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/api/include/torch/cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,8 @@ void TORCH_API manual_seed(uint64_t seed);
/// Sets the seed for all available GPUs.
void TORCH_API manual_seed_all(uint64_t seed);

/// Waits for all kernels in all streams on a CUDA device to complete.
void TORCH_API synchronize(int64_t device_index = -1);

} // namespace cuda
} // namespace torch
9 changes: 9 additions & 0 deletions torch/csrc/api/src/cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/cuda.h>

#include <ATen/Context.h>
#include <c10/core/DeviceGuard.h>

#include <cstddef>

Expand Down Expand Up @@ -49,5 +50,13 @@ void manual_seed_all(uint64_t seed) {
}
}

void synchronize(int64_t device_index) {
TORCH_CHECK(is_available(), "No CUDA GPUs are available");
int64_t num_gpus = cuda::device_count();
TORCH_CHECK(device_index == -1 || device_index < num_gpus,
"Device index out of range: ", device_index);
at::detail::getCUDAHooks().deviceSynchronize(device_index);
}

} // namespace cuda
} // namespace torch

0 comments on commit 7d9eb6c

Please sign in to comment.