diff --git a/torchvision/csrc/cpu/ROIAlign_cpu.cpp b/torchvision/csrc/cpu/ROIAlign_cpu.cpp index 75d3e7a90b4..03545883a69 100644 --- a/torchvision/csrc/cpu/ROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/ROIAlign_cpu.cpp @@ -141,9 +141,13 @@ void ROIAlignForward( T roi_end_w = offset_rois[3] * spatial_scale - offset; T roi_end_h = offset_rois[4] * spatial_scale - offset; - // Force malformed ROIs to be 1x1 - T roi_width = std::max(roi_end_w - roi_start_w, (T)1.); - T roi_height = std::max(roi_end_h - roi_start_h, (T)1.); + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); @@ -309,9 +313,13 @@ void ROIAlignBackward( T roi_end_w = offset_rois[3] * spatial_scale - offset; T roi_end_h = offset_rois[4] * spatial_scale - offset; - // Force malformed ROIs to be 1x1 - T roi_width = std::max(roi_end_w - roi_start_w, (T)1.); - T roi_height = std::max(roi_end_h - roi_start_h, (T)1.); + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); diff --git a/torchvision/csrc/cuda/ROIAlign_cuda.cu b/torchvision/csrc/cuda/ROIAlign_cuda.cu index 8f8bcd10d48..84a8ba4e3bd 100644 --- a/torchvision/csrc/cuda/ROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/ROIAlign_cuda.cu @@ -91,9 +91,13 @@ __global__ void RoIAlignForward( T roi_end_w = offset_rois[3] * spatial_scale - offset; T roi_end_h = offset_rois[4] * spatial_scale - offset; - // Force malformed ROIs to be 1x1 - T roi_width = max(roi_end_w - roi_start_w, (T)1.); - T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); @@ -229,9 +233,13 @@ __global__ void RoIAlignBackward( T roi_end_w = offset_rois[3] * spatial_scale - offset; T roi_end_h = offset_rois[4] * spatial_scale - offset; - // Force malformed ROIs to be 1x1 - T roi_width = max(roi_end_w - roi_start_w, (T)1.); - T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);