# Computer Vision (911.908)

## <font color='crimson'>Residual Learning</font>

**Changelog**:
- *Sep. 2021*: initial version
- *Jan. 2023*: PyTorch 1.13 adaptations and fixes

---

In this part of the lecture, we cover one of the current state-of-the-art neural network architectures for visual recognition, *ResNets*, introducded by He et al., https://arxiv.org/abs/1512.03385 (please read the paper).

## Content

- [ResNet components](#ResNet-components)
    - [Basic residual block](#Basic-residual-block)
    - [Bottleneck block](#Bottleneck-block)
- [Full ResNet implementation](#Full-ResNet-implementation)

---

## ResNet components

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

### Basic residual block

Let's implement a simple (**basic**) **residual block** of two layers of convolution, batch-normalization and ReLU activations (the second ReLU is delayed to after the addition).

<img src="BasicBlock.png" width="180"/>

In [2]:
def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, 
                     out_planes, 
                     kernel_size=3, 
                     stride=stride,
                     padding=1, 
                     bias=False)

Let's quickly test this helper function with 10 input channels, 20 output channels and a stride of 1.

In [3]:
l = conv3x3(10, 20,stride=1)
print(l(torch.randn(4,10,32,32)).size())

torch.Size([4, 20, 32, 32])


We are now ready to create a class that implements the residual block from the figure above.

In [4]:
class BasicBlock_v1(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1):
        super(BasicBlock_v1, self).__init__()
        
        """
        3x3 convolution inplanes->outplanes 
        (spatial size maintained), BN + ReLU
        """ 
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1   = nn.BatchNorm2d(planes)
        self.relu  = nn.ReLU(inplace=True)
        
        """
        3x3 convolution inplanes->outplanes 
        (spatial size maintained) + BN
        """ 
        self.conv2 = conv3x3(planes, planes)
        self.bn2   = nn.BatchNorm2d(planes)
        
    def forward(self, x):
        residual = x 
    
        # conv->bn->relu
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        # conv->bn
        out = self.conv2(out)
        out = self.bn2(out)

        # x + F(x) - this realizes the shortcut conn.
        print(out.size())
        print(residual.size())
        out += residual
        out = self.relu(out) # final relu (see figure)
        return out

Let's test the full block

In [5]:
# simple test with batch-size 4, 32 channels, spatial dim. 12x12 
x = torch.randn(4,64,12,12)

# Push data through a basic block
bb = BasicBlock_v1(64,64)
print(bb(x).size())

torch.Size([4, 64, 12, 12])
torch.Size([4, 64, 12, 12])
torch.Size([4, 64, 12, 12])


Ok, but how about strides $>1$?

In [6]:
x = torch.randn(4,64,12,12)
try:
    bb = BasicBlock_v1(64, 64, stride=2)
    out = bb(x)
except:
    print('Error')

torch.Size([4, 64, 6, 6])
torch.Size([4, 64, 12, 12])
Error


This happens because the dimensions are incompatible at the $F(x)+x$ operation.  We can easily fix this by a **1x1 convolution** with appropriate striding, i.e., a projection operation on $x$.

In [7]:
def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(
        in_planes, 
        out_planes, 
        kernel_size=1, 
        stride=stride, 
        bias=False)

In [8]:
class BasicBlock_v2(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock_v2, self).__init__()
        
        self.conv1      = conv3x3(inplanes, planes, stride)
        self.bn1        = nn.BatchNorm2d(planes)
        self.relu       = nn.ReLU(inplace=True)
        self.conv2      = conv3x3(planes, planes)
        self.bn2        = nn.BatchNorm2d(planes)        
        self.stride     = stride
        self.downsample = downsample
        
    def forward(self, x):
        residual = x
    
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        """
        At that point, we check if a downsampling needs to be 
        applied on x, so that F(x)+x can be computed.
        """
        if self.downsample is not None:
            residual = self.downsample(x)
        
        out += residual
        out = self.relu(out)

        return out

**Note**: If we have multiple input channels, say 32, then the kernel is of size 32x1x1 and this kernel is applied with stride 1 over the full feature map. Mathematically, this corresponds to taking the dot product between the kernel and each feature in the feature map (i.e., again a 32x1x1 tensor).

In [9]:
# 1x1 convolution from 64 to 64 channels using a stride of 2
# This decreases the spatial size by a factor of 2 (obviously)!
down_fn = conv1x1(
    64,
    64,
    stride=2)

# Input is a tensor of size 4x64x12x12 (so 64 channels)
x = torch.randn(4,64,12,12)

bb = BasicBlock_v2(
    64,
    64,
    stride=2,
    downsample=down_fn)

print(bb(x).size()) # Voila!

torch.Size([4, 64, 6, 6])


### Bottleneck block

With our previous implementation, we can handle different strides. Next, we look at the **bottleneck** design of a residual block. We can reuse our function that returns a 1x1
convolution layer and our function that returns a 3x3 convolution layer.

<img src="BottleneckBlock.png" width="180"/>



In [10]:
class Bottleneck(nn.Module):
    """
    The expansion factor determines the number of output 
    channels of the last 1x1 convolution layer.
    """
    expansion = 4  

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        
        self.conv1    = conv1x1(inplanes, planes)
        self.bn1      = nn.BatchNorm2d(planes)
        self.conv2    = conv3x3(planes, planes, stride)
        self.bn2      = nn.BatchNorm2d(planes)
        self.conv3    = conv1x1(planes, planes * self.expansion)
        self.bn3      = nn.BatchNorm2d(planes * self.expansion)
        self.relu     = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride   = stride
        

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

In [11]:
down_fn = conv1x1(256,256,stride=2)
x = torch.randn(4,256,12,12)
bb = Bottleneck(256,64,stride=2,downsample=down_fn)
print(bb(x).size())

torch.Size([4, 256, 6, 6])


## Full ResNet implementation

In [12]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

In [13]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0      # best test accuracy
start_epoch = 0   # start from epoch 0 or last checkpoint epoch
print(device)

cpu


In [14]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), 
                         (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), 
                         (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='../data', 
                                        train=True, 
                                        download=True, 
                                        transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, 
                                          batch_size=128, 
                                          shuffle=True,
                                          num_workers=2)

testset = torchvision.datasets.CIFAR10(root='../data', 
                                       train=False, 
                                       download=True, 
                                       transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, 
                                         batch_size=100, 
                                         shuffle=False, 
                                         num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz


100%|███████████████████████████████████████████| 170498071/170498071 [00:08<00:00, 19572294.17it/s]


Extracting ../data/cifar-10-python.tar.gz to ../data
Files already downloaded and verified


In [17]:
cifar_small = torch.utils.data.Subset(trainset, torch.tensor([0,500,1000,1001]))
len(cifar_small)

4

In [None]:
classes = ('plane', 
           'car', 
           'bird', 
           'cat', 
           'deer',
           'dog', 
           'frog', 
           'horse', 
           'ship', 
           'truck')

In [None]:
net = ResNet18()
net = net.to(device) # move the network to the GPU (if available)

if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), 
                      lr=0.01,
                      momentum=0.9, 
                      weight_decay=5e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [None]:
# Training
def train(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    epoch_loss = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        epoch_loss += loss.item()

    print('{:03d} | {:.5f} | {:.3f}'.format(
        epoch,
        epoch_loss/len(trainloader),
        100.*correct/total))

In [None]:
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    acc = 100.*correct/total
    print('=> Testing accuracy: {:.3f}'.format(acc))
    if acc > best_acc:
        best_acc = acc

In [None]:
for epoch in range(1, 1+200):
    train(epoch)
    test(epoch)
    scheduler.step()