In [None]:
import torch.nn as nn
import torch

In [None]:
def conv3x3(in_planes, out_planes, stride=1):
    conv3x3 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
    return conv3x3

def conv1x1(in_planes, out_planes, stride=1):
    conv1x1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
    return conv1x1

class FReLU(nn.Module):
    def __init__(self, dim_in):
        super().__init__()
        self.dw_conv3x3 = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, 3, padding=1, groups=dim_in, bias=False),
            nn.BatchNorm2d(dim_in))

    def forward(self, input):
        output = self.dw_conv3x3(input)
        output = torch.max(input, output)
        return output

In [None]:
class SE_block(nn.Module):
    def __init__(self, inplanes):
        super(SE_block, self).__init__()
        self.se_conv1 = conv1x1(inplanes, inplanes//16)
        self.se_conv2 = conv1x1(inplanes//16, inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        out = self.avgpool(x)
        out = self.se_conv1(out)
        out = self.relu(out)
        out = self.se_conv2(out)
        out = self.sigmoid(out)
        return x * out

class SE_iBasic_F(nn.Module):
    exp_block = 1

    def __init__(self, inplanes, outplanes, stride=1, downsample=None, nm_layer=None, s_block=False, e_block=False, exd_bn0=False):
        super(SE_iBasic_F, self).__init__()
        if nm_layer is None:
            nm_layer = nn.BatchNorm2d
        if not s_block and not exd_bn0:
            self.bn0 = nm_layer(inplanes)

        self.conv1 = conv3x3(inplanes, outplanes, stride)
        self.bn1 = nm_layer(outplanes)
        self.frelu1 = FReLU(inplanes)
        self.frelu2 = FReLU(outplanes)
        self.conv2 = conv3x3(outplanes, outplanes)
        self.se = SE_block(outplanes)

        if s_block:
            self.bn2 = nm_layer(outplanes)

        if e_block:
            self.bn2 = nm_layer(outplanes)

        self.downsample = downsample
        self.stride = stride

        self.s_block = s_block
        self.e_block = e_block
        self.exd_bn0 = exd_bn0

    def forward(self, x):
        identity = x

        if self.s_block:
            out = self.conv1(x)
        elif self.exd_bn0:
            out = self.frelu1(x)
            out = self.conv1(out)
        else:
            out = self.bn0(x)
            out = self.frelu2(out)
            out = self.conv1(out)

        out = self.bn1(out)
        out = self.frelu2(out)

        out = self.conv2(out)

        if self.s_block:
            out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.se(out)  
        out = out + identity

        if self.e_block:
            out = self.bn2(out)
            out = self.frelu2(out)

        return out
