From 991f2f9c01f52de2c185f52efc8a10615261751f Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 9 Jul 2020 12:25:48 +0200 Subject: [PATCH] Fix wrong clamping in RoIAlign with aligned=True (#2438) * Fix wrong clamping in RoIAlign with aligned=True * Fix silly mistake * Bugfix pointed out during code-review --- torchvision/csrc/cpu/ROIAlign_cpu.cpp | 20 ++++++++++++++------ torchvision/csrc/cuda/ROIAlign_cuda.cu | 20 ++++++++++++++------ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/torchvision/csrc/cpu/ROIAlign_cpu.cpp b/torchvision/csrc/cpu/ROIAlign_cpu.cpp index 75d3e7a90b4..03545883a69 100644 --- a/torchvision/csrc/cpu/ROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/ROIAlign_cpu.cpp @@ -141,9 +141,13 @@ void ROIAlignForward( T roi_end_w = offset_rois[3] * spatial_scale - offset; T roi_end_h = offset_rois[4] * spatial_scale - offset; - // Force malformed ROIs to be 1x1 - T roi_width = std::max(roi_end_w - roi_start_w, (T)1.); - T roi_height = std::max(roi_end_h - roi_start_h, (T)1.); + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); @@ -309,9 +313,13 @@ void ROIAlignBackward( T roi_end_w = offset_rois[3] * spatial_scale - offset; T roi_end_h = offset_rois[4] * spatial_scale - offset; - // Force malformed ROIs to be 1x1 - T roi_width = std::max(roi_end_w - roi_start_w, (T)1.); - T roi_height = std::max(roi_end_h - roi_start_h, (T)1.); + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = std::max(roi_width, (T)1.); + roi_height = std::max(roi_height, (T)1.); + } T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); diff --git a/torchvision/csrc/cuda/ROIAlign_cuda.cu b/torchvision/csrc/cuda/ROIAlign_cuda.cu index 8f8bcd10d48..84a8ba4e3bd 100644 --- a/torchvision/csrc/cuda/ROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/ROIAlign_cuda.cu @@ -91,9 +91,13 @@ __global__ void RoIAlignForward( T roi_end_w = offset_rois[3] * spatial_scale - offset; T roi_end_h = offset_rois[4] * spatial_scale - offset; - // Force malformed ROIs to be 1x1 - T roi_width = max(roi_end_w - roi_start_w, (T)1.); - T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); @@ -229,9 +233,13 @@ __global__ void RoIAlignBackward( T roi_end_w = offset_rois[3] * spatial_scale - offset; T roi_end_h = offset_rois[4] * spatial_scale - offset; - // Force malformed ROIs to be 1x1 - T roi_width = max(roi_end_w - roi_start_w, (T)1.); - T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) { + // Force malformed ROIs to be 1x1 + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); T bin_size_w = static_cast(roi_width) / static_cast(pooled_width);