In [None]:
# default_exp models.dcn

# DCN
- https://github.com/CharlesShang/DCNv2
- https://arxiv.org/abs/1707.02069

In [None]:
#export
from fastai2.vision.all import *
from moving_mnist.models.conv_rnn import *
try:
    from mmcv.ops import *
except:
    'please install MCV for your Pytorch version'

In [None]:
if torch.cuda.is_available():
    torch.cuda.set_device(0)
    print(torch.cuda.get_device_name())

Quadro RTX 8000


In [None]:
#export
class DCN(Module):
    "A Deformable Convolutional Kernel"
    def __init__(self, 
                 in_channels, 
                 out_channels,
                 kernel_size,
                 stride, 
                 padding,
                 dilation=1, 
                 deformable_groups=1):

        channels_ = deformable_groups * 2 * kernel_size[0] * kernel_size[1]
        self.conv_offset = nn.Conv2d(in_channels,
                                     channels_,
                                     kernel_size=kernel_size,
                                     stride=stride,
                                     padding=padding,
                                     bias=True)
        self.dconv = DeformConv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, 1, deformable_groups)
        self.init_offset()

    def init_offset(self):
        self.conv_offset.weight.data.zero_()
        self.conv_offset.bias.data.zero_()

    def forward(self, input):
        out = self.conv_offset(input)
        o1, o2 = torch.chunk(out, 2, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        return self.dconv(input, offset)

In [None]:
body = create_body(resnet34, 2)[0:6]

In [None]:
class Net(Module):
    def __init__(self, arch=resnet34, n_in=1, n_out=2):
        self.body = create_body(arch, n_in)[0:6]
        self.head = nn.Sequential(DCN(128, 256, (5,5), 1, 2),
                                  DCN(256,256,(3,3), 2, 1))
    def forward(self, x):
        return self.head(self.body(x))

In [None]:
m = Net().cuda()

In [None]:
m(torch.rand(2,1,128,128).cuda()).shape

torch.Size([2, 256, 8, 8])

# Export -

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 00_data.ipynb.
Converted 01_models.conv_rnn.ipynb.
Converted 02_models.dcn.ipynb.
Converted 02_models.transformer.ipynb.
Converted index.ipynb.
