diff --git a/torchvision/csrc/cuda/DeformConv_cuda.cu b/torchvision/csrc/cuda/DeformConv_cuda.cu index c6e9a9278ed..507532e7184 100644 --- a/torchvision/csrc/cuda/DeformConv_cuda.cu +++ b/torchvision/csrc/cuda/DeformConv_cuda.cu @@ -81,6 +81,9 @@ const int kMaxParallelImgs = 32; inline unsigned int GET_THREADS() { +#ifdef __HIP_PLATFORM_HCC__ + return 256; +#endif if (at::cuda::getCurrentDeviceProperties()->major >= 6) { return 1024; }