From a50db0a0459f86b2b657743c92cf89f711fd6485 Mon Sep 17 00:00:00 2001 From: puhuk Date: Fri, 15 Apr 2022 20:14:14 +0900 Subject: [PATCH] Replace usages of atomicAdd with gpuAtomicAdd To resolve issue #5815 --- torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu | 2 +- torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu | 8 ++++---- torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu | 2 +- torchvision/csrc/ops/cuda/roi_align_kernel.cu | 8 ++++---- torchvision/csrc/ops/cuda/roi_pool_kernel.cu | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu b/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu index 6389d80117f..d28d332b41e 100644 --- a/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu +++ b/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu @@ -385,7 +385,7 @@ __global__ void deformable_col2im_kernel( std::abs(y - yp) < 1 && std::abs(x - xp) < 1) { index_t grad_pos = ((b * channels + c) * height + yp) * width + xp; scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); - atomicAdd(grad_im + grad_pos, mask_value * weight * col[index]); + gpuAtomicAdd(grad_im + grad_pos, mask_value * weight * col[index]); } } } diff --git a/torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu b/torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu index bcb6cd5783e..b9c624b09c8 100644 --- a/torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu +++ b/torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu @@ -285,10 +285,10 @@ __global__ void ps_roi_align_backward_kernel_impl( T g4 = grad_output_this_bin * w4 / count; if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - atomicAdd(grad_input_offset + y_low * width + x_low, g1); - atomicAdd(grad_input_offset + y_low * width + x_high, g2); - atomicAdd(grad_input_offset + y_high * width + x_low, g3); - atomicAdd(grad_input_offset + y_high * width + x_high, g4); + gpuAtomicAdd(grad_input_offset + y_low * width + x_low, g1); + gpuAtomicAdd(grad_input_offset + y_low * width + x_high, g2); + gpuAtomicAdd(grad_input_offset + y_high * width + x_low, g3); + gpuAtomicAdd(grad_input_offset + y_high * width + x_high, g4); } // if } // ix } // iy diff --git a/torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu b/torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu index b0a156104ec..917fff03e8d 100644 --- a/torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu +++ b/torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu @@ -131,7 +131,7 @@ __global__ void ps_roi_pool_backward_kernel_impl( for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int grad_input_index = h * width + w; - atomicAdd(grad_input_offset + grad_input_index, diff_val); + gpuAtomicAdd(grad_input_offset + grad_input_index, diff_val); } } } diff --git a/torchvision/csrc/ops/cuda/roi_align_kernel.cu b/torchvision/csrc/ops/cuda/roi_align_kernel.cu index 0bc87da8fd9..f1f886c4738 100644 --- a/torchvision/csrc/ops/cuda/roi_align_kernel.cu +++ b/torchvision/csrc/ops/cuda/roi_align_kernel.cu @@ -301,13 +301,13 @@ __global__ void roi_align_backward_kernel_impl( T g4 = grad_output_this_bin * w4 / count; if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { - atomicAdd( + gpuAtomicAdd( offset_grad_input + y_low * width + x_low, static_cast(g1)); - atomicAdd( + gpuAtomicAdd( offset_grad_input + y_low * width + x_high, static_cast(g2)); - atomicAdd( + gpuAtomicAdd( offset_grad_input + y_high * width + x_low, static_cast(g3)); - atomicAdd( + gpuAtomicAdd( offset_grad_input + y_high * width + x_high, static_cast(g4)); } // if } // ix diff --git a/torchvision/csrc/ops/cuda/roi_pool_kernel.cu b/torchvision/csrc/ops/cuda/roi_pool_kernel.cu index f41f530cf12..e29c4438ed4 100644 --- a/torchvision/csrc/ops/cuda/roi_pool_kernel.cu +++ b/torchvision/csrc/ops/cuda/roi_pool_kernel.cu @@ -113,7 +113,7 @@ __global__ void roi_pool_backward_kernel_impl( int argmax = argmax_data_offset[ph * pooled_width + pw]; if (argmax != -1) { - atomicAdd( + gpuAtomicAdd( grad_input_offset + argmax, static_cast( grad_output[output_offset + ph * h_stride + pw * w_stride]));