In [141]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class RE_BN_DPCONV(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
        super(RE_BN_DPCONV, self).__init__()
        # kernel 1X1
        self.conv1x1= nn.Conv2d(in_ch, out_ch, 1)
        # global average pooling
        self.GAP =  nn.AdaptiveAvgPool2d((1, 1))
        #MLP
        self.flatten = nn.Flatten()
        self.relu = nn.ReLU()
        self.linear = nn.Linear(out_ch, 4)
        self.softmax = nn.Softmax(dim = 1)

        # Paramidal convolution
        self.conv3x3 = nn.Conv2d(out_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride)
        self.conv5x5 = nn.Conv2d(out_ch, out_ch, 5, padding=2 * dirate, dilation=1 * dirate, stride=stride)
        self.conv7x7 = nn.Conv2d(out_ch, out_ch, 7, padding=3 * dirate, dilation=1 * dirate, stride=stride)
        self.conv9x9 = nn.Conv2d(out_ch, out_ch, 9, padding=4 * dirate, dilation=1 * dirate, stride=stride)

        # kernel 1X1 final
        self.conv1x1_final= nn.Conv2d(4 * out_ch, out_ch, 1)

        # ReLU + BatchNorm
        self.bn_s1 = nn.BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)

    def forward(self, x): 
        x = self.conv1x1(x) 
        hx = x 
        # GAP + MLP
        hx = self.GAP(hx)
        hx = self.flatten(hx)
        hx = self.relu(hx)
        hx = self.linear(hx)
        hx = self.softmax(hx)
        # Paramidal convolution
        out1 = self.conv3x3(x)
        out2 = self.conv5x5(x)
        out3 = self.conv7x7(x)
        out4 = self.conv9x9(x)
        
        # channel wise
        for i in range(hx.shape[0]):
            out1[i] = out1[i] * hx[i][0]
            out2[i] = out2[i] * hx[i][1]
            out3[i] = out3[i] * hx[i][2]
            out4[i] = out4[i] * hx[i][3]
        # concatenation
        out_fuse = torch.cat((out1, out2, out3, out4),1)
        # conv1x1_final
        y = self.conv1x1_final(out_fuse)
        
        return self.relu_s1(self.bn_s1(y + x))

In [142]:
X = torch.randn((10, 30, 320, 320))
model = RE_BN_DPCONV(in_ch=30, out_ch=15, dirate=1, stride=1)
a = model(X)
a.shape

torch.Size([10, 15, 320, 320])