-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Description
I'm working on the CPU fallback for Position-sensitive ROI Pool/Align layers as in PR #1259.
The current version can be found here: https://github.com/LukasBommes/vision/tree/ps_roi_align_cpu
The CPU version of PSROIPool works already fine. All tests in test/test_ops.py run smoothly, including newly added tests for the CPU version.
However, with the PSROIAlign I encounter the problem that always the CPU version is executed, regardless of whether input tensors are of type 'cuda' or 'cpu'.
If I run the test case below, I get the following output
Testing test_ps_roi_align_basic_cuda
before creating PSROIAlign Layer
after creating PSROIAlign Layer
x_is_cuda: True
rois_is_cuda: True
Tensor is CUDA in PSROIAlign_forward
Using CUDA version of PSROIAlign_forward
Executing CPU version of PSROIAlignForward
Segmentation fault (core dumped)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_ps_roi_align_basic_cuda(self):
print("Testing test_ps_roi_align_basic_cuda")
device = torch.device('cuda')
pool_size = 3
x = torch.rand(1, 2 * (pool_size ** 2), 7, 7, dtype=self.dtype, device=device)
rois = torch.tensor([[0, 0, 0, 5, 5]], # format is (xyxy)
dtype=self.dtype, device=device)
pool_h, pool_w = (pool_size, pool_size)
print("before creating PSROIAlign Layer")
ps_roi_align = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2)
print("after creating PSROIAlign Layer")
print("x_is_cuda: {}".format(x.is_cuda))
print("rois_is_cuda: {}".format(rois.is_cuda))
y = ps_roi_align(x, rois)
print("after feedforward of data in PSROIAlign")
gt_y = self.slow_ps_roi_align(x, rois, pool_h, pool_w, device,
spatial_scale=1, sampling_ratio=2,
dtype=self.dtype)
assert torch.allclose(gt_y.cuda(), y), 'PSRoIAlign layer incorrect'
y = ps_roi_align(x.permute(0, 1, 3, 2), rois)
gt_y = self.slow_ps_roi_align(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device,
spatial_scale=1, sampling_ratio=-1,
dtype=self.dtype)
assert torch.allclose(gt_y.cuda(), y), 'PSRoIAlign layer incorrect'
The selection of whether cuda or cpu version of PSROIAlign_Forward andPSROIAlign_Backward happens in csrc/PSROIAlign.h which I modified as shown in the snippet below.
As indicated by the console output, the function PSROIAlign_forward_cuda is selected in this header file. Weirdly, PSROIAlign_forward_cpu is executed instead, leading to the segfault, because the tensors are cuda, but access happens within the cpu implementation.
#pragma once
#include "cpu/vision_cpu.h"
#ifdef WITH_CUDA
#include "cuda/vision_cuda.h"
#endif
#include <iostream>
std::tuple<at::Tensor, at::Tensor> PSROIAlign_forward(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
if (input.type().is_cuda()) {
std::cout << "Tensor is CUDA in PSROIAlign_forward" << std::endl;
#ifdef WITH_CUDA
std::cout << "Using CUDA version of PSROIAlign_forward" << std::endl;
return PSROIAlign_forward_cuda(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
std::cout << "Using CPU version of PSROIAlign_forward" << std::endl;
return PSROIAlign_forward_cpu(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio);
}
at::Tensor PSROIAlign_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& mapping_channel,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
const int batch_size,
const int channels,
const int height,
const int width) {
if (grad.type().is_cuda()) {
std::cout << "Tensor is CUDA in PSROIAlign_backward" << std::endl;
#ifdef WITH_CUDA
std::cout << "Using CUDA version of PSROIAlign_backward" << std::endl;
return PSROIAlign_backward_cuda(
grad,
rois,
mapping_channel,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
std::cout << "Using CPU version of PSROIAlign_backward" << std::endl;
return PSROIAlign_backward_cpu(
grad,
rois,
mapping_channel,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);
}
The same scheme works fine for the PSROIPool layer which is why I am a bit confused.
Help is greatly appreciated so I can finalize the PR.