# ResNet from scatch in PyTorch
> A tutorial on resnet on cifar-10 image classification

- toc: true 
- badges: true
- comments: true
- categories: [jupyter]

### Imports and Notebook setup

In [2]:
from tqdm.notebook import tqdm

In [3]:
import os

if os.environ.get('KAGGLE_KERNEL_RUN_TYPE', 'Localhost') != 'Losthost':
    # library with nice formating
    os.system("pip install rich")

In [4]:
import rich

from collections import OrderedDict

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [7]:
from torch.utils.data import dataset, dataloader, random_split

from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Lambda, Compose

In [8]:
dataset = CIFAR10(root="/kaggle/working/data/",
                        download=True,
                        train=True, transform=ToTensor())
test_dataset = CIFAR10(root="/kaggle/working/data/",
                       download=True, train=False,
                       transform=ToTensor())

# create train, valid datasets
valid_size = int(dataset.data.shape[0]*0.1)
train_size = dataset.data.shape[0] - valid_size

train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /kaggle/working/data/cifar-10-python.tar.gz


0it [00:00, ?it/s]

Extracting /kaggle/working/data/cifar-10-python.tar.gz to /kaggle/working/data/
Files already downloaded and verified


In [9]:
config = OrderedDict()
config['batch_size'] = 128

In [10]:
train_dataloader = dataloader.DataLoader(train_dataset, batch_size=config['batch_size'])
valid_dataloader = dataloader.DataLoader(valid_dataset, batch_size=config['batch_size'])
test_dataloader = dataloader.DataLoader(test_dataset, batch_size=config['batch_size'])

In [11]:
for x, y in train_dataloader:
    rich.print(x.max(), x.min())
    rich.print(y)
    break

### ResNet Block

In [12]:
class ResNetBlock(nn.Module):
    
    def __init__(self, in_channels: int = 32, num_filters: int = 64):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels,
                               out_channels=num_filters,
                               kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_filters)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=num_filters,
                               out_channels=num_filters, 
                               kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_filters)
        self.relu2 = nn.ReLU()
    
    def forward(self, x):
        z = self.conv1(x)
        z = self.relu1(z)
        z = self.bn1(z)
        z = self.conv2(z)
        y = self.relu2(z + x)
        return self.bn2(y)

Check that `ResNetBlock` has same input and output sizes

In [13]:
rblock = ResNetBlock(in_channels=64)
x = torch.randint(0, 100, size=(128, 64, 32, 32), dtype=torch.float32)
y = rblock(x)
assert x.shape == y.shape

`ResNetChangeBlock` implements the ResNet with skip connections when the input and output have different shape

In [14]:
class ResNetChangeBlock(nn.Module):
    
    def __init__(self, in_channels: int = 32, num_filters: int = 64):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = num_filters
        self.conv1 = nn.Conv2d(in_channels=in_channels,
                               out_channels=num_filters,
                               stride=2,
                               kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(num_filters)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=num_filters,
                               out_channels=num_filters, 
                               kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(num_filters)
        self.relu2 = nn.ReLU()
    
    def forward(self, x):
        n, c, w, h = x.shape
        z = self.conv1(x)
        z = self.relu1(z)
        z = self.bn1(z)
        z = self.conv2(z)
        self.zero_identity = nn.Parameter(
            torch.zeros(n, self.out_channels-self.in_channels,
                        int(w//2), int(h//2)),
            requires_grad=False
        )
        x_expand = F.pad(x[:, :, :int(w//2), :int(h//2)], pad=(
            0, 0, 0, 0, self.out_channels-self.in_channels, 0, 0, 0))
        y = self.relu2(z + x_expand)
        return self.bn2(y)

Check the padding operation works as intended, the pad dimensions are left-to-right last-to-first dimension

In [15]:
x = torch.ones([128, 16, 16, 16])
y = torch.ones([128, 32, 16, 16])
z = F.pad(x, (0, 0, 0, 0, 16, 0, 0, 0))
assert z.shape == y.shape

Check that `ResNetChangeBlock` has half receptor field

In [16]:
rblock2 = ResNetChangeBlock(in_channels=64, num_filters=128)
x = torch.randint(0, 100, size=(128, 64, 32, 32), dtype=torch.float32)
y = rblock2(x)
assert y.shape == (128, 128, 16, 16)

`ResNet` with variable `6n+2` number of layers and skip connections `2n` skip connections, Similar to CIFAR10 network from the [resnet paper](https://arxiv.org/abs/1512.03385#).

In [17]:
class ResNet(nn.Module):
    
    def __init__(self, n: int = 3, num_classes: int = 10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16,
                               kernel_size=3, padding=1)
        self.resblock1_list = nn.ModuleList([
            ResNetBlock(in_channels=16, num_filters=16) for _ in range(n)
        ])
        self.resblock2_list = nn.ModuleList([
                ResNetChangeBlock(in_channels=16, num_filters=32)
            ] + [
            ResNetBlock(in_channels=32, num_filters=32) for _ in range(n-1)
        ])
        self.resblock3_list = nn.ModuleList([
            ResNetChangeBlock(in_channels=32, num_filters=64)
        ] + [
            ResNetBlock(in_channels=64, num_filters=64) for _ in range(n-1)
        ])
        self.linear1 = nn.Linear(in_features=64,
                                 out_features=num_classes)
        
    def forward(self, x):
        x = self.conv1(x)
        for layer1 in self.resblock1_list:
            x = layer1(x)
        for layer2 in self.resblock2_list:
            x = layer2(x)
        for layer3 in self.resblock3_list:
            x = layer3(x)
        x = torch.mean(x, dim=[2, 3])
        x = self.linear1(x)
        return x

Check `ResNet` has last dimension as `num_classes`

In [18]:
n = 3
num_classes = 10
batch_size = 128
res_net = ResNet(n=n, num_classes=num_classes)
x = torch.randint(0, 100, size=(batch_size, 3, 32, 32), dtype=torch.float32)
y = res_net(x)
assert y.shape == (batch_size, num_classes)

### Define Train and Eval loop

In [19]:
config['device'] = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {config['device']}")

# define the network
config['n'] = 3
num_classes = 10
res_net = ResNet(n=config['n'], num_classes=num_classes)
res_net = res_net.to(config['device'])

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=res_net.parameters())

device: cuda


In [20]:
rich.print(res_net)

In [21]:
def train_fn(model: nn.Module, 
             train_dataloader: dataloader.DataLoader,
             loss_fn, optimizer, epoch: int, device: str):
    
    model.train()
    
    for batch_index, (X, y) in tqdm(enumerate(train_dataloader),
                                    total=len(train_dataloader.dataset.indices)//config['batch_size']):
        
        X_data, y_data = X.to(device), y.to(device)
        y_pred = model(X_data)
        loss = loss_fn(y_pred, y_data)
        
        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    loss = loss.item()
    accu = torch.argmax(y_pred, dim=-1) == y_data
    accu = sum(accu)/len(accu)
    accu = accu.item()
    print("loss="+str(loss), end=' ')
    print("accu="+str(accu), end=' ')


In [22]:
def test_fn(model: nn.Module,
            test_dataloader: dataloader.DataLoader,
            epoch: int, device: str, dataset: str):
    
    loss = 0
    results = []
    model.eval()
    with torch.no_grad():
        for batch_index, (X, y) in tqdm(enumerate(test_dataloader)):
            X_data, y_data = X.to(device), y.to(device)
            y_pred = model(X_data)
            loss += loss_fn(y_pred, y_data).item()
            results.extend(torch.argmax(y_pred, dim=-1) == y_data)
    
    test_accu = sum(results)/len(results)
    test_accu = test_accu.item()
    test_loss = loss/len(results)
    print(dataset + "_loss: " + str(test_loss), end=' ')
    print(dataset + "_accu: " + str(test_accu), end=' ')
    return False

### Train and Track the Model

In [23]:
config['num_epochs'] = 25

for epoch in tqdm(range(config['num_epochs'])):
    print("epoch: " + str(epoch), end=' ')
    train_fn(res_net, train_dataloader, loss_fn, optimizer, epoch, config['device'])
    test_fn(res_net, valid_dataloader, epoch, config['device'], dataset="valid")    
    print()

  0%|          | 0/25 [00:00<?, ?it/s]

epoch: 0 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=1.1741900444030762 accu=0.5 

0it [00:00, ?it/s]

valid_loss: 0.009578690445423126 valid_accu: 0.5667999982833862 
epoch: 1 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.9159127473831177 accu=0.6805555820465088 

0it [00:00, ?it/s]

valid_loss: 0.007797019159793854 valid_accu: 0.6571999788284302 
epoch: 2 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.7329821586608887 accu=0.7083333134651184 

0it [00:00, ?it/s]

valid_loss: 0.006730758094787598 valid_accu: 0.7023999691009521 
epoch: 3 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.5593144297599792 accu=0.7777777910232544 

0it [00:00, ?it/s]

valid_loss: 0.006174366390705109 valid_accu: 0.7310000061988831 
epoch: 4 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.500760018825531 accu=0.7777777910232544 

0it [00:00, ?it/s]

valid_loss: 0.006064797937870026 valid_accu: 0.7387999892234802 
epoch: 5 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.3587251901626587 accu=0.8611111044883728 

0it [00:00, ?it/s]

valid_loss: 0.00611595823764801 valid_accu: 0.7457999587059021 
epoch: 6 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.2707603871822357 accu=0.9166666865348816 

0it [00:00, ?it/s]

valid_loss: 0.006046025335788727 valid_accu: 0.7501999735832214 
epoch: 7 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.2549777626991272 accu=0.9027777910232544 

0it [00:00, ?it/s]

valid_loss: 0.00602523101568222 valid_accu: 0.7601999640464783 
epoch: 8 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.255811482667923 accu=0.9305555820465088 

0it [00:00, ?it/s]

valid_loss: 0.005957237935066223 valid_accu: 0.7705999612808228 
epoch: 9 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.1865358203649521 accu=0.9166666865348816 

0it [00:00, ?it/s]

valid_loss: 0.006287951052188873 valid_accu: 0.7590000033378601 
epoch: 10 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.2165629267692566 accu=0.9166666865348816 

0it [00:00, ?it/s]

valid_loss: 0.006590583300590515 valid_accu: 0.7545999884605408 
epoch: 11 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.23004494607448578 accu=0.875 

0it [00:00, ?it/s]

valid_loss: 0.007091660153865814 valid_accu: 0.7545999884605408 
epoch: 12 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.1190534457564354 accu=0.9583333134651184 

0it [00:00, ?it/s]

valid_loss: 0.007388375395536423 valid_accu: 0.7513999938964844 
epoch: 13 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.08894743770360947 accu=0.9722222089767456 

0it [00:00, ?it/s]

valid_loss: 0.007570652145147324 valid_accu: 0.7551999688148499 
epoch: 14 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.13272982835769653 accu=0.944444477558136 

0it [00:00, ?it/s]

valid_loss: 0.007181078696250916 valid_accu: 0.7703999876976013 
epoch: 15 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.1256604641675949 accu=0.9583333134651184 

0it [00:00, ?it/s]

valid_loss: 0.007560571193695068 valid_accu: 0.76419997215271 
epoch: 16 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.04683852195739746 accu=0.9861111044883728 

0it [00:00, ?it/s]

valid_loss: 0.007569316947460175 valid_accu: 0.7784000039100647 
epoch: 17 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.0941305011510849 accu=0.944444477558136 

0it [00:00, ?it/s]

valid_loss: 0.0074231725215911865 valid_accu: 0.777999997138977 
epoch: 18 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.10060694068670273 accu=0.9722222089767456 

0it [00:00, ?it/s]

valid_loss: 0.00766129287481308 valid_accu: 0.7789999842643738 
epoch: 19 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.08288509398698807 accu=0.9583333134651184 

0it [00:00, ?it/s]

valid_loss: 0.00796333611011505 valid_accu: 0.7759999632835388 
epoch: 20 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.03118015266954899 accu=1.0 

0it [00:00, ?it/s]

valid_loss: 0.008388083469867707 valid_accu: 0.7735999822616577 
epoch: 21 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.0689447894692421 accu=0.9722222089767456 

0it [00:00, ?it/s]

valid_loss: 0.008306311309337617 valid_accu: 0.7755999565124512 
epoch: 22 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.10265751928091049 accu=0.944444477558136 

0it [00:00, ?it/s]

valid_loss: 0.00867701324224472 valid_accu: 0.7789999842643738 
epoch: 23 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.06691690534353256 accu=0.9583333134651184 

0it [00:00, ?it/s]

valid_loss: 0.008344750368595123 valid_accu: 0.777999997138977 
epoch: 24 

  0%|          | 0/351 [00:00<?, ?it/s]

loss=0.013189314864575863 accu=1.0 

0it [00:00, ?it/s]

valid_loss: 0.008692999589443208 valid_accu: 0.7698000073432922 


### Test and Save the network

In [24]:
def predict_fn(model, test_dataloader, device):

    y_pred_list = []
    with torch.no_grad():
        for batch_index, (X, y) in enumerate(test_dataloader):
            X_data, y_data = X.to(device), y.to(device)
            y_pred = model(X_data)
            y_pred_list.append(y_pred)
    return torch.cat(y_pred_list)

In [25]:
test_fn(res_net, test_dataloader, epoch, config['device'], dataset="test")

0it [00:00, ?it/s]

test_loss: 0.0094959352850914 test_accu: 0.7688999772071838 

False

In [26]:
torch.save(res_net.state_dict(), "resnet_cifar10_n3_network.pth")

### Conclusion and Future work