In [1]:
%config Completer.use_jedi = False

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class CCM(nn.Module):
    def __init__(self, in_chans, N=6, ks=[9, 17, 33, 65, 129]):
        super(CCM, self).__init__()
        
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        
        self.mce = [nn.Conv1d(
            in_chans, 
            in_chans, 
            kernel_size=k, 
            stride=k, 
            padding=k-1) for k in ks]
        
        self.fc1 = nn.Linear(len(ks), 1, bias=True)
        self.fc2 = nn.Linear(in_chans, N)
        
    def forward(self, x):
        b, c, h, w = x.size()
        P_x = self.avgpool(x).contiguous().view(b, c, -1)
        C_m = self.mceblock(P_x)
        C_m = torch.sigmoid(C_m) # [N, C, 1]
        mid_feat = X_c = torch.mul(C_m.unsqueeze(-1), x)
        P_p = self.fc2(C_m.squeeze(-1)).unsqueeze(-1)        
        M = torch.matmul(C_m, P_p.transpose(2,1))
        
        return M, mid_feat, P_p
    
    def mceblock(self, p_x):
        # Need to consider again
        res = []
        for it in self.mce: 
            C_i = it(p_x)
            res.append(C_i)
        C_c = torch.cat(res, -1)
        return self.fc1(C_c)
        
if __name__ == "__main__":
    
    x = torch.randn(2, 512, 64, 64)
    module = CCM(in_chans=512, N=6)
    out = module(x)
    if isinstance(out, (tuple, list)):
        for idx in out:
            print(idx.shape)

torch.Size([2, 512, 6])
torch.Size([2, 512, 64, 64])
torch.Size([2, 6, 1])


In [4]:
class SCM(nn.Module):
    def __init__(self, in_chans=512, N=6, r=8):
        super(SCM, self).__init__()
        
        self.conv2d = nn.Conv2d(512, 512//r, kernel_size=1)
        self.conv2df = nn.Sequential(
            nn.Conv2d(64, 512,kernel_size=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        ) 
        self.conv1da = nn.Conv1d(512, 512//r, kernel_size=1)
        self.conv1db = nn.Conv1d(512, 512//r, kernel_size=1)

        
        
    def forward(self, X_c, M):
        """Args
        X_c: [B, C, H, W]
        M: [B, C, N]
        """
        B = self.conv2d(X_c)
        b, c_s, h, w = B.shape
        B = B.contiguous().view(b, c_s, -1)
        C = self.conv1da(M)
        D = self.conv1db(M)
        E = torch.einsum('bcn, bcm -> bnm', B, C)
        E = F.softmax(E, dim=-1)
        X_s = torch.einsum('bcn, bmn -> bcm', D, E)
        X_s = X_s.contiguous().view(b, -1, h, w)
        X_s = self.conv2df(X_s)
        return X_s

if __name__ == "__main__":
    
    X_c=torch.randn(2, 512, 64, 64)
    M = torch.randn(2, 512, 6)
    module = SCM()
    out = module(X_c, M)
    if isinstance(out, (tuple, list)):
        for idx in out:
            print(idx.shape)
    else:
        print(out.shape)

torch.Size([2, 512, 64, 64])
