In [1]:
import torch
import torch.nn as nn

<img src="./resnet.png">

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, inchannels, outchannels, kernel_size, stride, padding) -> None:
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(inchannels, outchannels, kernel_size, stride, padding)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(outchannels)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

<img src="./residual_connection.png" height=300>

there are two type of residual connection:
1. if two has same channel, do elementwise add
2. otherwise, first do conv1x1 to obtain same channel, then do elementwise add

in each ResLayer (three convs as a block, block * repeat time as a layer)
1. in first block's conv3x3, we need do downsampling
2. in first block, we need do conv1x1 to obtain the same channel, mind that the img_size is also different
3. in later block, we can do elementwise add

TIPS:
in Conv2_X:
its conv3x3 do not need do downsampling


In [2]:
class ResBlock(nn.Module):
    def __init__(self, config, conv2_flag) -> None:
        super().__init__()
        self.conv2_flag = conv2_flag
        self.repeat_time = config[4]
        self.config = config


        self.first_conv1 =ConvBlock(self.config[0], self.config[1], (1,1), 1, (0,0))
        self.later_conv1 =ConvBlock(self.config[3], self.config[1], (1,1), 1, (0,0))

        # conv2_X's 3x3 does not need do downsampling
        # but in later conv3/4/5_X
        # one of the conv3x3s will have stride == 2 to do downsampling
        # conv2_flag to identify conv2_X
        if self.conv2_flag == False:
            self.conv2_downsample = ConvBlock(self.config[1], self.config[2], (3,3), 2, (1,1))
        self.conv2 = ConvBlock(self.config[1], self.config[2], (3,3), 1,(1,1))

        self.conv3 = ConvBlock(self.config[2], self.config[3], (1,1), 1, (0,0))


    # Conv3/4/5_X's first block
    def _first_subblock(self, x, inchannel, outchannel):
        _ = self.first_conv1(x)
        # do downsampling
        _ = self.conv2_downsample(_)
        _ = self.conv3(_)
        # do conv and then element wise add
        # mind this conv do downsampling
        conv = ConvBlock(inchannel, outchannel, (4,4), 2, (1,1))
        return _ + conv(x)
        
    # Conv2/3/4/5_X's later block
    def _later_subblock(self, x):
        _ = self.later_conv1(x)
        _ = self.conv2(_)
        _ = self.conv3(_)
        return _ + x           

    # Conv2_X's first block
    def _conv2_first_subblock(self, x, inchannel, outchannel):
        _ = self.first_conv1(x)
        _ = self.conv2(_)
        _ = self.conv3(_)
        # this conv do not do downsampling
        conv = ConvBlock(inchannel, outchannel, (1,1), 1, (0,0))
        return _ + conv(x)


    def forward(self, x):
        # Conv3/4/5_X
        if self.conv2_flag ==False:
            x = self._first_subblock(x, self.config[0], self.config[3])
            for _ in range(self.repeat_time - 1):
                x = self._later_subblock(x)
            return x
        # Conv2_X
        x = self._conv2_first_subblock(x, self.config[0], self.config[3])
        for _ in range(self.repeat_time - 1):
            x = self._later_subblock(x)
        return x        

In [3]:
class ResNet(nn.Module):
    def __init__(self, inchannel, num_classes) -> None:
        super().__init__()
        self.config = {
        # in_channel conv1channel conv2channel conv3channel repeat_time
        # below: arch of res50
            'conv2_x': [64, 64, 64, 256, 3],
            'conv3_x': [256, 128, 128, 512, 4],
            'conv4_x': [512, 256, 256, 1024, 6],
            'conv5_x': [1024, 512, 512, 2048, 3]
                }                
        self.conv1 = nn.Conv2d(inchannel, 64, (7,7), 2, (3,3))
        self.bn = nn.BatchNorm2d(64)
        self.max_pool = nn.MaxPool2d((3,3), 2, (1,1))
        
        self.Conv2_X = ResBlock(self.config["conv2_x"], True)
        self.Conv3_X = ResBlock(self.config["conv3_x"], False)
        self.Conv4_X = ResBlock(self.config["conv4_x"], False)
        self.Conv5_X = ResBlock(self.config["conv5_x"], False)
        self.avg_pool = nn.AvgPool2d((7,7), 1, (0,0))
        self.fcs = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(2048, num_classes),
        )
    def forward(self, x):
        # (N, img_channel, 224, 224)
        x = self.max_pool(self.bn(self.conv1(x)))
        # (N, 64, 56, 56)
        x = self.Conv2_X(x)
        # (N, 256, 56, 56)
        x = self.Conv3_X(x)
        # (N, 512, 28, 28)
        x = self.Conv4_X(x)
        # (N, 1024, 14, 14)
        x = self.Conv5_X(x)
        # (N, 2048, 7, 7)
        x = self.avg_pool(x)
        # (N, 2048, 1, 1)
        x = x.reshape((x.shape[0],-1))
        # (N, 2048)
        x = self.fcs(x)
        # (N, num_class)
        return x
        

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
net = ResNet(3, 1000)
x = torch.rand((16, 3, 224, 224))
print(net(x).shape)

torch.Size([16, 1000])
