Skip to content

Commit

Permalink
[1.8] Do not print warning if CUDA driver not found (#51806) (#52050)
Browse files Browse the repository at this point in the history
Summary:
It frequently happens when PyTorch compiled with CUDA support is installed on machine that does not have NVIDIA GPUs.

Fixes #47038

Pull Request resolved: #51806

Reviewed By: ezyang

Differential Revision: D26285827

Pulled By: malfet

fbshipit-source-id: 9fd5e690d0135a2b219c1afa803fb69de9729f5e
  • Loading branch information
malfet committed Feb 12, 2021
1 parent f071020 commit c307a3f
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions c10/cuda/CUDAFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ int32_t driver_version() {
return driver_version;
}

int device_count_impl() {
int device_count_impl(bool fail_if_no_driver) {
int count;
auto err = cudaGetDeviceCount(&count);
if (err == cudaSuccess) {
Expand All @@ -34,6 +34,11 @@ int device_count_impl() {
case cudaErrorInsufficientDriver: {
auto version = driver_version();
if (version <= 0) {
if (!fail_if_no_driver) {
// No CUDA driver means no devices
count = 0;
break;
}
TORCH_CHECK(
false,
"Found no NVIDIA driver on your system. Please check that you "
Expand Down Expand Up @@ -95,9 +100,9 @@ DeviceIndex device_count() noexcept {
// initialize number of devices only once
static int count = []() {
try {
auto result = device_count_impl();
auto result = device_count_impl(/*fail_if_no_driver=*/false);
TORCH_INTERNAL_ASSERT(result <= std::numeric_limits<DeviceIndex>::max(), "Too many CUDA devices, DeviceIndex overflowed");
return device_count_impl();
return result;
} catch (const c10::Error& ex) {
// We don't want to fail, but still log the warning
// msg() returns the message without the stack trace
Expand All @@ -110,7 +115,7 @@ DeviceIndex device_count() noexcept {

DeviceIndex device_count_ensure_non_zero() {
// Call the implementation every time to throw the exception
int count = device_count_impl();
int count = device_count_impl(/*fail_if_no_driver=*/true);
// Zero gpus doesn't produce a warning in `device_count` but we fail here
TORCH_CHECK(count, "No CUDA GPUs are available");
return static_cast<DeviceIndex>(count);
Expand Down

0 comments on commit c307a3f

Please sign in to comment.