From f78a65a9c62e19d08a9787ebb181fff3fcacf0cc Mon Sep 17 00:00:00 2001 From: ashishfarmer Date: Thu, 12 Nov 2020 17:44:41 +0000 Subject: [PATCH] fix GET_THREADS() for ROCm --- torchvision/csrc/cuda/DeformConv_cuda.cu | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchvision/csrc/cuda/DeformConv_cuda.cu b/torchvision/csrc/cuda/DeformConv_cuda.cu index c6e9a9278ed..507532e7184 100644 --- a/torchvision/csrc/cuda/DeformConv_cuda.cu +++ b/torchvision/csrc/cuda/DeformConv_cuda.cu @@ -81,6 +81,9 @@ const int kMaxParallelImgs = 32; inline unsigned int GET_THREADS() { +#ifdef __HIP_PLATFORM_HCC__ + return 256; +#endif if (at::cuda::getCurrentDeviceProperties()->major >= 6) { return 1024; }