Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ __global__ void deformable_col2im_kernel(
std::abs(y - yp) < 1 && std::abs(x - xp) < 1) {
index_t grad_pos = ((b * channels + c) * height + yp) * width + xp;
scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp));
atomicAdd(grad_im + grad_pos, mask_value * weight * col[index]);
gpuAtomicAdd(grad_im + grad_pos, mask_value * weight * col[index]);
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/ops/cuda/ps_roi_align_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,10 @@ __global__ void ps_roi_align_backward_kernel_impl(
T g4 = grad_output_this_bin * w4 / count;

if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomicAdd(grad_input_offset + y_low * width + x_low, g1);
atomicAdd(grad_input_offset + y_low * width + x_high, g2);
atomicAdd(grad_input_offset + y_high * width + x_low, g3);
atomicAdd(grad_input_offset + y_high * width + x_high, g4);
gpuAtomicAdd(grad_input_offset + y_low * width + x_low, g1);
gpuAtomicAdd(grad_input_offset + y_low * width + x_high, g2);
gpuAtomicAdd(grad_input_offset + y_high * width + x_low, g3);
gpuAtomicAdd(grad_input_offset + y_high * width + x_high, g4);
} // if
} // ix
} // iy
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/ops/cuda/ps_roi_pool_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ __global__ void ps_roi_pool_backward_kernel_impl(
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int grad_input_index = h * width + w;
atomicAdd(grad_input_offset + grad_input_index, diff_val);
gpuAtomicAdd(grad_input_offset + grad_input_index, diff_val);
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/ops/cuda/roi_align_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -301,13 +301,13 @@ __global__ void roi_align_backward_kernel_impl(
T g4 = grad_output_this_bin * w4 / count;

if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
atomicAdd(
gpuAtomicAdd(
offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
atomicAdd(
gpuAtomicAdd(
offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
atomicAdd(
gpuAtomicAdd(
offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
atomicAdd(
gpuAtomicAdd(
offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
} // if
} // ix
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/ops/cuda/roi_pool_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ __global__ void roi_pool_backward_kernel_impl(
int argmax = argmax_data_offset[ph * pooled_width + pw];

if (argmax != -1) {
atomicAdd(
gpuAtomicAdd(
grad_input_offset + argmax,
static_cast<T>(
grad_output[output_offset + ph * h_stride + pw * w_stride]));
Expand Down