In [14]:
#|default_exp resnet

In [2]:
#|export

import torch
from torch import nn
from typing import Optional, List

In [None]:
class NaiveLayer(nn.Module):
    def __init__(self, F, in_chan, out_chan, stride, padding, block:"NaiveBlock"=None, set_residuals=False):
        super().__init__()

        self.conv = nn.Conv2d(in_chan, out_chan, F, stride, padding)
        self.batch_norm = nn.BatchNorm2d(out_chan, 1e-4, 0.1)
        self.relu = nn.ReLU()
        self.block = block
        self.set_residuals = set_residuals

    def forward(self, X):
        X = self.conv(X)
        X = self.batch_norm(X)

        if self.block is not None and self.block.residual is not None:
            X += self.block.residual

        if self.set_residuals == True:
            self.block.residual = X
        
        X = self.relu(X)

        return X
    

class NaiveBlock(nn.Module):
    residual: Optional[torch.Tensor]
    def __init__(self,  init_chan, in_chan, Fs=[], stride=1, paddings=[], expansion=4, num_repeats=1):
        super().__init__()

        self.init_chan = init_chan
        
        sequence = []

        self.register_buffer("residual", None)
        # TODO: ensure that the residual from the previous block is passed onto this block

        for i in range(num_repeats):
            for idx, F in enumerate(Fs):
                if expansion is not None and idx == len(Fs) - 1:
                    sequence.append(NaiveLayer(F, in_chan, in_chan*expansion, stride, paddings[idx], block=self))
                    self.init_chan = in_chan*expansion
                elif idx == 0 and i == 0:
                    sequence.append(NaiveLayer(F, self.init_chan, in_chan, stride, paddings[idx], self, set_residuals=True))
                elif idx == 0:
                    sequence.append(NaiveLayer(F, self.init_chan, in_chan, stride, paddings[idx]))
                else:
                    sequence.append(NaiveLayer(F, in_chan, in_chan, stride, paddings[idx]))

        self.sequence = nn.Sequential(*sequence)

    def forward(self, X):
        return self.sequence(X)


class NaiveResNet50(nn.Module):

    def __init__(self):
        super().__init__()

        self.register_buffer("global_residual", None)

        self.conv1 = NaiveLayer(7, in_chan=3, out_chan=64, stride=2, padding=3)
        self.conv2_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2 = NaiveBlock(init_chan=64, in_chan=64, Fs=[1,3,1], stride=1, paddings=[0,1,0], expansion=4, num_repeats=3)
        self.conv3 = NaiveBlock(init_chan=256, in_chan=128, Fs=[1,3,1], stride=1, paddings=[0,1,0], expansion=4, num_repeats=4)
        self.conv4 = NaiveBlock(init_chan=512, in_chan=256, Fs=[1,3,1], stride=1, paddings=[0,1,0], expansion=4, num_repeats=6)
        self.conv5 = NaiveBlock(init_chan=1024, in_chan=512, Fs=[1,3,1], stride=1, paddings=[0,1,0], expansion=4, num_repeats=3)
        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(512*4, 1000) 

    
    def forward(self, X):
        X = self.conv1(X)
        X = self.conv2_pool(X)
        X = self.conv2(X)
        X = self.conv3(X)
        X = self.conv4(X)
        X = self.conv5(X)
        X = self.avg_pool(X)
        X = self.flatten(X)
        X = self.fc(X)
        return X


In [None]:

class block(nn.Module):
    def __init__(self, out_channels, repeats):
        super().__init__()

        self.sequence = nn.Sequential(
            *[nn.Conv2d(56 * 56 * out_channels, 56 * 56 * out_channels, kernel_size=1),
            nn.Conv2d(56 * 56 * out_channels, 56 * 56 * out_channels, kernel_size=3, padding=1),
            nn.Conv2d(56 * 56 * out_channels, 56 * 56 * out_channels*4, kernel_size=1) for i in range(repeats)],
        )

    def forward(self, X):
        return self.sequence(X)

class ResNet50(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(
                224 * 224 * 3, 112 * 112 * 64, kernel_size=7, stride=2, padding=3
            )
        self.bridge_pool = nn.MaxPool2d(kernel_size=3, stride=2)

        self.conv2 = block(64, 3)
        self.conv3 = block(128, 4)
        self.conv4 = block(256, 6)
        self.conv5 = block(512, 3)
        self.avg_pool = nn.AvgPool2d()

    def forward(self, X):
        return (chain(X)
                (self.conv1)
                (self.bridge_pool)
                (self.conv1)
                (self.conv2)
                (self.conv3)
                (self.conv4)
                (self.conv5)
                (self.avg_pool)())

In [4]:
max([1,2,1])

2

In [160]:
# |export


class Layer:
    def __init__(
        self,
        repeats: int,
        # inner loop
        out_chans: List[int],
        kernel_sizes: List[int],
        strides: List[int],
        paddings: List[int],
    ):
        self.repeats = repeats

        # inner loop
        self.out_chans = out_chans
        self.kernel_sizes = kernel_sizes
        self.strides = strides
        self.paddings = paddings


class ResNet(nn.Module):
    def __init__(self, layers: List[Layer]):
        super().__init__()

        self.relu = nn.ReLU()

        # Initial
        self.conv_i = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn_conv_i = nn.BatchNorm2d(64)
        self.max_pool_i = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.bn_max_i = nn.BatchNorm2d(64)

        resnet_layers = []

        curr_chan = 64

        # Resnet
        for layer_idx, layer in enumerate(layers):
            for i in range(layer.repeats):
                strides = (
                    [1, 1, 1] if layer_idx == 0 or i != 0 else [1, 2, 1]
                )  # TODO: remove hardcoding like this
                
                identity_chans = None

                resnet_block = []
                # inner loop
                for j, _ in enumerate(layer.out_chans):
                    if j == 0:
                        in_chan = curr_chan  # input is the output of previous layer （from prev block)
                        identity_chans = curr_chan
                    else:
                        in_chan = layer.out_chans[
                            j - 1
                        ]  # input is the output of previous layer

                    conv_name = f"{layer_idx}_{i}_{j}_conv"
                    bn_name = f"{layer_idx}_{i}_{j}_bn"

                    setattr(self, conv_name, nn.Conv2d(
                            in_chan,  # [64,64,64], [256,64,64] | [256,128,128], [512,128,128]
                            layer.out_chans[j],  # [64,64,256] | [128,128,512]
                            layer.kernel_sizes[j],  # [1,3,1]
                            strides[j],  # [1,1,1] | [1,2,1]
                            layer.paddings[j],  # [0,1,0]
                        ))
                    resnet_block.append(
                        conv_name
                    )
                    setattr(self, bn_name, nn.BatchNorm2d(layer.out_chans[j]))
                    resnet_block.append(bn_name)

                    if j < len(layer.out_chans) - 1:
                        resnet_block.append("relu")

                    curr_chan = layer.out_chans[j]

                #   output         residuals
                if curr_chan != identity_chans:
                    identity_downsample = [
                        nn.Conv2d(
                            in_channels=identity_chans,
                            out_channels=curr_chan,
                            kernel_size=1,
                            stride=max(
                                strides
                            ),  # if stride 2, then output is halved, so need to halve identity too
                        ),
                        nn.BatchNorm2d(curr_chan),
                    ]

                    ds_name = f"{layer_idx}_{i}_ds"
                    setattr(self, ds_name, nn.Sequential(*identity_downsample))
                    resnet_block.append(ds_name)
                else:
                    resnet_block.append(f"{layer_idx}_{i}_identity")

                resnet_block.append("relu")
                resnet_layers.append(resnet_block)

        self.resnet_layers = resnet_layers

        # Ending
        self.avg_pool_e = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * 4, 1000)

    def forward(self, X: torch.Tensor):
        X = self.conv_i(X)
        X = self.bn_conv_i(X)
        X = self.max_pool_i(X)
        X = self.bn_max_i(X)

        identity = None
        for b_i, block in enumerate(self.resnet_layers):
            for i, net_name in enumerate(block):
                net = getattr(self, net_name, None)

                if i == 0:
                    identity = X

                if "ds" in net_name:
                    downsampled_identity = net(identity)
                    X += downsampled_identity
                    continue

                if "identity" in net_name and net is None:
                    X += identity
                    continue

                print(b_i, i, X.shape, net)
                X = net(X)

        X = self.avg_pool_e(X)
        # [N, 2048, 1, 1]

        X = X.flatten(1,3)
        # [N, 2048]

        X = self.fc(X)
        # [N, 1000]

        return X

In [161]:
#|export

model = ResNet(
    [
        Layer(3, [64, 64, 256], [1, 3, 1], [1, 2, 1], [0, 1, 0]),
        Layer(4, [128, 128, 512], [1, 3, 1], [1, 2, 1], [0, 1, 0]),
        Layer(6, [256, 256, 1024], [1, 3, 1], [1, 2, 1], [0, 1, 0]),
        Layer(3, [512, 512, 2048], [1, 3, 1], [1, 2, 1], [0, 1, 0]),
    ]
)

In [162]:
model.resnet_layers

[['0_0_0_conv',
  '0_0_0_bn',
  'relu',
  '0_0_1_conv',
  '0_0_1_bn',
  'relu',
  '0_0_2_conv',
  '0_0_2_bn',
  '0_0_ds',
  'relu'],
 ['0_1_0_conv',
  '0_1_0_bn',
  'relu',
  '0_1_1_conv',
  '0_1_1_bn',
  'relu',
  '0_1_2_conv',
  '0_1_2_bn',
  '0_1_identity',
  'relu'],
 ['0_2_0_conv',
  '0_2_0_bn',
  'relu',
  '0_2_1_conv',
  '0_2_1_bn',
  'relu',
  '0_2_2_conv',
  '0_2_2_bn',
  '0_2_identity',
  'relu'],
 ['1_0_0_conv',
  '1_0_0_bn',
  'relu',
  '1_0_1_conv',
  '1_0_1_bn',
  'relu',
  '1_0_2_conv',
  '1_0_2_bn',
  '1_0_ds',
  'relu'],
 ['1_1_0_conv',
  '1_1_0_bn',
  'relu',
  '1_1_1_conv',
  '1_1_1_bn',
  'relu',
  '1_1_2_conv',
  '1_1_2_bn',
  '1_1_identity',
  'relu'],
 ['1_2_0_conv',
  '1_2_0_bn',
  'relu',
  '1_2_1_conv',
  '1_2_1_bn',
  'relu',
  '1_2_2_conv',
  '1_2_2_bn',
  '1_2_identity',
  'relu'],
 ['1_3_0_conv',
  '1_3_0_bn',
  'relu',
  '1_3_1_conv',
  '1_3_1_bn',
  'relu',
  '1_3_2_conv',
  '1_3_2_bn',
  '1_3_identity',
  'relu'],
 ['2_0_0_conv',
  '2_0_0_bn',
  'relu',

In [163]:
#|export

# testing shapes
img = torch.randn(1, 3, 224, 224)
res = model(img)
res.shape


0 0 torch.Size([1, 64, 56, 56]) Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
0 1 torch.Size([1, 64, 56, 56]) BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
0 2 torch.Size([1, 64, 56, 56]) ReLU()
0 3 torch.Size([1, 64, 56, 56]) Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
0 4 torch.Size([1, 64, 56, 56]) BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
0 5 torch.Size([1, 64, 56, 56]) ReLU()
0 6 torch.Size([1, 64, 56, 56]) Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
0 7 torch.Size([1, 256, 56, 56]) BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
0 9 torch.Size([1, 256, 56, 56]) ReLU()
1 0 torch.Size([1, 256, 56, 56]) Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
1 1 torch.Size([1, 64, 56, 56]) BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
1 2 torch.Size([1, 64, 56, 56]) ReLU()
1 3 torch.Size([1, 64, 56, 56]) Conv2d(64, 64,

torch.Size([1, 1000])

In [106]:
from nbdev.export import nb_export 

nb_export('resnet.ipynb', '../src/')