This is a implementation ResNet-50 with as much verbose as possible. I wrote this code while trying to understand ResNet. To make sure it is correct, pretrained weights of ResNet50 are loaded from PyTorch model.

In [1]:
# imports
import numpy as np
import torch
import torch.nn as nn
from collections import OrderedDict

from torchvision import models

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # select device

In [3]:
# define res1 block
class Layer1(nn.Module):
    def __init__( self):
        super(Layer1, self).__init__()
        self.block1 = nn.Sequential(
                nn.Conv2d(64, 64, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(),

                nn.Conv2d(64, 64, kernel_size=(3, 3), padding=(1,1), bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(),

                nn.Conv2d(64, 256, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(256),
            )
     
        # downsample 1
        self.downsample1 = nn.Sequential(
                nn.Conv2d(64, 256, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(256)
        )       
        
        self.block2 = nn.Sequential(
                nn.Conv2d(256, 64, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(),

                nn.Conv2d(64, 64, kernel_size=(3, 3), padding=(1,1), bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(),

                nn.Conv2d(64, 256, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(256),
            )
        
        self.block3 = nn.Sequential(
                nn.Conv2d(256, 64, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(),

                nn.Conv2d(64, 64, kernel_size=(3, 3), padding=(1,1), bias=False),
                nn.BatchNorm2d(64),
                nn.ReLU(),

                nn.Conv2d(64, 256, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(256),
            )

In [4]:
# define res2 block
class Layer2(nn.Module):
    def __init__( self):
        super(Layer2, self).__init__()
        self.block1 = nn.Sequential(
                nn.Conv2d(256, 128, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                
                nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                
                nn.Conv2d(128, 512, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(512),
            )

        # downsample 2
        self.downsample2 = nn.Sequential(
                nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False),
                nn.BatchNorm2d(512)
        )
        
        self.block2 = nn.Sequential(
                nn.Conv2d(512, 128, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                
                nn.Conv2d(128, 128, kernel_size=(3, 3), padding=(1, 1), bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                
                nn.Conv2d(128, 512, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(512),
            )
        
        self.block3 = nn.Sequential(
                nn.Conv2d(512, 128, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                
                nn.Conv2d(128, 128, kernel_size=(3, 3), padding=(1,1), bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                
                nn.Conv2d(128, 512, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(512),
            )
        
        self.block4 = nn.Sequential(
                nn.Conv2d(512, 128, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                
                nn.Conv2d(128, 128, kernel_size=(3, 3), padding=(1,1), bias=False),
                nn.BatchNorm2d(128),
                nn.ReLU(),
                
                nn.Conv2d(128, 512, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(512),
            )

In [5]:
# define res3 block
class Layer3(nn.Module):
    def __init__( self):
        super(Layer3, self).__init__()
        self.block1 = nn.Sequential(
                nn.Conv2d(512, 256, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                
                nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                
                nn.Conv2d(256, 1024, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(1024),
            )
 
        # downsample 3
        self.downsample3 = nn.Sequential(
                nn.Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False),
                nn.BatchNorm2d(1024)
        )

        self.block2 = nn.Sequential(
                nn.Conv2d(1024, 256, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                
                nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                
                nn.Conv2d(256, 1024, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(1024),
            )
        
        self.block3 = nn.Sequential(
                nn.Conv2d(1024, 256, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                
                nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                
                nn.Conv2d(256, 1024, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(1024),
            )
            
        self.block4 = nn.Sequential(
                nn.Conv2d(1024, 256, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                
                nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                
                nn.Conv2d(256, 1024, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(1024),
            )
        
        self.block5 = nn.Sequential(
                nn.Conv2d(1024, 256, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                
                nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                
                nn.Conv2d(256, 1024, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(1024),
            )
          
        self.block6 = nn.Sequential(
                nn.Conv2d(1024, 256, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                
                nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                
                nn.Conv2d(256, 1024, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(1024),
            )

In [6]:
# define res4 block
class Layer4(nn.Module):
    def __init__( self):
        super(Layer4, self).__init__()
        
        self.block1 = nn.Sequential(
                nn.Conv2d(1024, 512, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU(),
                
                nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU(),
                
                nn.Conv2d(512, 2048, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(2048),
            )

        # downsample 4
        self.downsample4 = nn.Sequential(
                nn.Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False),
                nn.BatchNorm2d(2048)
        )        
        
        self.block2 = nn.Sequential(
                nn.Conv2d(2048, 512, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU(),
                
                nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU(),
                
                nn.Conv2d(512, 2048, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(2048),
            )
        
        self.block3 = nn.Sequential(
                nn.Conv2d(2048, 512, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU(),
                
                nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU(),
                
                nn.Conv2d(512, 2048, kernel_size=(1, 1), bias=False),
                nn.BatchNorm2d(2048),
            )

In [7]:
# define the ResNet50 model using the above blocks
class ResNet50(nn.Module):
    def __init__(self):
        super(ResNet50, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.act1 = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = Layer1()
        self.layer2 = Layer2()
        self.layer3 = Layer3()
        self.layer4 = Layer4()
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(2048, 1000)
        
        self.relu = nn.ReLU()
    
    def forward(self, x):
        # stem
        x = self.act1(self.bn1(self.conv1(x)))
        
        # maxpool
        x = self.maxpool(x)
        
        # layer 1_1
        x1_1 = self.layer1.block1(x)
        x1_1 = self.relu(x1_1 + self.layer1.downsample1(x)) # shortcut
        
        # layer 1_2
        x1_2 = self.layer1.block2(x1_1)
        x1_2 = self.relu(x1_2 + x1_1) # shortcut

        # layer 1_3
        x1_3 = self.layer1.block3(x1_2)
        x1_3 = self.relu(x1_3 + x1_2) # shortcut

        # layer 2_1
        x2_1 = self.layer2.block1(x1_3)
        x2_1 = self.relu(x2_1 + self.layer2.downsample2(x1_3)) # shortcut

        # layer 2_2
        x2_2 = self.layer2.block2(x2_1)
        x2_2 = self.relu(x2_2 + x2_1) # shortcut

        # layer 2_3
        x2_3 = self.layer2.block3(x2_2)
        x2_3 = self.relu(x2_3 + x2_2) # shortcut
        
        # layer 2_4
        x2_4 = self.layer2.block4(x2_3)
        x2_4 = self.relu(x2_4 + x2_3) # shortcut
        
        # layer 3_1
        x3_1 = self.layer3.block1(x2_4)
        x3_1 = self.relu(x3_1 + self.layer3.downsample3(x2_4)) # shortcut

        # layer 3_2
        x3_2 = self.layer3.block2(x3_1)
        x3_2 = self.relu(x3_2 + x3_1) # shortcut

        # layer 3_3
        x3_3 = self.layer3.block3(x3_2)
        x3_3 = self.relu(x3_3 + x3_2) # shortcut
        
        # layer 3_4
        x3_4 = self.layer3.block4(x3_3)
        x3_4 = self.relu(x3_4 + x3_3) # shortcut

        # layer 3_5
        x3_5 = self.layer3.block5(x3_4)
        x3_5 = self.relu(x3_5 + x3_4) # shortcut
        
        # layer 3_6
        x3_6 = self.layer3.block6(x3_5)
        x3_6 = self.relu(x3_6 + x3_5) # shortcut

        # layer 4_1
        x4_1 = self.layer4.block1(x3_6)
        x4_1 = self.relu(x4_1 + self.layer4.downsample4(x3_6)) # shortcut

        # layer 4_2
        x4_2 = self.layer4.block2(x4_1)
        x4_2 = self.relu(x4_2 + x4_1) # shortcut
        
        # layer 4_3
        x4_3 = self.layer4.block3(x4_2)
        x4_3 = self.relu(x4_3 + x4_2) # shortcut
        
        cp = self.avg_pool(x4_3).reshape(x.shape[0], -1)
        
        l = self.linear(cp)
        
        return l

In [8]:
# test
r50 = models.resnet50(pretrained=True).to(device) # model from PyTorch
model = ResNet50().to(device) # our custom model

r50_state = r50.state_dict()
model_state = model.state_dict()

model_state_keys = list(model_state.keys())

model_state_new = OrderedDict()

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

In [9]:
def count_params(model):
    return sum(np.prod(p.size()) for p in model.parameters() if p.requires_grad)

In [10]:
count_params(r50), count_params(model) # both have equal number of parameters

(25557032, 25557032)

In [11]:
for i, k in enumerate(r50_state.keys()):
    model_state_new[model_state_keys[i]] = r50_state[k]

In [12]:
model.load_state_dict(model_state_new) # load PyTorch model state into our model

<All keys matched successfully>

In [13]:
input_tensor = torch.randn(1, 3, 224, 224).to(device)
r50.eval()
model.eval()

with torch.no_grad():
    resnet_output = r50(input_tensor)
    custom_model_output = model(input_tensor)

In [14]:
(resnet_output == custom_model_output).sum() # to check if both the models are generating same predictions

tensor(1000)