diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp index f86b624b84f7..bc0e8546e318 100644 --- a/aten/src/ATen/native/cuda/jit_utils.cpp +++ b/aten/src/ATen/native/cuda/jit_utils.cpp @@ -894,6 +894,8 @@ void codegenOutputQuery( max_dev_version = CUDAVersion(7, 5); } else if (nvrtc_version == CUDAVersion(11, 0)) { // 11.0 supports 3-8.0 max_dev_version = CUDAVersion(8, 0); + } else if (nvrtc_major == 11 && nvrtc_minor < 8) { + max_dev_version = CUDAVersion(8, 6); } else { // If the driver version is unknown (i.e. newer than this code) // assume the driver supports this device diff --git a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp index 85bd74bfdbae..85de541f4ba7 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp +++ b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp @@ -64,6 +64,8 @@ void codegenOutputQuery( max_dev_version = CudaVersion(7, 5); } else if (nvrtc_version == CudaVersion(11, 0)) { // 11.0 supports 3-8.0 max_dev_version = CudaVersion(8, 0); + } else if (nvrtc_version.first == 11 && nvrtc_version.second < 8) { + max_dev_version = CudaVersion(8, 6); } else { // If the driver version is unknown (i.e. newer than this code) // assume the driver supports this device