diff --git a/test/test_ops.py b/test/test_ops.py index 7db8c6981d0..3059f282b56 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -20,11 +20,10 @@ def slow_roi_pooling(self, x, rois, pool_h, pool_w, spatial_scale=1, c = x.size(1) y = torch.zeros(rois.size(0), c, pool_h, pool_w, dtype=dtype, device=device) - rois = torch.round(rois * spatial_scale) - - for n in range(0, y.size(0)): + for n in range(0, x.size(0)): for r, roi in enumerate(rois): if roi[0] == n: + roi[1:] = torch.round(roi[1:] * spatial_scale) start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1 start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1 roi_x = x[roi[0].long(), :, start_h:end_h, start_w:end_w] @@ -58,6 +57,12 @@ def test_roi_pool_basic_cpu(self): gt_y = self.slow_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device=device, dtype=self.dtype) assert torch.allclose(gt_y, y), 'RoIPool layer incorrect on CPU' + # spatial-scale != 1 + y = ops.RoIPool((pool_h, pool_w), 2)(x.permute(0, 1, 3, 2), rois) + gt_y = self.slow_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, + spatial_scale=2, device=device, dtype=self.dtype) + assert torch.allclose(gt_y, y), 'RoIPool layer incorrect on CPU' + def test_roi_pool_cpu(self): device = torch.device('cpu') x = torch.rand(2, 1, 10, 10, dtype=self.dtype, device=device) @@ -487,6 +492,661 @@ def script_func(input, rois): assert gradcheck(lambda x: script_func(x, rois), (x,)), 'gradcheck failed for scripted roi_align on CUDA' +def bilinear_interpolate(data, height, width, y, x): + if y < -1.0 or y > height or x < -1.0 or x > width: + return 0. + + if y <= 0: + y = 0. + if x <= 0: + x = 0. + + y_low, x_low = int(y), int(x) + y_high, x_high = 0, 0 + + if y_low >= height - 1: + y_high = y_low = height - 1 + y = float(y_low) + else: + y_high = y_low + 1 + + if x_low >= width - 1: + x_high = x_low = width - 1 + x = float(x_low) + else: + x_high = x_low + 1 + + ly = y - y_low + lx = x - x_low + hy, hx = 1. - ly, 1. - lx + + v1 = data[y_low * width + x_low] + v2 = data[y_low * width + x_high] + v3 = data[y_high * width + x_low] + v4 = data[y_high * width + x_high] + w1, w2, w3, w4 = hy * hx, hy * lx, ly * hx, ly * lx + + return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 + + +class PSRoIAlignTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dtype = torch.float64 + + def slow_ps_roi_align(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, + sampling_ratio=-1, dtype=torch.float64): + if device is None: + device = torch.device("cpu") + num_input_channels = in_data.size(1) + assert num_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw" + num_output_channels = int(num_input_channels / (pool_h * pool_w)) + out_data = torch.zeros(rois.size(0), num_output_channels, pool_h, pool_w, dtype=dtype, device=device) + + for n in range(0, in_data.size(0)): + for r, roi in enumerate(rois): + if roi[0] != n: + continue + roi[1:] = (roi[1:] * spatial_scale) - 0.5 + c_in = 0 + roi_height = float(roi[4].item() - roi[2].item()) + roi_width = float(roi[3].item() - roi[1].item()) + bin_h, bin_w = roi_height / float(pool_h), roi_width / float(pool_w) + for c_out in range(0, num_output_channels): + for j in range(0, pool_h): + start_h = float(j) * bin_h + roi[2].item() + + for i in range(0, pool_w): + start_w = float(i) * bin_w + roi[1].item() + + roi_bin_grid_h = sampling_ratio if sampling_ratio > 0 else int(np.ceil(roi_height / pool_h)) + roi_bin_grid_w = sampling_ratio if sampling_ratio > 0 else int(np.ceil(roi_width / pool_w)) + + val = 0. + for iy in range(0, roi_bin_grid_h): + y = start_h + (iy + 0.5) * bin_h / float(roi_bin_grid_h) + for ix in range(0, roi_bin_grid_w): + x = start_w + (ix + 0.5) * bin_w / float(roi_bin_grid_w) + val += bilinear_interpolate( + in_data[n, c_in, :, :].flatten(), + in_data.size(-2), + in_data.size(-1), + y, x + ) + count = roi_bin_grid_h * roi_bin_grid_w + out_data[r, c_out, j, i] = val / count + c_in += 1 + return out_data + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_ps_roi_align_basic_cuda(self): + device = torch.device('cuda') + pool_size = 3 + x = torch.rand(1, 2 * (pool_size ** 2), 7, 7, dtype=self.dtype, device=device) + rois = torch.tensor([[0, 0, 0, 5, 5]], # format is (xyxy) + dtype=self.dtype, device=device) + + pool_h, pool_w = (pool_size, pool_size) + ps_roi_align = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2) + y = ps_roi_align(x, rois) + + gt_y = self.slow_ps_roi_align(x, rois, pool_h, pool_w, device, + spatial_scale=1, sampling_ratio=2, + dtype=self.dtype) + assert torch.allclose(gt_y.cuda(), y), 'PSRoIAlign layer incorrect' + + y = ps_roi_align(x.permute(0, 1, 3, 2), rois) + gt_y = self.slow_ps_roi_align(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device, + spatial_scale=1, sampling_ratio=-1, + dtype=self.dtype) + assert torch.allclose(gt_y.cuda(), y), 'PSRoIAlign layer incorrect' + + def test_ps_roi_align_basic_cpu(self): + device = torch.device('cpu') + pool_size = 3 + x = torch.rand(1, 2 * (pool_size ** 2), 7, 7, dtype=self.dtype, device=device) + rois = torch.tensor([[0, 0, 0, 5, 5]], # format is (xyxy) + dtype=self.dtype, device=device) + + pool_h, pool_w = (pool_size, pool_size) + ps_roi_align = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2) + y = ps_roi_align(x, rois) + + gt_y = self.slow_ps_roi_align(x, rois, pool_h, pool_w, device, + spatial_scale=1, sampling_ratio=2, + dtype=self.dtype) + assert torch.allclose(gt_y, y), 'PSRoIAlign layer incorrect on CPU' + + y = ps_roi_align(x.permute(0, 1, 3, 2), rois) + gt_y = self.slow_ps_roi_align(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device, + spatial_scale=1, sampling_ratio=-1, + dtype=self.dtype) + assert torch.allclose(gt_y, y), 'PSRoIAlign layer incorrect on CPU' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_ps_roi_align_cuda(self): + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + pool_size = 5 + x = torch.rand(2, 2 * (pool_size ** 2), 10, 10, dtype=self.dtype, device=device) + rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) + [0, 0, 5, 4, 9], + [0, 5, 5, 9, 9], + [1, 0, 0, 9, 9]], + dtype=self.dtype, device=device) + + pool_h, pool_w = (pool_size, pool_size) + ps_roi_align = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2) + y = ps_roi_align(x, rois) + + gt_y = self.slow_ps_roi_align(x, rois, pool_h, pool_w, device, + spatial_scale=1, sampling_ratio=2, + dtype=self.dtype) + assert torch.allclose(gt_y.cuda(), y), 'PSRoIAlign layer incorrect' + + y = ps_roi_align(x.permute(0, 1, 3, 2), rois) + gt_y = self.slow_ps_roi_align(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, + device, spatial_scale=1, sampling_ratio=2, + dtype=self.dtype) + assert torch.allclose(gt_y.cuda(), y), 'PSRoIAlign layer incorrect' + + def test_ps_roi_align_cpu(self): + device = torch.device('cpu') + pool_size = 5 + x = torch.rand(2, 2 * (pool_size ** 2), 10, 10, dtype=self.dtype, device=device) + rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) + [0, 0, 5, 4, 9], + [0, 5, 5, 9, 9], + [1, 0, 0, 9, 9]], + dtype=self.dtype, device=device) + + pool_h, pool_w = (pool_size, pool_size) + ps_roi_align = ops.PSRoIAlign((pool_h, pool_w), spatial_scale=1, sampling_ratio=2) + y = ps_roi_align(x, rois) + + gt_y = self.slow_ps_roi_align(x, rois, pool_h, pool_w, device, + spatial_scale=1, sampling_ratio=2, + dtype=self.dtype) + assert torch.allclose(gt_y, y), 'PSRoIAlign layer incorrect on CPU' + + y = ps_roi_align(x.permute(0, 1, 3, 2), rois) + gt_y = self.slow_ps_roi_align(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, + device, spatial_scale=1, sampling_ratio=2, + dtype=self.dtype) + assert torch.allclose(gt_y, y), 'PSRoIAlign layer incorrect on CPU' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_ps_roi_align_gradient_cuda(self): + device = torch.device('cuda') + pool_size = 3 + layer = ops.PSRoIAlign((pool_size, pool_size), spatial_scale=1, + sampling_ratio=-1).to(dtype=self.dtype, device=device) + x = torch.ones(1, pool_size ** 2, 5, 5, dtype=self.dtype, device=device, requires_grad=True) + rois = torch.tensor([ + [0, 0, 0, 4, 4], + [0, 0, 3, 5, 5], + [0, 1, 0, 2, 4]], + dtype=self.dtype, device=device) + + y = layer(x, rois) + s = y.sum() + s.backward() + gt_grad = torch.tensor([[[[8.125e-01, 6.875e-01, 0.0, 0.0, 0.0, ], + [2.7083333333e-01, 2.2916666667e-01, 0.0, 0.0, 0.0, ], + [1.0416666667e-01, 6.25e-02, 0.0, 0.0, 0.0, ], + [5.2083333333e-01, 3.125e-01, 0.0, 0.0, 0.0, ], + [0.0, 0.0, 0.0, 0.0, 0.0, ]], + [[8.3266726847e-17, 1.125e00, 3.750e-01, 0.0, 0.0, ], + [2.7755575616e-17, 3.750e-01, 1.250e-01, 0.0, 0.0, ], + [0.0, 3.4722222222e-02, 9.7222222222e-02, 3.4722222222e-02, 0.0, ], + [0.0, 1.7361111111e-01, 4.8611111111e-01, 1.7361111111e-01, 0.0, ], + [0.0, 0.0, 0.0, 0.0, 0.0, ]], + [[0.0, 5.000e-01, 4.375e-01, 5.000e-01, 6.25e-02, ], + [0.0, 1.6666666667e-01, 1.4583333333e-01, 1.6666666667e-01, 2.0833333333e-02, ], + [0.0, 0.0, 0.0, 6.25e-02, 1.0416666667e-01, ], + [0.0, 0.0, 0.0, 3.125e-01, 5.2083333333e-01, ], + [0.0, 0.0, 0.0, 0.0, 0.0, ]], + [[0.0, 0.0, 0.0, 0.0, 0.0, ], + [5.4166666667e-01, 4.5833333333e-01, 0.0, 0.0, 0.0, ], + [5.4166666667e-01, 4.5833333333e-01, 0.0, 0.0, 0.0, ], + [3.125e-01, 1.875e-01, 0.0, 0.0, 0.0, ], + [3.125e-01, 1.875e-01, 0.0, 0.0, 0.0, ]], + [[0.0, 0.0, 0.0, 0.0, 0.0, ], + [5.5511151231e-17, 7.500e-01, 2.500e-01, 0.0, 0.0, ], + [5.5511151231e-17, 7.500e-01, 2.500e-01, 0.0, 0.0, ], + [0.0, 1.0416666667e-01, 2.9166666667e-01, 1.0416666667e-01, 0.0, ], + [0.0, 1.0416666667e-01, 2.9166666667e-01, 1.0416666667e-01, 0.0, ]], + [[0.0, 0.0, 0.0, 0.0, 0.0, ], + [0.0, 3.3333333333e-01, 2.9166666667e-01, 3.3333333333e-01, 4.1666666667e-02, ], + [0.0, 3.3333333333e-01, 2.9166666667e-01, 3.3333333333e-01, 4.1666666667e-02, ], + [0.0, 0.0, 0.0, 1.875e-01, 3.125e-01, ], + [0.0, 0.0, 0.0, 1.875e-01, 3.125e-01, ]], + [[0.0, 0.0, 0.0, 0.0, 0.0, ], + [0.0, 0.0, 0.0, 0.0, 0.0, ], + [2.7083333333e-01, 2.2916666667e-01, 0.0, 0.0, 0.0, ], + [7.2222222222e-01, 6.1111111111e-01, 0.0, 0.0, 0.0, ], + [7.1527777778e-01, 4.5138888889e-01, 0.0, 0.0, 0.0, ]], + [[0.0, 0.0, 0.0, 0.0, 0.0, ], + [0.0, 0.0, 0.0, 0.0, 0.0, ], + [2.7755575616e-17, 3.750e-01, 1.250e-01, 0.0, 0.0, ], + [7.4014868308e-17, 1.000e00, 3.3333333333e-01, 0.0, 0.0, ], + [9.2518585385e-18, 3.3333333333e-01, 6.25e-01, 2.0833333333e-01, 0.0, ]], + [[0.0, 0.0, 0.0, 0.0, 0.0, ], + [0.0, 0.0, 0.0, 0.0, 0.0, ], + [0.0, 1.6666666667e-01, 1.4583333333e-01, 1.6666666667e-01, 2.0833333333e-02, ], + [0.0, 4.4444444444e-01, 3.8888888889e-01, 4.4444444444e-01, 5.5555555556e-02, ], + [0.0, 5.5555555556e-02, 4.8611111111e-02, 4.3055555556e-01, 6.3194444444e-01, ]]]], + device=device, dtype=self.dtype) + assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for PSRoIAlign' + + def test_ps_roi_align_gradient_cpu(self): + device = torch.device('cpu') + pool_size = 3 + layer = ops.PSRoIAlign((pool_size, pool_size), spatial_scale=1, + sampling_ratio=-1).to(dtype=self.dtype, device=device) + x = torch.ones(1, pool_size ** 2, 5, 5, dtype=self.dtype, device=device, requires_grad=True) + rois = torch.tensor([ + [0, 0, 0, 4, 4], + [0, 0, 3, 5, 5], + [0, 1, 0, 2, 4]], + dtype=self.dtype, device=device) + + y = layer(x, rois) + s = y.sum() + s.backward() + gt_grad = torch.tensor([[[[8.125e-01, 6.875e-01, 0.0, 0.0, 0.0, ], + [2.7083333333e-01, 2.2916666667e-01, 0.0, 0.0, 0.0, ], + [1.0416666667e-01, 6.25e-02, 0.0, 0.0, 0.0, ], + [5.2083333333e-01, 3.125e-01, 0.0, 0.0, 0.0, ], + [0.0, 0.0, 0.0, 0.0, 0.0, ]], + [[8.3266726847e-17, 1.125e00, 3.750e-01, 0.0, 0.0, ], + [2.7755575616e-17, 3.750e-01, 1.250e-01, 0.0, 0.0, ], + [0.0, 3.4722222222e-02, 9.7222222222e-02, 3.4722222222e-02, 0.0, ], + [0.0, 1.7361111111e-01, 4.8611111111e-01, 1.7361111111e-01, 0.0, ], + [0.0, 0.0, 0.0, 0.0, 0.0, ]], + [[0.0, 5.000e-01, 4.375e-01, 5.000e-01, 6.25e-02, ], + [0.0, 1.6666666667e-01, 1.4583333333e-01, 1.6666666667e-01, 2.0833333333e-02, ], + [0.0, 0.0, 0.0, 6.25e-02, 1.0416666667e-01, ], + [0.0, 0.0, 0.0, 3.125e-01, 5.2083333333e-01, ], + [0.0, 0.0, 0.0, 0.0, 0.0, ]], + [[0.0, 0.0, 0.0, 0.0, 0.0, ], + [5.4166666667e-01, 4.5833333333e-01, 0.0, 0.0, 0.0, ], + [5.4166666667e-01, 4.5833333333e-01, 0.0, 0.0, 0.0, ], + [3.125e-01, 1.875e-01, 0.0, 0.0, 0.0, ], + [3.125e-01, 1.875e-01, 0.0, 0.0, 0.0, ]], + [[0.0, 0.0, 0.0, 0.0, 0.0, ], + [5.5511151231e-17, 7.500e-01, 2.500e-01, 0.0, 0.0, ], + [5.5511151231e-17, 7.500e-01, 2.500e-01, 0.0, 0.0, ], + [0.0, 1.0416666667e-01, 2.9166666667e-01, 1.0416666667e-01, 0.0, ], + [0.0, 1.0416666667e-01, 2.9166666667e-01, 1.0416666667e-01, 0.0, ]], + [[0.0, 0.0, 0.0, 0.0, 0.0, ], + [0.0, 3.3333333333e-01, 2.9166666667e-01, 3.3333333333e-01, 4.1666666667e-02, ], + [0.0, 3.3333333333e-01, 2.9166666667e-01, 3.3333333333e-01, 4.1666666667e-02, ], + [0.0, 0.0, 0.0, 1.875e-01, 3.125e-01, ], + [0.0, 0.0, 0.0, 1.875e-01, 3.125e-01, ]], + [[0.0, 0.0, 0.0, 0.0, 0.0, ], + [0.0, 0.0, 0.0, 0.0, 0.0, ], + [2.7083333333e-01, 2.2916666667e-01, 0.0, 0.0, 0.0, ], + [7.2222222222e-01, 6.1111111111e-01, 0.0, 0.0, 0.0, ], + [7.1527777778e-01, 4.5138888889e-01, 0.0, 0.0, 0.0, ]], + [[0.0, 0.0, 0.0, 0.0, 0.0, ], + [0.0, 0.0, 0.0, 0.0, 0.0, ], + [2.7755575616e-17, 3.750e-01, 1.250e-01, 0.0, 0.0, ], + [7.4014868308e-17, 1.000e00, 3.3333333333e-01, 0.0, 0.0, ], + [9.2518585385e-18, 3.3333333333e-01, 6.25e-01, 2.0833333333e-01, 0.0, ]], + [[0.0, 0.0, 0.0, 0.0, 0.0, ], + [0.0, 0.0, 0.0, 0.0, 0.0, ], + [0.0, 1.6666666667e-01, 1.4583333333e-01, 1.6666666667e-01, 2.0833333333e-02, ], + [0.0, 4.4444444444e-01, 3.8888888889e-01, 4.4444444444e-01, 5.5555555556e-02, ], + [0.0, 5.5555555556e-02, 4.8611111111e-02, 4.3055555556e-01, 6.3194444444e-01, ]]]], + device=device, dtype=self.dtype) + assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for PSRoIAlign on CPU' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_ps_roi_align_gradcheck_cuda(self): + device = torch.device('cuda') + pool_size = 5 + x = torch.rand(1, pool_size ** 2, 10, 10, dtype=self.dtype, device=device, requires_grad=True) + rois = torch.tensor([ + [0, 0, 0, 9, 9], + [0, 0, 5, 5, 9], + [0, 5, 5, 9, 9]], dtype=self.dtype, device=device) + + m = ops.PSRoIAlign((pool_size, pool_size), spatial_scale=1, + sampling_ratio=2).to(dtype=self.dtype, device=device) + + def func(input): + return m(input, rois) + + assert gradcheck(func, (x,)), 'gradcheck failed for PSRoIAlign CUDA' + assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for PSRoIAlign CUDA' + + def test_ps_roi_align_gradcheck_cpu(self): + device = torch.device('cpu') + pool_size = 5 + x = torch.rand(1, pool_size ** 2, 10, 10, dtype=self.dtype, device=device, requires_grad=True) + rois = torch.tensor([ + [0, 0, 0, 9, 9], + [0, 0, 5, 5, 9], + [0, 5, 5, 9, 9]], dtype=self.dtype, device=device) + + m = ops.PSRoIAlign((pool_size, pool_size), spatial_scale=1, + sampling_ratio=2).to(dtype=self.dtype, device=device) + + def func(input): + return m(input, rois) + + assert gradcheck(func, (x,)), 'gradcheck failed for PSRoIAlign on CPU' + assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for PSRoIAlign on CPU' + + +class PSRoIPoolTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dtype = torch.float64 + + def slow_ps_roi_pooling(self, x, rois, pool_h, pool_w, device, spatial_scale=1, + dtype=torch.float64): + if device is None: + device = torch.device("cpu") + num_input_channels = x.size(1) + assert num_input_channels % (pool_h * pool_w) == 0, "input channels must be divisible by ph * pw" + num_output_channels = int(num_input_channels / (pool_h * pool_w)) + y = torch.zeros(rois.size(0), num_output_channels, pool_h, pool_w, dtype=dtype, device=device) + + rois = torch.round(rois * spatial_scale).int() + for n in range(0, x.size(0)): + for r, roi in enumerate(rois): + if roi[0] != n: + continue + c_in = 0 + for c_out in range(0, num_output_channels): + roi_height = max(roi[4].item() - roi[2].item(), 1) + roi_width = max(roi[3].item() - roi[1].item(), 1) + bin_h, bin_w = roi_height / float(pool_h), roi_width / float(pool_w) + + for j in range(0, pool_h): + start_h = int(np.floor(j * bin_h)) + roi[2].item() + end_h = int(np.ceil((j + 1) * bin_w)) + roi[2].item() + + # range-check + start_h = min(max(start_h, 0), x.size(2)) + end_h = min(max(end_h, 0), x.size(2)) + + for i in range(0, pool_w): + start_w = int(np.floor(i * bin_w)) + roi[1].item() + end_w = int(np.ceil((i + 1) * bin_w)) + roi[1].item() + + # range-check + start_w = min(max(start_w, 0), x.size(3)) + end_w = min(max(end_w, 0), x.size(3)) + + is_empty = (end_h <= start_h) or (end_w <= start_w) + area = (end_h - start_h) * (end_w - start_w) + + if not is_empty: + t = torch.sum(x[n, c_in, slice(start_h, end_h), slice(start_w, end_w)]) + y[r, c_out, j, i] = t / area + c_in += 1 + return y + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_ps_roi_pool_basic_cuda(self): + device = torch.device('cuda') + pool_size = 3 + x = torch.rand(1, pool_size ** 2, 10, 10, dtype=self.dtype, device=device) + rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy) + dtype=self.dtype, device=device) + + pool_h, pool_w = (pool_size, pool_size) + ps_roi_pool = ops.PSRoIPool((pool_h, pool_w), 1) + y = ps_roi_pool(x, rois) + + gt_y = self.slow_ps_roi_pooling(x, rois, pool_h, pool_w, device, dtype=self.dtype) + assert torch.allclose(gt_y.cuda(), y), 'PSRoIPool layer incorrect' + + y = ps_roi_pool(x.permute(0, 1, 3, 2), rois) + gt_y = self.slow_ps_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device, dtype=self.dtype) + assert torch.allclose(gt_y.cuda(), y), 'PSRoIPool layer incorrect' + + def test_ps_roi_pool_basic_cpu(self): + device = torch.device('cpu') + pool_size = 3 + x = torch.rand(1, pool_size ** 2, 10, 10, dtype=self.dtype, device=device) + rois = torch.tensor([[0, 0, 0, 4, 4]], # format is (xyxy) + dtype=self.dtype, device=device) + + pool_h, pool_w = (pool_size, pool_size) + ps_roi_pool = ops.PSRoIPool((pool_h, pool_w), 1) + y = ps_roi_pool(x, rois) + + gt_y = self.slow_ps_roi_pooling(x, rois, pool_h, pool_w, device, dtype=self.dtype) + assert torch.allclose(gt_y, y), 'PSRoIPool layer incorrect on CPU' + + y = ps_roi_pool(x.permute(0, 1, 3, 2), rois) + gt_y = self.slow_ps_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device, dtype=self.dtype) + assert torch.allclose(gt_y, y), 'PSRoIPool layer incorrect on CPU' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_ps_roi_pool_cuda(self): + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + pool_size = 5 + x = torch.rand(2, 2 * (pool_size ** 2), 10, 10, dtype=self.dtype, device=device) + rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) + [0, 0, 5, 4, 9], + [0, 5, 5, 9, 9], + [1, 0, 0, 9, 9]], + dtype=self.dtype, device=device) + + pool_h, pool_w = (pool_size, pool_size) + ps_roi_pool = ops.PSRoIPool((pool_h, pool_w), 1) + y = ps_roi_pool(x, rois) + + gt_y = self.slow_ps_roi_pooling(x, rois, pool_h, pool_w, device, dtype=self.dtype) + + assert torch.allclose(gt_y.cuda(), y), 'PSRoIPool layer incorrect' + + y = ps_roi_pool(x.permute(0, 1, 3, 2), rois) + gt_y = self.slow_ps_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device, dtype=self.dtype) + assert torch.allclose(gt_y.cuda(), y), 'PSRoIPool layer incorrect' + + def test_ps_roi_pool_cpu(self): + device = torch.device('cpu') + pool_size = 5 + x = torch.rand(2, 2 * (pool_size ** 2), 10, 10, dtype=self.dtype, device=device) + rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) + [0, 0, 5, 4, 9], + [0, 5, 5, 9, 9], + [1, 0, 0, 9, 9]], + dtype=self.dtype, device=device) + + pool_h, pool_w = (pool_size, pool_size) + ps_roi_pool = ops.PSRoIPool((pool_h, pool_w), 1) + y = ps_roi_pool(x, rois) + + gt_y = self.slow_ps_roi_pooling(x, rois, pool_h, pool_w, device, dtype=self.dtype) + assert torch.allclose(gt_y, y), 'PSRoIPool layer incorrect on CPU' + + y = ps_roi_pool(x.permute(0, 1, 3, 2), rois) + gt_y = self.slow_ps_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device, dtype=self.dtype) + assert torch.allclose(gt_y, y), 'PSRoIPool layer incorrect on CPU' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_ps_roi_pool_gradient_cuda(self): + device = torch.device('cuda') + pool_size = 3 + layer = ops.PSRoIPool((pool_size, pool_size), 1).to(dtype=self.dtype, device=device) + x = torch.ones(1, pool_size ** 2, 5, 5, dtype=self.dtype, device=device, requires_grad=True) + rois = torch.tensor([ + [0, 0, 0, 4, 4], + [0, 0, 3, 5, 5], + [0, 1, 0, 2, 4]], + dtype=self.dtype, device=device) + + y = layer(x, rois) + s = y.sum() + s.backward() + gt_grad = torch.tensor([[[[0.2500, 0.7500, 0.0000, 0.0000, 0.0000], + [0.2500, 0.7500, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.5000, 0.5000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], + + [[0.0000, 0.7500, 0.2500, 0.0000, 0.0000], + [0.0000, 0.7500, 0.2500, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 1. / 3, 1. / 3, 1. / 3, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], + + [[0.0000, 0.5000, 0.2500, 0.2500, 0.0000], + [0.0000, 0.5000, 0.2500, 0.2500, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.5000, 0.5000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], + + [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.2500, 0.7500, 0.0000, 0.0000, 0.0000], + [0.2500, 0.7500, 0.0000, 0.0000, 0.0000], + [0.2500, 0.2500, 0.0000, 0.0000, 0.0000], + [0.2500, 0.2500, 0.0000, 0.0000, 0.0000]], + + [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.7500, 0.2500, 0.0000, 0.0000], + [0.0000, 0.7500, 0.2500, 0.0000, 0.0000], + [0.0000, 1. / 6, 1. / 6, 1. / 6, 0.0000], + [0.0000, 1. / 6, 1. / 6, 1. / 6, 0.0000]], + + [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.5000, 0.2500, 0.2500, 0.0000], + [0.0000, 0.5000, 0.2500, 0.2500, 0.0000], + [0.0000, 0.0000, 0.0000, 0.2500, 0.2500], + [0.0000, 0.0000, 0.0000, 0.2500, 0.2500]], + + [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.2500, 0.7500, 0.0000, 0.0000, 0.0000], + [0.2500, 0.7500, 0.0000, 0.0000, 0.0000], + [0.5000, 0.5000, 0.0000, 0.0000, 0.0000]], + + [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.7500, 0.2500, 0.0000, 0.0000], + [0.0000, 0.7500, 0.2500, 0.0000, 0.0000], + [0.0000, 1. / 3, 1. / 3, 1. / 3, 0.0000]], + + [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.5000, 0.2500, 0.2500, 0.0000], + [0.0000, 0.5000, 0.2500, 0.2500, 0.0000], + [0.0000, 0.0000, 0.0000, 0.5000, 0.5000]]]], + device=device, dtype=self.dtype) + assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for PSRoIPool' + + def test_ps_roi_pool_gradient_cpu(self): + device = torch.device('cpu') + pool_size = 3 + layer = ops.PSRoIPool((pool_size, pool_size), 1).to(dtype=self.dtype, device=device) + x = torch.ones(1, pool_size ** 2, 5, 5, dtype=self.dtype, device=device, requires_grad=True) + rois = torch.tensor([ + [0, 0, 0, 4, 4], + [0, 0, 3, 5, 5], + [0, 1, 0, 2, 4]], + dtype=self.dtype, device=device) + + y = layer(x, rois) + s = y.sum() + s.backward() + gt_grad = torch.tensor([[[[0.2500, 0.7500, 0.0000, 0.0000, 0.0000], + [0.2500, 0.7500, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.5000, 0.5000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], + + [[0.0000, 0.7500, 0.2500, 0.0000, 0.0000], + [0.0000, 0.7500, 0.2500, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 1. / 3, 1. / 3, 1. / 3, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], + + [[0.0000, 0.5000, 0.2500, 0.2500, 0.0000], + [0.0000, 0.5000, 0.2500, 0.2500, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.5000, 0.5000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], + + [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.2500, 0.7500, 0.0000, 0.0000, 0.0000], + [0.2500, 0.7500, 0.0000, 0.0000, 0.0000], + [0.2500, 0.2500, 0.0000, 0.0000, 0.0000], + [0.2500, 0.2500, 0.0000, 0.0000, 0.0000]], + + [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.7500, 0.2500, 0.0000, 0.0000], + [0.0000, 0.7500, 0.2500, 0.0000, 0.0000], + [0.0000, 1. / 6, 1. / 6, 1. / 6, 0.0000], + [0.0000, 1. / 6, 1. / 6, 1. / 6, 0.0000]], + + [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.5000, 0.2500, 0.2500, 0.0000], + [0.0000, 0.5000, 0.2500, 0.2500, 0.0000], + [0.0000, 0.0000, 0.0000, 0.2500, 0.2500], + [0.0000, 0.0000, 0.0000, 0.2500, 0.2500]], + + [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.2500, 0.7500, 0.0000, 0.0000, 0.0000], + [0.2500, 0.7500, 0.0000, 0.0000, 0.0000], + [0.5000, 0.5000, 0.0000, 0.0000, 0.0000]], + + [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.7500, 0.2500, 0.0000, 0.0000], + [0.0000, 0.7500, 0.2500, 0.0000, 0.0000], + [0.0000, 1. / 3, 1. / 3, 1. / 3, 0.0000]], + + [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.5000, 0.2500, 0.2500, 0.0000], + [0.0000, 0.5000, 0.2500, 0.2500, 0.0000], + [0.0000, 0.0000, 0.0000, 0.5000, 0.5000]]]], + device=device, dtype=self.dtype) + assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for PSRoIPool on CPU' + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_ps_roi_pool_gradcheck_cuda(self): + device = torch.device('cuda') + pool_size = 5 + x = torch.rand(1, pool_size ** 2, 10, 10, dtype=self.dtype, device=device, requires_grad=True) + rois = torch.tensor([ + [0, 0, 0, 9, 9], + [0, 0, 5, 5, 9], + [0, 5, 5, 9, 9]], dtype=self.dtype, device=device) + + m = ops.PSRoIPool((pool_size, pool_size), 1).to(dtype=self.dtype, device=device) + + def func(input): + return m(input, rois) + + assert gradcheck(func, (x,)), 'gradcheck failed for PSRoIPool CUDA' + assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for PSRoIPool CUDA' + + def test_ps_roi_pool_gradcheck_cpu(self): + device = torch.device('cpu') + pool_size = 5 + x = torch.rand(1, pool_size ** 2, 10, 10, dtype=self.dtype, device=device, requires_grad=True) + rois = torch.tensor([ + [0, 0, 0, 9, 9], + [0, 0, 5, 5, 9], + [0, 5, 5, 9, 9]], dtype=self.dtype, device=device) + + m = ops.PSRoIPool((pool_size, pool_size), 1).to(dtype=self.dtype, device=device) + + def func(input): + return m(input, rois) + + assert gradcheck(func, (x,)), 'gradcheck failed for PSRoIPool on CPU' + assert gradcheck(func, (x.permute(0, 1, 3, 2),)), 'gradcheck failed for PSRoIPool on CPU' + + class NMSTester(unittest.TestCase): def reference_nms(self, boxes, scores, iou_threshold): """ diff --git a/torchvision/csrc/PSROIAlign.h b/torchvision/csrc/PSROIAlign.h new file mode 100644 index 00000000000..a57be93c540 --- /dev/null +++ b/torchvision/csrc/PSROIAlign.h @@ -0,0 +1,80 @@ +#pragma once + +#include "cpu/vision_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/vision_cuda.h" +#endif + +std::tuple PSROIAlign_forward( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return PSROIAlign_forward_cuda( + input, + rois, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + return PSROIAlign_forward_cpu( + input, + rois, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio); +} + +at::Tensor PSROIAlign_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& mapping_channel, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const int batch_size, + const int channels, + const int height, + const int width) { + if (grad.type().is_cuda()) { +#ifdef WITH_CUDA + return PSROIAlign_backward_cuda( + grad, + rois, + mapping_channel, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + return PSROIAlign_backward_cpu( + grad, + rois, + mapping_channel, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio, + batch_size, + channels, + height, + width); +} diff --git a/torchvision/csrc/PSROIPool.h b/torchvision/csrc/PSROIPool.h new file mode 100644 index 00000000000..70ac70df75f --- /dev/null +++ b/torchvision/csrc/PSROIPool.h @@ -0,0 +1,66 @@ +#pragma once + +#include "cpu/vision_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/vision_cuda.h" +#endif + +std::tuple PSROIPool_forward( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) { + if (input.type().is_cuda()) { +#ifdef WITH_CUDA + return PSROIPool_forward_cuda( + input, rois, spatial_scale, pooled_height, pooled_width); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + return PSROIPool_forward_cpu( + input, rois, spatial_scale, pooled_height, pooled_width); +} + +at::Tensor PSROIPool_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& mapping_channel, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) { + if (grad.type().is_cuda()) { +#ifdef WITH_CUDA + return PSROIPool_backward_cuda( + grad, + rois, + mapping_channel, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + return PSROIPool_backward_cpu( + grad, + rois, + mapping_channel, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); +} diff --git a/torchvision/csrc/cpu/PSROIAlign_cpu.cpp b/torchvision/csrc/cpu/PSROIAlign_cpu.cpp new file mode 100644 index 00000000000..274038c3395 --- /dev/null +++ b/torchvision/csrc/cpu/PSROIAlign_cpu.cpp @@ -0,0 +1,414 @@ +#include +#include +#include + +template +T bilinear_interpolate( + const T* input, + const int height, + const int width, + T y, + T x, + const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return 0; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // do bilinear interpolation + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +void PSROIAlignForwardCPU( + const int nthreads, + const T* input, + const T spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const T* rois, + const int channels_out, + T* output, + int* channel_mapping) { + int num_rois = nthreads / channels_out / pooled_width / pooled_height; + for (int n = 0; n < num_rois; n++) { + + // [start, end) interval for spatial sampling + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + int c_in = 0; + for (int c_out = 0; c_out < channels_out; ++c_out) { + + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + + int index = + ((n * channels_out + c_out) * pooled_height + ph) * pooled_width + pw; + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + const T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + + T out_sum = 0; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = bilinear_interpolate(offset_input, height, width, y, x, index); + out_sum += val; + } + } + + out_sum /= count; + output[index] = out_sum; + channel_mapping[index] = c_in; + c_in++; + } + } + } + } +} + +template +void bilinear_interpolate_gradient( + const int height, + const int width, + T y, + T x, + T& w1, + T& w2, + T& w3, + T& w4, + int& x_low, + int& x_high, + int& y_low, + int& y_high, + const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} + +template +inline void add(T* address, const T& val) { + *address += val; +} + +template +void PSROIAlignBackwardCPU( + const int nthreads, + const T* grad_output, + const int* channel_mapping, + const int num_rois, + const T spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const int channels_out, + T* grad_input, + const T* rois) { + for (int index = 0; index < nthreads; index++) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int n = index / pooled_width / pooled_height / channels_out; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + // Force too small ROIs to be 1x1 + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + int c_in = channel_mapping[index]; + T* grad_input_offset = + grad_input + (roi_batch_ind * channels + c_in) * height * width; + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + const T grad_output_this_bin = grad_output[index]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + add(grad_input_offset + y_low * width + x_low, g1); + add(grad_input_offset + y_low * width + x_high, g2); + add(grad_input_offset + y_high * width + x_low, g3); + add(grad_input_offset + y_high * width + x_high, g4); + } // if + } // ix + } // iy + } +} + +std::tuple PSROIAlign_forward_cpu( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio) { + // Check if input tensors are CPU tensors + AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "PSROIAlign_forward_cpu"; + at::checkAllSameType(c, {input_t, rois_t}); + + int num_rois = rois.size(0); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + AT_ASSERTM( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + int channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros( + {num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = + at::zeros(output.sizes(), input.options().dtype(at::kInt)); + + auto output_size = output.numel(); + if (output_size == 0) { + return std::make_tuple(output, channel_mapping); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "PSROIAlign_forward", [&] { + PSROIAlignForwardCPU( + output_size, + input.contiguous().data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois.contiguous().data(), + channels_out, + output.data(), + channel_mapping.data()); + }); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor PSROIAlign_backward_cpu( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const int batch_size, + const int channels, + const int height, + const int width) { + // Check if input tensors are CPU tensors + AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + AT_ASSERTM( + channel_mapping.device().is_cpu(), + "channel_mapping must be a CPU tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "PSROIAlign_backward_cpu"; + at::checkAllSameType(c, {grad_t, rois_t}); + + auto num_rois = rois.size(0); + auto grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + int channels_out = channels / (pooled_height * pooled_width); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "PSROIAlign_backward", [&] { + PSROIAlignBackwardCPU( + grad.numel(), + grad.contiguous().data(), + channel_mapping.data(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + channels_out, + grad_input.data(), + rois.contiguous().data()); + }); + return grad_input; +} diff --git a/torchvision/csrc/cpu/PSROIPool_cpu.cpp b/torchvision/csrc/cpu/PSROIPool_cpu.cpp new file mode 100644 index 00000000000..92614f11385 --- /dev/null +++ b/torchvision/csrc/cpu/PSROIPool_cpu.cpp @@ -0,0 +1,250 @@ +#include +#include +#include +#include + +template +inline void add(T* address, const T& val) { + *address += val; +} + +template +void PSROIPoolForward( + const T* input, + const T spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const T* rois, + const int channels_out, + const int num_rois, + T* output, + int* channel_mapping) { + for (int n = 0; n < num_rois; ++n) { + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = round(offset_rois[1] * spatial_scale); + int roi_start_h = round(offset_rois[2] * spatial_scale); + int roi_end_w = round(offset_rois[3] * spatial_scale); + int roi_end_h = round(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + int roi_width = std::max(roi_end_w - roi_start_w, 1); + int roi_height = std::max(roi_end_h - roi_start_h, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + int c_in = 0; + for (int c_out = 0; c_out < channels_out; ++c_out) { + + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart + roi_start_h, 0), height - 1); + hend = std::min(std::max(hend + roi_start_h, 0), height - 1); + wstart = std::min(std::max(wstart + roi_start_w, 0), width - 1); + wend = std::min(std::max(wend + roi_start_w, 0), width - 1); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + const T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + + T out_sum = 0; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_index = h * width + w; + out_sum += offset_input[input_index]; + } + } + + int index = + ((n * channels_out + c_out) * pooled_height + ph) * pooled_width + pw; + T bin_area = (hend - hstart) * (wend - wstart); + output[index] = is_empty ? static_cast(0) : out_sum / bin_area; + channel_mapping[index] = c_in; + c_in++; + } + } + } + } +} + +template +void PSROIPoolBackward( + const T* grad_output, + const int* channel_mapping, + const int num_rois, + const T spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int channels_out, + T* grad_input, + const T* rois) { + for (int n = 0; n < num_rois; ++n) { + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = roundf(offset_rois[1] * spatial_scale); + int roi_start_h = roundf(offset_rois[2] * spatial_scale); + int roi_end_w = roundf(offset_rois[3] * spatial_scale); + int roi_end_h = roundf(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + int roi_width = std::max(roi_end_w - roi_start_w, 1); + int roi_height = std::max(roi_end_h - roi_start_h, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = std::min(std::max(hstart + roi_start_h, 0), height); + hend = std::min(std::max(hend + roi_start_h, 0), height); + wstart = std::min(std::max(wstart + roi_start_w, 0), width); + wend = std::min(std::max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + for (int c_out = 0; c_out < channels_out; ++c_out) { + + int index = + ((n * channels_out + c_out) * pooled_height + ph) * pooled_width + pw; + int c_in = channel_mapping[index]; + + T* grad_input_offset = + grad_input + (roi_batch_ind * channels + c_in) * height * width; + T bin_area = (hend - hstart) * (wend - wstart); + T diff_val = is_empty ? static_cast(0) : grad_output[index] / bin_area; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int grad_input_index = h * width + w; + add(grad_input_offset + grad_input_index, diff_val); + } + } + } + } + } + } +} + +std::tuple PSROIPool_forward_cpu( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) { + // Check if input tensors are CPU tensors + AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "PSROIPool_forward_cpu"; + at::checkAllSameType(c, {input_t, rois_t}); + + int num_rois = rois.size(0); + int channels = input.size(1); + int height = input.size(2); + int width = input.size(3); + + AT_ASSERTM( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + int channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros( + {num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = + at::zeros(output.sizes(), input.options().dtype(at::kInt)); + + auto output_size = output.numel(); + if (output_size == 0) { + return std::make_tuple(output, channel_mapping); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "PSROIPool_forward", [&] { + PSROIPoolForward( + input.contiguous().data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois.contiguous().data(), + channels_out, + num_rois, + output.data(), + channel_mapping.data()); + }); + return std::make_tuple(output, channel_mapping); + } + + at::Tensor PSROIPool_backward_cpu( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) { + // Check if input tensors are CPU tensors + AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + AT_ASSERTM( + channel_mapping.device().is_cpu(), + "channel_mapping must be a CPU tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "PSROIPool_backward_cpu"; + at::checkAllSameType(c, {grad_t, rois_t}); + + auto num_rois = rois.size(0); + auto grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + int channels_out = channels / (pooled_height * pooled_width); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "PSROIPool_backward", [&] { + PSROIPoolBackward( + grad.contiguous().data(), + channel_mapping.data(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + channels_out, + grad_input.data(), + rois.contiguous().data()); + }); + return grad_input; +} diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index bd6719fdbeb..d84b172ba49 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -40,6 +40,46 @@ at::Tensor ROIAlign_backward_cpu( const int width, const int sampling_ratio); +std::tuple PSROIPool_forward_cpu( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width); + +at::Tensor PSROIPool_backward_cpu( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& mapping_channel, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width); + +std::tuple PSROIAlign_forward_cpu( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio); + +at::Tensor PSROIAlign_backward_cpu( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& mapping_channel, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const int batch_size, + const int channels, + const int height, + const int width); + at::Tensor nms_cpu( const at::Tensor& dets, const at::Tensor& scores, diff --git a/torchvision/csrc/cuda/PSROIAlign_cuda.cu b/torchvision/csrc/cuda/PSROIAlign_cuda.cu new file mode 100644 index 00000000000..3237c7f7689 --- /dev/null +++ b/torchvision/csrc/cuda/PSROIAlign_cuda.cu @@ -0,0 +1,429 @@ +#include +#include +#include +#include +#include +#include + +#include "cuda_helpers.h" + +template +__device__ T bilinear_interpolate( + const T* input, + const int height, + const int width, + T y, + T x, + const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return 0; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // do bilinear interpolation + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +__global__ void PSROIAlignForwardCUDA( + const int nthreads, + const T* input, + const T spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const T* rois, + const int channels_out, + T* output, + int* channel_mapping) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c_out, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c_out = (index / pooled_width / pooled_height) % channels_out; + int n = index / pooled_width / pooled_height / channels_out; + + // (n, c_in, ph, pw) is the associated element in the input + int c_in = (c_out * pooled_height + ph) * pooled_width + pw; + + // [start, end) interval for spatial sampling + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + const T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + T out_sum = 0; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + T val = bilinear_interpolate(offset_input, height, width, y, x, index); + out_sum += val; + } + } + + out_sum /= count; + output[index] = out_sum; + channel_mapping[index] = c_in; + } +} + +template +__device__ void bilinear_interpolate_gradient( + const int height, + const int width, + T y, + T x, + T& w1, + T& w2, + T& w3, + T& w4, + int& x_low, + int& x_high, + int& y_low, + int& y_high, + const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y <= 0) + y = 0; + if (x <= 0) + x = 0; + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} + +template +__global__ void PSROIAlignBackwardCUDA( + const int nthreads, + const T* grad_output, + const int* channel_mapping, + const int num_rois, + const T spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const int channels_out, + T* grad_input, + const T* rois) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, *, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int n = index / pooled_width / pooled_height / channels_out; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[1] * spatial_scale - static_cast(0.5); + T roi_start_h = offset_rois[2] * spatial_scale - static_cast(0.5); + T roi_end_w = offset_rois[3] * spatial_scale - static_cast(0.5); + T roi_end_h = offset_rois[4] * spatial_scale - static_cast(0.5); + + // Force too small ROIs to be 1x1 + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + int c_in = channel_mapping[index]; + T* grad_input_offset = + grad_input + (roi_batch_ind * channels + c_in) * height * width; + + // Do not using floor/ceil; this implementation detail is critical + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + + const T grad_output_this_bin = grad_output[index]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + const T count = roi_bin_grid_h * roi_bin_grid_w; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T y = hstart + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = wstart + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, + width, + y, + x, + w1, + w2, + w3, + w4, + x_low, + x_high, + y_low, + y_high, + index); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd(grad_input_offset + y_low * width + x_low, g1); + atomicAdd(grad_input_offset + y_low * width + x_high, g2); + atomicAdd(grad_input_offset + y_high * width + x_low, g3); + atomicAdd(grad_input_offset + y_high * width + x_high, g4); + } // if + } // ix + } // iy + } +} + +std::tuple PSROIAlign_forward_cuda( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio) { + // Check if input tensors are CUDA tensors + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "PSROIAlign_forward_cuda"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + at::cuda::CUDAGuard device_guard(input.device()); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + AT_ASSERTM( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + int channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros( + {num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = + at::zeros(output.sizes(), input.options().dtype(at::kInt)); + + auto output_size = output.numel(); + if (output_size == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(output, channel_mapping); + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(at::cuda::ATenCeilDiv(output_size, 512L), 4096L)); + dim3 block(512); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "PSROIAlign_forward", [&] { + PSROIAlignForwardCUDA<<>>( + output_size, + input.contiguous().data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois.contiguous().data(), + channels_out, + output.data(), + channel_mapping.data()); + }); + AT_CUDA_CHECK(cudaGetLastError()); + cudaDeviceSynchronize(); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor PSROIAlign_backward_cuda( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const int batch_size, + const int channels, + const int height, + const int width) { + // Check if input tensors are CUDA tensors + AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM( + channel_mapping.type().is_cuda(), + "channel_mapping must be a CUDA tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "PSROIAlign_backward_cuda"; + at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + at::cuda::CUDAGuard device_guard(grad.device()); + + auto num_rois = rois.size(0); + auto grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(at::cuda::ATenCeilDiv(grad.numel(), 512L), 4096L)); + dim3 block(512); + + // handle possibly empty gradients + if (grad.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return grad_input; + } + + int channels_out = channels / (pooled_height * pooled_width); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "PSROIAlign_backward", [&] { + PSROIAlignBackwardCUDA<<>>( + grad.numel(), + grad.contiguous().data(), + channel_mapping.data(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + channels_out, + grad_input.data(), + rois.contiguous().data()); + }); + AT_CUDA_CHECK(cudaGetLastError()); + return grad_input; +} diff --git a/torchvision/csrc/cuda/PSROIPool_cuda.cu b/torchvision/csrc/cuda/PSROIPool_cuda.cu new file mode 100644 index 00000000000..a2c5addd936 --- /dev/null +++ b/torchvision/csrc/cuda/PSROIPool_cuda.cu @@ -0,0 +1,262 @@ +#include +#include +#include +#include +#include + +#include "cuda_helpers.h" + +template +__global__ void PSROIPoolForward( + const int nthreads, + const T* input, + const T spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const T* rois, + const int channels_out, + T* output, + int* channel_mapping) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c_out, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c_out = (index / pooled_width / pooled_height) % channels_out; + int n = index / pooled_width / pooled_height / channels_out; + + // (n, c_in, ph, pw) is the associated element in the input + int c_in = (c_out * pooled_height + ph) * pooled_width + pw; + + // [start, end) interval for spatial sampling + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = roundf(offset_rois[1] * spatial_scale); + int roi_start_h = roundf(offset_rois[2] * spatial_scale); + int roi_end_w = roundf(offset_rois[3] * spatial_scale); + int roi_end_h = roundf(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + int roi_width = max(roi_end_w - roi_start_w, 1); + int roi_height = max(roi_end_h - roi_start_h, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, 0), height - 1); + hend = min(max(hend + roi_start_h, 0), height - 1); + wstart = min(max(wstart + roi_start_w, 0), width - 1); + wend = min(max(wend + roi_start_w, 0), width - 1); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + const T* offset_input = + input + (roi_batch_ind * channels + c_in) * height * width; + T out_sum = 0; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int input_index = h * width + w; + out_sum += offset_input[input_index]; + } + } + + T bin_area = (hend - hstart) * (wend - wstart); + output[index] = is_empty ? static_cast(0) : out_sum / bin_area; + channel_mapping[index] = c_in; + } +} + +template +__global__ void PSROIPoolBackward( + const int nthreads, + const T* grad_output, + const int* channel_mapping, + const int num_rois, + const T spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int channels_out, + T* grad_input, + const T* rois) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, *, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int n = index / pooled_width / pooled_height / channels_out; + + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; + int roi_start_w = roundf(offset_rois[1] * spatial_scale); + int roi_start_h = roundf(offset_rois[2] * spatial_scale); + int roi_end_w = roundf(offset_rois[3] * spatial_scale); + int roi_end_h = roundf(offset_rois[4] * spatial_scale); + + // Force too small ROIs to be 1x1 + int roi_width = max(roi_end_w - roi_start_w, 1); + int roi_height = max(roi_end_h - roi_start_h, 1); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); + + // Add roi offsets and clip to input boundaries + hstart = min(max(hstart + roi_start_h, 0), height); + hend = min(max(hend + roi_start_h, 0), height); + wstart = min(max(wstart + roi_start_w, 0), width); + wend = min(max(wend + roi_start_w, 0), width); + bool is_empty = (hend <= hstart) || (wend <= wstart); + + int c_in = channel_mapping[index]; + T* grad_input_offset = + grad_input + (roi_batch_ind * channels + c_in) * height * width; + T bin_area = (hend - hstart) * (wend - wstart); + T diff_val = is_empty ? static_cast(0) : grad_output[index] / bin_area; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int grad_input_index = h * width + w; + atomicAdd(grad_input_offset + grad_input_index, diff_val); + } + } + } +} + +std::tuple PSROIPool_forward_cuda( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width) { + // Check if input tensors are CUDA tensors + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "PSROIPool_forward_cuda"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + + at::cuda::CUDAGuard device_guard(input.device()); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + AT_ASSERTM( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width"); + int channels_out = channels / (pooled_height * pooled_width); + + auto output = at::zeros( + {num_rois, channels_out, pooled_height, pooled_width}, input.options()); + auto channel_mapping = + at::zeros(output.sizes(), input.options().dtype(at::kInt)); + + auto output_size = output.numel(); + if (output_size == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(output, channel_mapping); + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(at::cuda::ATenCeilDiv(output_size, 512L), 4096L)); + dim3 block(512); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "PSROIPool_forward", [&] { + PSROIPoolForward<<>>( + output_size, + input.contiguous().data(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + rois.contiguous().data(), + channels_out, + output.data(), + channel_mapping.data()); + }); + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(output, channel_mapping); +} + +at::Tensor PSROIPool_backward_cuda( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width) { + // Check if input tensors are CUDA tensors + AT_ASSERTM(grad.type().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.type().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM( + channel_mapping.type().is_cuda(), + "channel_mapping must be a CUDA tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}, + channel_mapping_t{channel_mapping, "channel_mapping", 3}; + + at::CheckedFrom c = "PSROIPool_backward_cuda"; + at::checkAllSameGPU(c, {grad_t, rois_t, channel_mapping_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + + at::cuda::CUDAGuard device_guard(grad.device()); + + auto num_rois = rois.size(0); + auto grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min(at::cuda::ATenCeilDiv(grad.numel(), 512L), 4096L)); + dim3 block(512); + + // handle possibly empty gradients + if (grad.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return grad_input; + } + + int channels_out = channels / (pooled_height * pooled_width); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "PSROIPool_backward", [&] { + PSROIPoolBackward<<>>( + grad.numel(), + grad.contiguous().data(), + channel_mapping.data(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + channels_out, + grad_input.data(), + rois.contiguous().data()); + }); + AT_CUDA_CHECK(cudaGetLastError()); + return grad_input; +} diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 12b5d70b599..b35c4c909c1 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -41,6 +41,46 @@ at::Tensor ROIPool_backward_cuda( const int height, const int width); +std::tuple PSROIPool_forward_cuda( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width); + +at::Tensor PSROIPool_backward_cuda( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& mapping_channel, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width); + +std::tuple PSROIAlign_forward_cuda( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio); + +at::Tensor PSROIAlign_backward_cuda( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& mapping_channel, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const int batch_size, + const int channels, + const int height, + const int width); + at::Tensor nms_cuda( const at::Tensor& dets, const at::Tensor& scores, diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index 61a4eeee727..8a80c1f3f05 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -1,3 +1,5 @@ +#include "PSROIAlign.h" +#include "PSROIPool.h" #include "ROIAlign.h" #include "ROIPool.h" #include "nms.h" @@ -10,6 +12,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // TODO: remove nms from here since it is now registered // and used as a PyTorch custom op m.def("nms", &nms, "non-maximum suppression"); + m.def("ps_roi_align_forward", &PSROIAlign_forward, "PSROIAlign_forward"); + m.def("ps_roi_align_backward", &PSROIAlign_backward, "PSROIAlign_backward"); + m.def("ps_roi_pool_forward", &PSROIPool_forward, "PSROIPool_forward"); + m.def("ps_roi_pool_backward", &PSROIPool_backward, "PSROIPool_backward"); m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward"); m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward"); m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward"); diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index fbd1181929b..06449ac9c3f 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -1,11 +1,16 @@ from .boxes import nms, box_iou +from .feature_pyramid_network import FeaturePyramidNetwork +from .poolers import MultiScaleRoIAlign +from .ps_roi_align import ps_roi_align, PSRoIAlign +from .ps_roi_pool import ps_roi_pool, PSRoIPool from .roi_align import roi_align, RoIAlign from .roi_pool import roi_pool, RoIPool -from .poolers import MultiScaleRoIAlign -from .feature_pyramid_network import FeaturePyramidNetwork - __all__ = [ - 'nms', 'roi_align', 'RoIAlign', 'roi_pool', 'RoIPool', - 'MultiScaleRoIAlign', 'FeaturePyramidNetwork' + 'nms', + 'MultiScaleRoIAlign', 'FeaturePyramidNetwork', + 'ps_roi_align', 'PSRoIAlign', + 'ps_roi_pool', 'PSRoIPool', + 'roi_align', 'RoIAlign', + 'roi_pool', 'RoIPool', ] diff --git a/torchvision/ops/poolers.py b/torchvision/ops/poolers.py index d87da57dfd7..831b3a860ba 100644 --- a/torchvision/ops/poolers.py +++ b/torchvision/ops/poolers.py @@ -3,8 +3,8 @@ import torch.nn.functional as F from torch import nn -from torchvision.ops import roi_align -from torchvision.ops.boxes import box_area +from .roi_align import roi_align +from .boxes import box_area class LevelMapper(object): diff --git a/torchvision/ops/ps_roi_align.py b/torchvision/ops/ps_roi_align.py new file mode 100644 index 00000000000..ab5b018ff42 --- /dev/null +++ b/torchvision/ops/ps_roi_align.py @@ -0,0 +1,100 @@ +import torch +from torch import nn + +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from torch.nn.modules.utils import _pair + +from torchvision.extension import _lazy_import +from ._utils import convert_boxes_to_roi_format + + +class _PSRoIAlignFunction(Function): + @staticmethod + def forward(ctx, input, rois, output_size, spatial_scale, sampling_ratio): + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.sampling_ratio = sampling_ratio + ctx.input_shape = input.size() + _C = _lazy_import() + output, channel_mapping = _C.ps_roi_align_forward( + input, rois, spatial_scale, + output_size[0], output_size[1], sampling_ratio + ) + ctx.save_for_backward(rois, channel_mapping) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + rois, channel_mapping = ctx.saved_tensors + output_size = ctx.output_size + spatial_scale = ctx.spatial_scale + sampling_ratio = ctx.sampling_ratio + bs, ch, h, w = ctx.input_shape + _C = _lazy_import() + grad_input = _C.ps_roi_align_backward( + grad_output, + rois, + channel_mapping, + spatial_scale, + output_size[0], + output_size[1], + sampling_ratio, + bs, ch, h, w, + ) + return grad_input, None, None, None, None + + +def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1): + """ + Performs Position-Sensitive Region of Interest (RoI) Align operator + mentioned in Light-Head R-CNN. + + Arguments: + input (Tensor[N, C, H, W]): input tensor + boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2) + format where the regions will be taken from. If a single Tensor is passed, + then the first column should contain the batch index. If a list of Tensors + is passed, then each Tensor will correspond to the boxes for an element i + in a batch + output_size (int or Tuple[int, int]): the size of the output after the cropping + is performed, as (height, width) + spatial_scale (float): a scaling factor that maps the input coordinates to + the box coordinates. Default: 1.0 + sampling_ratio (int): number of sampling points in the interpolation grid + used to compute the output value of each pooled output bin. If > 0 + then exactly sampling_ratio x sampling_ratio grid points are used. + If <= 0, then an adaptive number of grid points are used (computed as + ceil(roi_width / pooled_w), and likewise for height). Default: -1 + + Returns: + output (Tensor[K, C, output_size[0], output_size[1]]) + """ + rois = boxes + if not isinstance(rois, torch.Tensor): + rois = convert_boxes_to_roi_format(rois) + return _PSRoIAlignFunction.apply(input, rois, output_size, spatial_scale, sampling_ratio) + + +class PSRoIAlign(nn.Module): + """ + See ps_roi_align + """ + def __init__(self, output_size, spatial_scale, sampling_ratio): + super(PSRoIAlign, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + + def forward(self, input, rois): + return ps_roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio) + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) + tmpstr += ")" + return tmpstr diff --git a/torchvision/ops/ps_roi_pool.py b/torchvision/ops/ps_roi_pool.py new file mode 100644 index 00000000000..7ee33fb70de --- /dev/null +++ b/torchvision/ops/ps_roi_pool.py @@ -0,0 +1,89 @@ +import torch +from torch import nn + +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from torch.nn.modules.utils import _pair + +from torchvision.extension import _lazy_import +from ._utils import convert_boxes_to_roi_format + + +class _PSRoIPoolFunction(Function): + @staticmethod + def forward(ctx, input, rois, output_size, spatial_scale): + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.input_shape = input.size() + _C = _lazy_import() + output, channel_mapping = _C.ps_roi_pool_forward( + input, rois, spatial_scale, output_size[0], output_size[1] + ) + ctx.save_for_backward(rois, channel_mapping) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + rois, channel_mapping = ctx.saved_tensors + output_size = ctx.output_size + spatial_scale = ctx.spatial_scale + bs, ch, h, w = ctx.input_shape + _C = _lazy_import() + grad_input = _C.ps_roi_pool_backward( + grad_output, + rois, + channel_mapping, + spatial_scale, + output_size[0], + output_size[1], + bs, ch, h, w, + ) + return grad_input, None, None, None + + +def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0): + """ + Performs Position-Sensitive Region of Interest (RoI) Pool operator + described in R-FCN + + Arguments: + input (Tensor[N, C, H, W]): input tensor + boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2) + format where the regions will be taken from. If a single Tensor is passed, + then the first column should contain the batch index. If a list of Tensors + is passed, then each Tensor will correspond to the boxes for an element i + in a batch + output_size (int or Tuple[int, int]): the size of the output after the cropping + is performed, as (height, width) + spatial_scale (float): a scaling factor that maps the input coordinates to + the box coordinates. Default: 1.0 + + Returns: + output (Tensor[K, C, output_size[0], output_size[1]]) + """ + rois = boxes + if not isinstance(rois, torch.Tensor): + rois = convert_boxes_to_roi_format(rois) + return _PSRoIPoolFunction.apply(input, rois, output_size, spatial_scale) + + +class PSRoIPool(nn.Module): + """ + See ps_roi_pool + """ + def __init__(self, output_size, spatial_scale): + super(PSRoIPool, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + + def forward(self, input, rois): + return ps_roi_pool(input, rois, self.output_size, self.spatial_scale) + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ")" + return tmpstr