In [None]:
import torch.nn as nn
import torch

In [None]:
def conv3x3(in_planes, out_planes, stride=1):
    conv3x3 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
    return conv3x3

def conv1x1(in_planes, out_planes, stride=1):
    conv1x1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
    return conv1x1

class SE_block(nn.Module):
    def __init__(self, inplanes):
        super(SE_block, self).__init__()
        self.se_conv1 = conv1x1(inplanes, inplanes//16)
        self.se_conv2 = conv1x1(inplanes//16, inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        out = self.avgpool(x)
        out = self.se_conv1(out)
        out = self.relu(out)
        out = self.se_conv2(out)
        out = self.sigmoid(out)
        return x * out

In [None]:
'''
copy from: https://github.com/leaderj1001/BottleneckTransformers
author: Myeongjun Kim
'''

class MHSA(nn.Module):

    def __init__(self, n_dims, width=4, height=4, heads=4):
        super(MHSA, self).__init__()
        self.heads = heads

        self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)

        self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, height]), requires_grad=True)
        self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims // heads, width, 1]), requires_grad=True)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        n_batch, C, width, height = x.size()
        q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
        k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
        v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)

        content_content = torch.matmul(q.permute(0, 1, 3, 2), k)

        content_position = (self.rel_h + self.rel_w).view(1, self.heads, C // self.heads, -1).permute(0, 1, 3, 2)
        content_position = torch.matmul(content_position, q)

        energy = content_content + content_position
        attention = self.softmax(energy)

        out = torch.matmul(v, attention.permute(0, 1, 3, 2))
        out = out.view(n_batch, C, width, height)

        return out

In [None]:
'''modified from BasicBlock in iResNet.ipynb'''


class IBT(nn.Module):
    exp_block = 1

    def __init__(self, inplanes, outplanes, stride=1, downsample=None, nm_layer=None, s_block=False, e_block=False, exd_bn0=False):
        super(IBT, self).__init__()

        if nm_layer is None:
            nm_layer = nn.BatchNorm2d
        if not s_block and not exd_bn0:
            self.bn0 = nm_layer(outplanes)

        self.conv1 = conv3x3(inplanes, outplanes, stride)
        self.bn1 = nm_layer(outplanes)
        self.prelu = nn.PReLU()
        self.mhsa = MHSA(outplanes)
        self.relu = nn.ReLU()
        self.conv2 = conv1x1(outplanes, outplanes)

        self.dw_conv3x3 = nn.Sequential(
            nn.Conv2d(outplanes, outplanes, 3, padding=1, groups=outplanes, bias=False),
            nn.BatchNorm2d(outplanes))
        self.se = SE_block(outplanes)

        if s_block:
            self.bn2 = nm_layer(outplanes)

        if e_block:
            self.bn2 = nm_layer(outplanes)

        self.downsample = downsample
        self.stride = stride

        self.s_block = s_block
        self.e_block = e_block
        self.exd_bn0 = exd_bn0

    def forward(self, x):

        id_ibsa = x

        if self.s_block:
            out = self.conv1(x)
        elif self.exd_bn0:
            out = self.relu(x)
            out = self.conv1(out)
        else:
            out = self.bn0(x)
            out = self.relu(out)
            out = self.conv1(out)

        out = self.bn1(out)
        out = self.relu(out)
        out = self.mhsa(out)

        id_ibsa = self.conv1(id_ibsa)

        out = out + id_ibsa

        identity = out

        out = self.bn1(out)
        out = self.prelu(out)
        out = self.dw_conv3x3(out)

        out = self.se(out)

        out = self.conv2(out)

        if self.s_block:
            out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = out + identity

        if self.e_block:
            out = self.bn2(out)
            out = self.prelu(out)

        return out
