In [None]:
import torch.nn as nn

In [None]:
def conv3x3(in_planes, out_planes, stride=1):
    conv3x3 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
    return conv3x3

def conv1x1(in_planes, out_planes, stride=1):
    conv1x1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
    return conv1x1

In [None]:
''' modified from BasicBlock in iResNet.ipynb'''

class iBasic(nn.Module):
    exp_block = 1

    def __init__(self, inplanes, outplanes, stride=1, downsample=None, nm_layer=None, s_block=False, e_block=False, exd_bn0=False):
        super(iBasic, self).__init__()
        if nm_layer is None:
            nm_layer = nn.BatchNorm2d
        if not s_block and not exd_bn0:
            self.bn0 = nm_layer(inplanes)

        self.conv1 = conv3x3(inplanes, outplanes, stride)
        self.bn1 = nm_layer(outplanes)
        self.prelu = nn.PReLU()
        self.conv2 = conv3x3(outplanes, outplanes)

        if s_block:
            self.bn2 = nm_layer(outplanes)

        if e_block:
            self.bn2 = nm_layer(outplanes)

        self.downsample = downsample
        self.stride = stride

        self.s_block = s_block
        self.e_block = e_block
        self.exd_bn0 = exd_bn0

    def forward(self, x):
        identity = x

        if self.s_block:
            out = self.conv1(x)
        elif self.exd_bn0:
            out = self.prelu(x)
            out = self.conv1(out)
        else:
            out = self.bn0(x)
            out = self.prelu(out)
            out = self.conv1(out)

        out = self.bn1(out)
        out = self.prelu(out)

        out = self.conv2(out)

        if self.s_block:
            out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = out + identity

        if self.e_block:
            out = self.bn2(out)
            out = self.prelu(out)

        return out
