diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index b736055db61a..af6b27f322ff 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -6,6 +6,7 @@ from torch.testing._internal.common_utils import TEST_NUMBA import inspect import contextlib +from distutils.version import LooseVersion TEST_CUDA = torch.cuda.is_available() @@ -15,7 +16,7 @@ TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)) TEST_CUDNN_VERSION = torch.backends.cudnn.version() if TEST_CUDNN else 0 -CUDA11OrLater = torch.version.cuda and float(torch.version.cuda) >= 11 +CUDA11OrLater = torch.version.cuda and LooseVersion(torch.version.cuda) >= "11.0.0" CUDA9 = torch.version.cuda and torch.version.cuda.startswith('9.') SM53OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3)