In [1]:
import torch
import torch.nn as nn

In [2]:
class ResBlock(nn.Module):
    def __init__(self, config, conv2_flag) -> None:
        super().__init__()
        self.conv2_flag = conv2_flag
        self.repeat_time = config[4]
        self.config = config


        self.first_conv1 = nn.Conv2d(self.config[0],
                        self.config[1],
                        (1,1),
                        1,
                        (0,0))
        self.later_conv1 = nn.Conv2d(self.config[3],
                        self.config[1],
                        (1,1),
                        1,
                        (0,0))
        # conv2 does not need do downsampling
        # else one of the conv3x3s will have stride == 2 to do downsampling
        if self.conv2_flag == False:
            self.conv2_downsample = nn.Conv2d(self.config[1],
                                            self.config[2],
                                            (3,3),
                                            2,
                                            (1,1))
        self.conv2 = nn.Conv2d(self.config[1],
                                        self.config[2],
                                        (3,3),
                                        1,
                                        (1,1))
        self.conv3 = nn.Conv2d(self.config[2],
                                        self.config[3],
                                        (1,1),
                                        1,
                                        (0,0))
    
    def _first_subblock(self, x, inchannel, outchannel):
        # if the inchannel == block output channel
        #      do elementwise add
        # else
        #      do conv1x1 to modify input channel then elementwise add
        # TIPS:
        # for conv2 block  
        #     1. all subblock can elementwise
        #     2. without any downsampling 
        # for other convN (N != 2), 
        #  first subblock: do conv1x1 and downsampling
        #  other subblocks: elementwise add
        _ = self.first_conv1(x)
        _ = self.conv2_downsample(_)
        _ = self.conv3(_)
        conv = nn.Conv2d(inchannel, outchannel, (2,2), 2, (0,0))
        return _ + conv(x) 
        
    def _later_subblock(self, x):
        _ = self.later_conv1(x)
        _ = self.conv2(_)
        _ = self.conv3(_)
        return _ + x           

    def _conv2_first_subblock(self, x, inchannel, outchannel):
        _ = self.first_conv1(x)
        _ = self.conv2(_)
        _ = self.conv3(_)
        conv = nn.Conv2d(inchannel, outchannel, (1,1), 1, (0,0))
        x = _ + conv(x)

        
        return x

    def forward(self, x):
        if self.conv2_flag ==False:
            x = self._first_subblock(x, self.config[0], self.config[3])
            for i in range(self.repeat_time - 1):
                x = self._later_subblock(x)
            return x
        # if conv2_flag == True
        # first subblock do only the noneelementwise add
        x = self._conv2_first_subblock(x, self.config[0], self.config[3]) 
        for i in range(self.repeat_time):
            x = self._later_subblock(x)
        return x        

In [3]:
class ResNet(nn.Module):
    def __init__(self, inchannel, num_classes) -> None:
        super().__init__()
        self.config = {
        # in_channel conv1channel conv2channel conv3channel repeat_time
        # below: arch of res50
            'conv2_x': [64, 64, 64, 256, 3],
            'conv3_x': [256, 128, 128, 512, 4],
            'conv4_x': [512, 256, 256, 1024, 6],
            'conv5_x': [1024, 512, 512, 2048, 3]
                }                
        self.conv1 = nn.Conv2d(inchannel, 64, (7,7), 2, (3,3))
        self.max_pool = nn.MaxPool2d((3,3), 2, (1,1))
        
        self.Conv2_X = ResBlock(self.config["conv2_x"], True)
        self.Conv3_X = ResBlock(self.config["conv3_x"], False)
        self.Conv4_X = ResBlock(self.config["conv4_x"], False)
        self.Conv5_X = ResBlock(self.config["conv5_x"], False)
        self.avg_pool = nn.AvgPool2d((7,7), 1, (0,0))
        self.fcs = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(2048, 1000),
        )
    def forward(self, x):
        x = self.max_pool(self.conv1(x))
        x = self.Conv2_X(x)
        x = self.Conv3_X(x)
        x = self.Conv4_X(x)
        x = self.Conv5_X(x)
        x = self.avg_pool(x)
        x = x.reshape((x.shape[0],-1)) 
        x = self.fcs(x)
        return x
        

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
net = ResNet(3, 1000)
x = torch.rand((16, 3, 224, 224))
print(net(x).shape)

torch.Size([16, 1000])
