Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

The backward calculation results of DCN in some cases are inconsistent with the DCN of mmcv. #6885

Open
mengpenghui opened this issue Nov 2, 2022 · 2 comments

Comments

@mengpenghui
Copy link

馃悰 Describe the bug

When I used torchvision.ops.deform_conv2d to build DCNv2, I found that the calculation result of backward was inconsistent with mmcv's DCNv2.
As far as I know, DCNv2 can be constructed through torchvision's deform_conv2d interface as follows:
(refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/modulated_deform_conv.py)

import torch
import torch.nn as nn
from torchvision.ops import deform_conv2d as tv_deform_conv2d

class ModulatedDeformConv2d_tv(nn.modules.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: Union[int, Tuple[int]],
                 stride: int = 1,
                 padding: int = 0,
                 dilation: int = 1,
                 groups: int = 1,
                 deform_groups: int = 1,
                 bias: Union[bool, str] = True):
        super(ModulatedDeformConv2d_tv, self).__init__()
        self.in_channels  = in_channels
        self.out_channels = out_channels
        self.kernel_size  = _pair(kernel_size)
        self.stride       = _pair(stride)
        self.padding      = _pair(padding)
        self.dilation     = _pair(dilation)
        self.groups   = groups
        self.deform_groups = deform_groups
        self.weight = nn.Parameter(torch.Tensor(
            out_channels, in_channels, *self.kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.conv_offset = nn.Conv2d(
            self.in_channels,
            self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding,
            bias=True)
        self.init_weights()
    def init_weights(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()
        self.conv_offset.weight.data.zero_()
        self.conv_offset.bias.data.zero_()
    def forward(self, x):
        out = self.conv_offset(x)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)
        return tv_deform_conv2d(x, offset, self.weight,
                                bias=self.bias, stride=self.stride,
                                padding=self.padding, dilation=self.dilation,
                                mask=mask)

Next, I tested it in mmcv's test case. Test case location:
https://github.com/open-mmlab/mmcv/blob/master/tests/test_ops/test_modulated_deform_conv.py

I tested with ModulatedDeformConv2d_tv as described above in place of ModulatedDeformConv2dPack in test_modulated_deform_conv.py.

dcn = ModulatedDeformConv2dPack_tv(1, 1, kernel_size=(2, 2), stride=1, padding=1, deform_groups=1, bias=False)

As a result, there is an error in the calculation of dcn.conv_offset.weight.grad and dcn.conv_offset.bias.grad.

I checked the cpu implementation of torchvision's DCN and found that it was the same paper as mmcv, and it seems that the implementation of tv refers to openmmlab. https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp#L67

But both are in https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp#L384 and https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp#L500 .The processing logic seems to be different from mmcv.
The implementation version of mmcv can refer to :
https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/pytorch/cpu/modulated_deform_conv.cpp#L294

Finally, I modified the deform_conv2d_kernel.cpp of torchvision according to the version of mmcv, and the test case passed. It seems that in the case of index=-1, the logic of the two is different.

Therefore, I would like to ask the developers of torchvision, since the implementation of openmmlab is referenced and the reference papers are consistent, is it reasonable that the two frameworks are not aligned in this scenario.

Attached here are modifications to torchvision to align mmcv.

diff --git a/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp b/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp
index b1d15a1..7b3ca6b 100644
--- a/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp
+++ b/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp
@@ -388,6 +388,9 @@ scalar_t get_coordinate_weight(
     scalar_t y,
     scalar_t x,
     bool is_y_direction) {
+  if (y <= -1 || y >= height || x <= -1 || x >= width){
+    return 0;
+  }
   int y_l = floor(y);
   int x_l = floor(x);
   int y_h = y_l + 1;
@@ -499,6 +502,10 @@ void deformable_col2im_coord_kernel(
       scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
       scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
 
+      if (y <= -1 || x <= -1 || y >= height || x >= width)
+      {
+        x = y = -2;
+      }
       const scalar_t weight =
           get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction);
       grad_offset_val += mask_value * weight * col_ptr[col_pos];

Versions

The version I use is

python==3.6.9
torch==1.9.0
torchvision==0.10.0
mmcv-full==1.6.2
@mengpenghui
Copy link
Author

Recently, I read the cpu kernel of deform_conv2d, and compared it with the implementation of the original author's DCN
hub (torchvision's implementation modified from this), specifically these two files:

torchvision:
https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp
The original author's implementation:
https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu

I found that the original author added the consideration of boundary value <=-1 when calculating mask_val and offset_val through bilinear sampling.

// mask_val calculating of  original author
//https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu#L222 
       if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) {
          val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
        }
// offset_val calculating of  original author
//https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu#L136
__device__ DType get_coordinate_weight(DType argmax_h, DType argmax_w,
  const int height, const int width, const DType *im_data,
  const int data_width, const int bp_dir) {
  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) {
    //empty
    return 0;
  }
  ...
}

However, in torchvision, the calculation of mask_val is consistent with the native implementation, but the calculation of offset_val seems to optimize the native implementation, and the new algorithm adopted does not align the boundary value = -1 with the native.

// mask_val calculating of  torchvision
//https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp#L506
      if (use_mask && is_y_direction) {
        grad_mask_val += col_ptr[col_pos] *
          bilinear_interpolate(im_ptr, height, width, y, x);
      }
//https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp#L80
scalar_t bilinear_interpolate(
    const scalar_t* in,
    int height,
    int width,
    scalar_t h,
    scalar_t w) {
  if (h <= -1 || height <= h || w <= -1 || width <= w) {
    return 0;
  }
  ...
}

The implementation here has improved the native. It seems that the purpose is to cover the native limitations while being simpler, but it does not consider the situation of x, y=-1, at least not aligned with the native situation.

//https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp#L384
template <typename scalar_t>
scalar_t get_coordinate_weight(
    const scalar_t* im_data,
    int height,
    int width,
    scalar_t y,
    scalar_t x,
    bool is_y_direction) {
  int y_l = floor(y);
  int x_l = floor(x);
  int y_h = y_l + 1;
  int x_h = x_l + 1;
  bool valid_y_l = 0 <= y_l && y_l < height;
  bool valid_y_h = 0 <= y_h && y_h < height;
  bool valid_x_l = 0 <= x_l && x_l < width;
  bool valid_x_h = 0 <= x_h && x_h < width;
  scalar_t zero = 0;
  scalar_t v_yx = (valid_y_l && valid_x_l) ? im_data[y_l * width + x_l] : zero;
  scalar_t v_yX = (valid_y_l && valid_x_h) ? im_data[y_l * width + x_h] : zero;
  scalar_t v_Yx = (valid_y_h && valid_x_l) ? im_data[y_h * width + x_l] : zero;
  scalar_t v_YX = (valid_y_h && valid_x_h) ? im_data[y_h * width + x_h] : zero;
  if (is_y_direction) {
    scalar_t dx = x - x_l;
    return dx * (v_YX - v_yX) + (1 - dx) * (v_Yx - v_yx);
  } else {
    scalar_t dy = y - y_l;
    return dy * (v_YX - v_Yx) + (1 - dy) * (v_yX - v_yx);
  }
}

//https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp#L503
      scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
      scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
      const scalar_t weight = get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction);

According to the principle of bilinear sampling, for the case of boundary value = -1, there is no sampleable dependency point around it, but currently torchvision will sample it as a point on the boundary of the corresponding data in the get_coordinate_weight function , which doesn't seem reasonable.

I don't know if this difference is intentional or not, but at the moment it looks more like a bug in modding the code to me personally.

@GuWei007
Copy link

GuWei007 commented Nov 8, 2022

https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/modulated_deform_conv.py
The logic of mmcv version seems more reasonable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants