Skip to content

Commit

Permalink
Fix wrong clamping in RoIAlign with aligned=True (#2438) (#2445)
Browse files Browse the repository at this point in the history
Summary:
* Fix wrong clamping in RoIAlign with aligned=True

* Fix silly mistake

* Bugfix pointed out during code-review

Pull Request resolved: #2445

Reviewed By: zhangguanheng66

Differential Revision: D22458789

Pulled By: fmassa

fbshipit-source-id: cbe4d7df64b56b2c0b44c21c3fb155d40c74e057
  • Loading branch information
fmassa authored and facebook-github-bot committed Jul 9, 2020
1 parent 496ed83 commit 3ea1969
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
20 changes: 14 additions & 6 deletions torchvision/csrc/cpu/ROIAlign_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,13 @@ void ROIAlignForward(
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;

// Force malformed ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = std::max(roi_width, (T)1.);
roi_height = std::max(roi_height, (T)1.);
}

T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
Expand Down Expand Up @@ -309,9 +313,13 @@ void ROIAlignBackward(
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;

// Force malformed ROIs to be 1x1
T roi_width = std::max(roi_end_w - roi_start_w, (T)1.);
T roi_height = std::max(roi_end_h - roi_start_h, (T)1.);
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = std::max(roi_width, (T)1.);
roi_height = std::max(roi_height, (T)1.);
}

T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
Expand Down
20 changes: 14 additions & 6 deletions torchvision/csrc/cuda/ROIAlign_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,13 @@ __global__ void RoIAlignForward(
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;

// Force malformed ROIs to be 1x1
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
}

T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
Expand Down Expand Up @@ -229,9 +233,13 @@ __global__ void RoIAlignBackward(
T roi_end_w = offset_rois[3] * spatial_scale - offset;
T roi_end_h = offset_rois[4] * spatial_scale - offset;

// Force malformed ROIs to be 1x1
T roi_width = max(roi_end_w - roi_start_w, (T)1.);
T roi_height = max(roi_end_h - roi_start_h, (T)1.);
T roi_width = roi_end_w - roi_start_w;
T roi_height = roi_end_h - roi_start_h;
if (!aligned) {
// Force malformed ROIs to be 1x1
roi_width = max(roi_width, (T)1.);
roi_height = max(roi_height, (T)1.);
}

T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
Expand Down

0 comments on commit 3ea1969

Please sign in to comment.