In [1]:
!pip install einops



In [2]:
import torch 
import torchvision
import torchvision.models as models
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from einops import rearrange, reduce, repeat

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os

In [3]:
# Example of how to use the 1D Convolution 

import torch
from torch import nn
from torch.nn.parameter import Parameter

class eca_layer(nn.Module):
    """Constructs a ECA module.
    Args:
        channel: Number of channels of the input feature map
        k_size: Adaptive selection of kernel size
    """
    def __init__(self, channel, k_size=3):
        super(eca_layer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: input features with shape [b, c, h, w]
        b, c, h, w = x.size()
        # feature descriptor on the global spatial information
        y = self.avg_pool(x)
        # Two different branches of ECA module
        y = rearrange(y, 'b c h w -> b c (h w)')
        y = rearrange(y, 'b c n -> b n c')

        y = self.conv(y)
        y = rearrange(y, 'b n c -> b c n')
        y = rearrange(y, 'b c (h w) -> b c h w', h=1, w=1)
        
        # Multi-scale information fusion
        y = self.sigmoid(y)
        return x * y.expand_as(x)

In [4]:
class DeformConv2D(nn.Module):
    def __init__(self, inc, outc, kernel_size=3, padding=1, bias=None):
        super(DeformConv2D, self).__init__()
        self.kernel_size = kernel_size
        self.padding = padding
        self.zero_padding = nn.ZeroPad2d(padding)
        self.conv_kernel = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)

    def forward(self, x, offset):
        dtype = offset.data.type()
        ks = self.kernel_size
        N = offset.size(1) // 2
    
        print('offset: ', offset.shape, x.shape)
        
        # Change offset's order from [x1, x2, ..., y1, y2, ...] to [x1, y1, x2, y2, ...]
        # Codes below are written to make sure same results of MXNet implementation.
        # You can remove them, and it won't influence the module's performance.
        offsets_index = Variable(torch.cat([torch.arange(0, 2*N, 2), torch.arange(1, 2*N+1, 2)]), requires_grad=False).type_as(x).long()
        offsets_index = offsets_index.unsqueeze(dim=0).unsqueeze(dim=-1).unsqueeze(dim=-1).expand(*offset.size())
        offset = torch.gather(offset, dim=1, index=offsets_index)
        # ------------------------------------------------------------------------

        if self.padding:
            x = self.zero_padding(x)

        # (b, 2N, h, w)
        p = self._get_p(offset, dtype)

        # (b, h, w, 2N)
        p = p.contiguous().permute(0, 2, 3, 1)
        q_lt = Variable(p.data, requires_grad=False).floor()
        q_rb = q_lt + 1

        print(q_lt.shape)
        print(q_rb.shape)
        
        
        q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
        q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
        q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], -1)
        q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], -1)
        
        print(q_lt.shape, q_rb.shape, q_lb.shape, q_rt.shape)

        # (b, h, w, N)
        mask = torch.cat([p[..., :N].lt(self.padding)+p[..., :N].gt(x.size(2)-1-self.padding),
                          p[..., N:].lt(self.padding)+p[..., N:].gt(x.size(3)-1-self.padding)], dim=-1).type_as(p)
        mask = mask.detach()
        floor_p = p - (p - torch.floor(p))
        p = p*(1-mask) + floor_p*mask
        p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)

        # bilinear kernel (b, h, w, N)
        g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
        g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
        g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
        g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))

        # (b, c, h, w, N)
        x_q_lt = self._get_x_q(x, q_lt, N)
        x_q_rb = self._get_x_q(x, q_rb, N)
        x_q_lb = self._get_x_q(x, q_lb, N)
        x_q_rt = self._get_x_q(x, q_rt, N)

        # (b, c, h, w, N)
        x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
                   g_rb.unsqueeze(dim=1) * x_q_rb + \
                   g_lb.unsqueeze(dim=1) * x_q_lb + \
                   g_rt.unsqueeze(dim=1) * x_q_rt

        x_offset = self._reshape_x_offset(x_offset, ks)
        out = self.conv_kernel(x_offset)

        return out

    def _get_p_n(self, N, dtype):
        p_n_x, p_n_y = np.meshgrid(range(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
                          range(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), indexing='ij')
        # (2N, 1)
        p_n = np.concatenate((p_n_x.flatten(), p_n_y.flatten()))
        p_n = np.reshape(p_n, (1, 2*N, 1, 1))
        p_n = Variable(torch.from_numpy(p_n).type(dtype), requires_grad=False)

        print('p_n: ', p_n.shape)
        return p_n

    @staticmethod
    def _get_p_0(h, w, N, dtype):
        p_0_x, p_0_y = np.meshgrid(range(1, h+1), range(1, w+1), indexing='ij')
        p_0_x = p_0_x.flatten().reshape(1, 1, h, w).repeat(N, axis=1)
        p_0_y = p_0_y.flatten().reshape(1, 1, h, w).repeat(N, axis=1)
        p_0 = np.concatenate((p_0_x, p_0_y), axis=1)
        p_0 = Variable(torch.from_numpy(p_0).type(dtype), requires_grad=False)

        print('p_0: ', p_0.shape)
        return p_0

    def _get_p(self, offset, dtype):
        N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)

        # (1, 2N, 1, 1)
        p_n = self._get_p_n(N, dtype)
        # (1, 2N, h, w)
        p_0 = self._get_p_0(h, w, N, dtype)
        p = p_0 + p_n + offset
        print('p: ', p.shape)
        return p

    def _get_x_q(self, x, q, N):
        b, h, w, _ = q.size()
        padded_w = x.size(3)
        c = x.size(1)
        # (b, c, h*w)
        print(x.shape)
        x = x.contiguous().view(b, c, -1)

        print('q: ', q.shape)
        
        print('x: ', x.shape)
        
        # (b, h, w, N)
        index = q[..., :N]*padded_w + q[..., N:]  # offset_x*w + offset_y
        print('index 1: ', index.shape)
        # (b, c, h*w*N)
        index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
        print('index 2: ', index.shape)
        
        x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
        print('x_offset: ', x_offset.shape)
        
        print(q)
        print(index)
                
        return x_offset

    @staticmethod
    def _reshape_x_offset(x_offset, ks):
        b, c, h, w, N = x_offset.size()
        x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
        x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)

        return x_offset

In [5]:
class DeformNet(nn.Module):
    def __init__(self):
        super(DeformNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)

        self.offsets = nn.Conv2d(128, 18, kernel_size=3, padding=1)
        self.conv4 = DeformConv2D(128, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)

        self.classifier = nn.Linear(128, 10)

    def forward(self, x):
        # convs
        x = F.relu(self.conv1(x))
        x = self.bn1(x)
        x = F.relu(self.conv2(x))
        x = self.bn2(x)
        x = F.relu(self.conv3(x))
        x = self.bn3(x)
        # deformable convolution
        offsets = self.offsets(x)
        x = F.relu(self.conv4(x, offsets))
        x = self.bn4(x)

        x = F.avg_pool2d(x, kernel_size=28, stride=1).view(x.size(0), -1)
        x = self.classifier(x)

        return F.log_softmax(x, dim=1)

In [6]:
model = DeformNet()

In [7]:
input_feat = torch.rand([10, 1, 28, 28]) 

out = model(input_feat)
out.shape

offset:  torch.Size([10, 18, 28, 28]) torch.Size([10, 128, 28, 28])
p_n:  torch.Size([1, 18, 1, 1])
p_0:  torch.Size([1, 18, 28, 28])
p:  torch.Size([10, 18, 28, 28])
torch.Size([10, 28, 28, 18])
torch.Size([10, 28, 28, 18])
torch.Size([10, 28, 28, 18]) torch.Size([10, 28, 28, 18]) torch.Size([10, 28, 28, 18]) torch.Size([10, 28, 28, 18])
torch.Size([10, 128, 30, 30])
q:  torch.Size([10, 28, 28, 18])
x:  torch.Size([10, 128, 900])
index 1:  torch.Size([10, 28, 28, 9])
index 2:  torch.Size([10, 128, 7056])
x_offset:  torch.Size([10, 128, 28, 28, 9])
tensor([[[[ 0,  0,  0,  ...,  0,  1,  2],
          [ 0,  0,  0,  ...,  0,  2,  2],
          [ 0,  0,  0,  ...,  2,  3,  4],
          ...,
          [ 0,  0,  0,  ..., 25, 26, 27],
          [ 0,  0,  0,  ..., 25, 27, 28],
          [ 0,  0,  0,  ..., 26, 28, 29]],

         [[ 0,  0,  0,  ...,  0,  0,  1],
          [ 0,  0,  0,  ...,  0,  1,  2],
          [ 0,  0,  0,  ...,  1,  3,  3],
          ...,
          [ 1,  0,  0,  ..., 25, 25

torch.Size([10, 10])

In [8]:
h = 1
w = 512 
N = 3 

p_0_x = np.meshgrid(range(1, w+1))[0]
p_0 = p_0_x.flatten().reshape(1, 1, w).repeat(N, axis=1)
p_0 = Variable(torch.from_numpy(p_0).type(torch.FloatTensor), requires_grad=False)

print('p_0: ', p_0.shape)


p_0:  torch.Size([1, 3, 512])


In [9]:
kernel_size = 3 

p_n_x = np.meshgrid(range(-(kernel_size-1)//2, (kernel_size-1)//2+1))[0]
# (N, 1)
p_n = p_n_x.flatten()
p_n = np.reshape(p_n, (1, N, 1))
p_n = Variable(torch.from_numpy(p_n).type(torch.FloatTensor), requires_grad=False)
p_n.shape, p_n

(torch.Size([1, 3, 1]),
 tensor([[[-1.],
          [ 0.],
          [ 1.]]]))

# D1D V1 

This version of the Deformable Convolution works on a 1x1xC Average-Pooled Tensor. 

* Apply a k-sized 1D Convolution to generate offsets
* Linear Interpolation of Offsets 
* Apply offsets 
* Apply convolution on new weights 


In [10]:
b = 10 
c = 512 

input_feat = torch.rand([b, c, 1, 1]) 
input_feat.shape

torch.Size([10, 512, 1, 1])

In [11]:
# Example of how to use the 1D Convolution 

import torch
from torch import nn
from torch.nn.parameter import Parameter

class d1d_offset(nn.Module):
    """Constructs a offsets.
    Args:
        channel: Number of channels of the input feature map
        k_size: Adaptive selection of kernel size
    """
    def __init__(self, channel, k_size=3):
        super(d1d_offset, self).__init__()
        self.offsets = nn.Conv1d(1, k_size, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 

    def forward(self, x):
        # x: input features with shape [b, c, h, w]
        b, c, h, w = x.size()
        y = rearrange(x, 'b c h w -> b c (h w)')
        y = rearrange(y, 'b c n -> b n c')
        offsets = self.offsets(y)        
        return offsets 

In [12]:
offset = d1d_offset(channel=c)(input_feat)
offset.shape

torch.Size([10, 3, 512])

In [13]:
class DeformConv1D(nn.Module):
    def __init__(self, inc=1, outc=1, kernel_size=3, padding=1, bias=None):
        super(DeformConv1D, self).__init__()
        
        self.kernel_size = kernel_size
        self.padding = padding
        self.zero_padding = nn.ConstantPad1d(kernel_size // 2, value=0)        
        self.conv_kernel = nn.Conv1d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)

    def forward(self, x, offset):
        # x - b, c, 1, 1  
        # offset - b k c (k = kernel_size)
                
        x = rearrange(x, 'b c h w -> b c (h w)')
        x = rearrange(x, 'b c n -> b n c')

        dtype = offset.data.type()
        ks = self.kernel_size
        N = offset.size(1) # Kernel size (only x direction)

        assert N == ks, "Offset size is wrong!"
        
        if self.padding:
            x = self.zero_padding(x)
            
        # print('x: ', x.shape)

        # (b, N,  w) - w = c 
        p = self._get_p(offset, dtype)
        
        # (b, w, N)
        p = p.contiguous().permute(0, 2, 1)
                
        # print('p: ', p.shape)
        
        q_left = Variable(p.data, requires_grad=False).floor()
        q_right = q_left + 1  
        
        q_left = torch.clamp(q_left[..., :N], 0, x.size(2)-1).long()
        q_right = torch.clamp(q_right[..., :N], 0, x.size(2)-1).long()
        
        # print('q_left: ', q_left.shape)
        # print('q_right: ', q_right.shape)

        
        # (b, h, w, N)
        mask = (p[..., :N].lt(self.padding)+p[..., :N].gt(x.size(2)-1-self.padding)).type_as(p)
        mask = mask.detach()
        floor_p = p - (p - torch.floor(p))
        p = p*(1-mask) + floor_p*mask
        p = torch.clamp(p[..., :N], 0, x.size(2)-1)

        # print('pnew: ', p.shape)
                
        # linear kernel (b, h, w, N) 
        g_left = (1 + (q_left[..., :N].type_as(p) - p[..., :N])) 
        g_right = (1 - (q_right[..., :N].type_as(p) - p[..., :N])) 

        # print('g_left: ', g_left.shape)
        # print('g_right: ', g_right.shape)
        
        # (b, c, h, w, N)
        x_q_left = self._get_x_q(x, q_left, N)
        x_q_right = self._get_x_q(x, q_right, N)

        # (b, c, h, w, N)
        x_offset = g_left.unsqueeze(dim=1) * x_q_left + \
                   g_right.unsqueeze(dim=1) * x_q_right
          
        # print('x_offset: ', x_offset.shape)
    
        x_offset = self._reshape_x_offset(x_offset, ks)
        out = self.conv_kernel(x_offset)

        return out

    def _get_p_n(self, N, dtype):
        p_n_x = np.meshgrid(range(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))[0]
        # (N, 1)
        p_n = p_n_x.flatten()
        p_n = np.reshape(p_n, (1, N, 1))
        p_n = Variable(torch.from_numpy(p_n).type(dtype), requires_grad=False)
        return p_n

    @staticmethod
    def _get_p_0(w, N, dtype):
        p_0_x = np.meshgrid(range(1, w+1))[0]
        p_0 = p_0_x.flatten().reshape(1, 1, w).repeat(N, axis=1)
        p_0 = Variable(torch.from_numpy(p_0).type(torch.FloatTensor), requires_grad=False)
        return p_0

    def _get_p(self, offset, dtype):
        N, w = offset.size(1), offset.size(2)

        # (1, N, 1, 1)
        p_n = self._get_p_n(N, dtype)
        # (1, N, h, w)
        p_0 = self._get_p_0(w, N, dtype)
        
        # print(p_n.shape, p_0.shape)
        
        p = p_0 + p_n + offset
    
        return p

    def _get_x_q(self, x, q, N):
        b, channels, kernel_size = q.size()

        x = x.contiguous()

        # (b, h, w, N)
        index = q[..., :N] # offset_x
        
        # (b, c, h*w*N)        
        index = index.contiguous().unsqueeze(dim=1).expand(-1, 1, -1, -1).contiguous().view(b, 1, -1)
        
        # print('x: ', x.shape, ' index: ', index.shape, ' q: ', q.shape)
        # print(index)

        x_offset = x.gather(dim=-1, index=index).contiguous().view(b, 1, channels, N)
                
        return x_offset

    @staticmethod
    def _reshape_x_offset(x_offset, ks):
        b, c, w, N = x_offset.size()
        x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, w*ks) for s in range(0, N, ks)], dim=-1)
        # print(x_offset.shape)
        x_offset = x_offset.contiguous().view(b, c, w*ks)

        return x_offset

In [14]:
conv = DeformConv1D(1, 1, kernel_size=3, padding=1)

In [30]:
sum([p.data.nelement() for p in conv.parameters()])

3

In [15]:
b = 10
c = 512 

input_feat = torch.zeros([b, c, 1, 1]) 
offset = d1d_offset(channel=c)(input_feat)
out = conv(input_feat, offset)
out.shape

torch.Size([10, 1, 512])

# Deformable Channel Attention 

In [16]:
class DeformConv1D(nn.Module):
    def __init__(self, inc=1, outc=1, kernel_size=3, padding=1, bias=None):
        super(DeformConv1D, self).__init__()
        
        self.kernel_size = kernel_size
        self.padding = padding
        self.zero_padding = nn.ConstantPad1d(kernel_size // 2, value=0)        
        self.conv_kernel = nn.Conv1d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)

    def forward(self, x, offset):
        # x - b, c, 1, 1  
        # offset - b k c (k = kernel_size)
                
        x = rearrange(x, 'b c h w -> b c (h w)')
        x = rearrange(x, 'b c n -> b n c')

        dtype = offset.data.type()
        ks = self.kernel_size
        N = offset.size(1) # Kernel size (only x direction)

        assert N == ks, "Offset size is wrong!"
        
        if self.padding:
            x = self.zero_padding(x)
            
        # print('x: ', x.shape)

        # (b, k,  w) - w = c 
        p = self._get_p(offset, dtype)
                
        # (b, w, k)
        p = p.contiguous().permute(0, 2, 1)
                
        # print('p: ', p.shape)
        
        q_left = Variable(p.data, requires_grad=False).floor()
        q_right = q_left + 1  
        
        q_left = torch.clamp(q_left[..., :N], 0, x.size(2)-1).long()
        q_right = torch.clamp(q_right[..., :N], 0, x.size(2)-1).long()
        
        # print('q_left: ', q_left.shape)
        # print('q_right: ', q_right.shape)
        
        # (b, h, w, N)
        mask = (p[..., :N].lt(self.padding)+p[..., :N].gt(x.size(2)-1-self.padding)).type_as(p)
        mask = mask.detach()
        floor_p = p - (p - torch.floor(p)) # floor_p = torch.floor(p)
        p = p*(1-mask) + floor_p*mask
        p = torch.clamp(p[..., :N], 0, x.size(2)-1)

        # print('pnew: ', p.shape)
                
        # linear kernel (b, h, w, N) 
        g_left = (1 + (q_left[..., :N].type_as(p) - p[..., :N])) # q_right - p (q_left = q_right - 1)
        g_right = (1 - (q_right[..., :N].type_as(p) - p[..., :N])) # q_left - p (q_right = q_left + 1) 

        # print('g_left: ', g_left.shape)
        # print('g_right: ', g_right.shape)
        
        # (b, c, h, w, N)
        x_q_left = self._get_x_q(x, q_left, N)
        x_q_right = self._get_x_q(x, q_right, N)
        
        # (b, c, h, w, N)
        x_offset = g_left.unsqueeze(dim=1) * x_q_left + \
                   g_right.unsqueeze(dim=1) * x_q_right
          
#         print('x_offset: ', x_offset.shape)
    
        x_offset = self._reshape_x_offset(x_offset, ks)
        
#         print('x_offset: ', x_offset.shape)
        
        out = self.conv_kernel(x_offset)

        return out

    def _get_p_n(self, N, dtype):
        p_n_x = np.meshgrid(range(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))[0]
        # (N, 1)
        p_n = p_n_x.flatten()
        p_n = np.reshape(p_n, (1, N, 1))
        p_n = Variable(torch.from_numpy(p_n).type(dtype), requires_grad=False)
        return p_n

    @staticmethod
    def _get_p_0(w, N, dtype):
        p_0_x = np.meshgrid(range(1, w+1))[0]
        p_0 = p_0_x.flatten().reshape(1, 1, w).repeat(N, axis=1)
        p_0 = Variable(torch.from_numpy(p_0).type(torch.FloatTensor), requires_grad=False)
        return p_0

    def _get_p(self, offset, dtype):
        N, w = offset.size(1), offset.size(2)

        # (1, N, 1, 1)
        p_n = self._get_p_n(N, dtype)
        # (1, N, h, w)
        p_0 = self._get_p_0(w, N, dtype)
                
        # print(p_n.shape, p_0.shape)
        
        p = p_0 + p_n + offset
    
        return p

    def _get_x_q(self, x, q, N):
        b, channels, kernel_size = q.size()

        x = x.contiguous()

        # (b, h, w, N)
        index = q[..., :N] # offset_x
        
        # (b, c, h*w*N)        
        index = index.contiguous().unsqueeze(dim=1).expand(-1, 1, -1, -1).contiguous().view(b, 1, -1)
        
        # print('x: ', x.shape, ' index: ', index.shape, ' q: ', q.shape)
        # print(index)

        x_value_at_int_location = x.gather(dim=-1, index=index).contiguous().view(b, 1, channels, N)
                
        return x_value_at_int_location

    @staticmethod
    def _reshape_x_offset(x_offset, ks):
        b, c, w, N = x_offset.size()
        # x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, w*ks) for s in range(0, N, ks)], dim=-1)
        # print(x_offset.shape)
               
        x_offset = x_offset.contiguous().view(b, c, w*ks)
                
        return x_offset

In [17]:
import torch
from torch import nn
from torch.nn.parameter import Parameter

class dca_offsets_layer(nn.Module):
    """Constructs a Offset Generation module.
    """
    def __init__(self):
        super(dca_offsets_layer, self).__init__()
        
        self.conv = nn.Conv2d(512, 3, kernel_size=1)
    
    def covariance_features(self, x):
        """
        Takes in a feature map and returns the unnormalized covariance matrix 
        """
        m_batchsize, C, height, width = x.size()
        proj_query = x.view(m_batchsize, C, -1)
        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        return energy
    
    def forward(self, x):
        m_batchsize, C, height, width = x.size()
        cov_matrix = self.covariance_features(x).reshape(m_batchsize, C, 1, C)
        offsets = self.conv(cov_matrix)
        return offsets 

In [18]:
x = torch.rand([b, 512, 1, 1])
dca_offsets_conv = dca_offsets_layer()
offsets = dca_offsets_conv(x)
offsets.shape

torch.Size([10, 3, 1, 512])

In [19]:

def cov_feature(x):
    m_batchsize, C, height, width = x.size()
    proj_query = x.view(m_batchsize, C, -1)
    proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
    energy = torch.bmm(proj_query, proj_key)
    return energy

out = cov_feature(x)
out.shape

torch.Size([10, 512, 512])

In [20]:
import torch
from torch import nn
from torch.nn.parameter import Parameter

class dca_offsets_layer(nn.Module):
    """Constructs a Offset Generation module.
    """
    def __init__(self, channel, n_offsets):
        super(dca_offsets_layer, self).__init__()
        
        self.channel = channel 
        self.n_offsets = n_offsets 
        
        self.conv = nn.Conv2d(channel, n_offsets, kernel_size=1)
            
    def covariance_features(self, x):
        """
        Takes in a feature map and returns the unnormalized covariance matrix 
        """
        m_batchsize, C, height, width = x.size()
        proj_query = x.view(m_batchsize, C, -1)
        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        return energy
    
    def forward(self, x):
        m_batchsize, C, height, width = x.size()
        cov_matrix = self.covariance_features(x).reshape(m_batchsize, C, 1, C)
        offsets = self.conv(cov_matrix).squeeze()
        return offsets 
    
x = torch.rand([b, 512, 1, 1])
dca_offsets_conv = dca_offsets_layer(512, 3)
offsets = dca_offsets_conv(x)
offsets.shape

torch.Size([10, 3, 512])

In [21]:
# Example of how to use the 1D Convolution 

import torch
from torch import nn
from torch.nn.parameter import Parameter

class dca_layer(nn.Module):
    """Constructs a Deformable ECA module.
    Args:
        channel: Number of channels of the input feature map
        k_size: Adaptive selection of kernel size
    """
    def __init__(self, channel=512, k_size=3, use_cov=False):
        super(dca_layer, self).__init__()

        self.use_cov = use_cov
        
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        
        if use_cov:
            self.conv_offset = dca_offsets_layer(channel, k_size)
        else:
            self.conv_offset = nn.Conv1d(1, k_size, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 
        
        self.deform_conv = DeformConv1D(1, 1, kernel_size=k_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: input features with shape [b, c, h, w]
        b, c, h, w = x.size()
        # feature descriptor on the global spatial information
        y = self.avg_pool(x)
        
        y_reshaped = rearrange(y, 'b c h w -> b c (h w)')
        y_reshaped = rearrange(y_reshaped, 'b c n -> b n c')
        
        if self.use_cov:
            offset = self.conv_offset(x)
        else:
            offset = self.conv_offset(y_reshaped)
            
#         offset = repeat(torch.tensor([0, 100, 212]), 'k -> b c k', c=512, b=b)
#         offset = rearrange(offset, 'b c k -> b k c')
                
        y = self.deform_conv(y, offset)
        y = rearrange(y, 'b n c -> b c n')
        y = rearrange(y, 'b c (h w) -> b c h w', h=1, w=1)
        
        # Multi-scale information fusion
        y = self.sigmoid(y)
        return x * y.expand_as(x)

In [39]:
b = 10
c = 512
input_feat = torch.ones([b, c, 16 ,16]) 
de_eca = dca_layer(use_cov=True)
out = de_eca(input_feat)

In [40]:
out.shape

torch.Size([10, 512, 16, 16])

In [41]:
sum([p.data.nelement() for p in de_eca.parameters() if p.requires_grad])

1542

In [42]:
eca = eca_layer(channel=512)

In [29]:
sum([p.data.nelement() for p in eca.parameters()])

3

In [33]:
k_size = 3
nn.Conv1d(1, k_size, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False).weight.shape

torch.Size([3, 1, 3])