In [1]:
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

# ResNet
ResNet is the models that first introduced residual connections (a form of skip connections). It is a rather simple, but successful and very popular architecture. In this demo the [original version](https://arxiv.org/abs/1512.03385) for CIFAR-10 is re-implemented step by step. 

This is just a convenience function to make e.g. `nn.Sequential` more flexible. It is e.g. useful in combination with `x.squeeze()`.

In [2]:
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x):
        return self.func(x)

We begin by implementing the residual blocks. 

Note that we use 'SAME' padding, no bias, and batch normalization after each convolution. 

In [3]:
class ResidualBlock(nn.Module):
    """
    The residual block used by ResNet.
    
    Args:
        in_channels: The number of channels (feature maps) of the incoming embedding
        out_channels: The number of channels after the first convolution
        stride: Stride size of the first convolution, used for downsampling
    """
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()        
        if stride > 1 or in_channels != out_channels:
            # Add strides in the skip connection and zeros for the new channels.
            self.skip = Lambda(lambda x: F.pad(x[:, :, ::stride, ::stride],
                                               (0, 0, 0, 0, 0, out_channels - in_channels),
                                               mode="constant", value=0))
        else:
            self.skip = nn.Sequential()
            
        # Initialize the required layers
        # 2 CNNs with Kernel = 3 and SAME padding 3%2=1 (K%2)
        self.stride = stride
        self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size=3,  
                              stride = self.stride, padding=1)
     
        #second convolution doesn't have a stride size
        self.conv2 = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, 
                               kernel_size=3, padding=1)
        
        # initialize batch normalization for each convolution
        # number of interesting features equal to the outcoming channels 
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        
    def forward(self, input):
        # Execute the required layers and functions
        
        # initialize skip connection
        skip = self.skip(input)
        
        # save forward pass of inner residual block in output 
        output = self.conv1(input)
        output = self.bn1(output)
        output = F.relu(output)
        output = self.conv2(output)
        output = self.bn2(output)
        output = output + skip
        output = F.relu(output)
        return output
        

Next we implement a stack of residual blocks for convenience. The first layer in the block is the one changing the number of channels and downsampling. 

In [4]:
class ResidualStack(nn.Module):
    """
    A stack of residual blocks.
    
    Args:
        in_channels: The number of channels (feature maps) of the incoming embedding
        out_channels: The number of channels after the first layer
        stride: Stride size of the first layer, used for downsampling
        num_blocks: Number of residual blocks
    """
    
    def __init__(self, in_channels, out_channels, stride, num_blocks):
        super().__init__()
        
        # Initialize the required layers (blocks)
        self.ResStack = nn.ModuleList()
        self.first_block = True
        self.stride = stride
        
        for i in range(num_blocks):
            if i == 0 and self.first_block==True:
                self.ResStack.append(ResidualBlock(in_channels, out_channels,
                self.stride))
                
                self.first_block = False
            else:
                self.ResStack.append(ResidualBlock(out_channels, out_channels))
      
        
        
    def forward(self, input):
        # Execute the layers (blocks)
        # ModuleList can act as an iterable, or be indexed using ints

        for i,l in enumerate(self.ResStack):
            input = l(input)
        
        return input
        

Now we implement the full model.

In [5]:
n = 5
num_classes = 10

# Implement ResNet via nn.Sequential
stride = 2 

# used the information about the filters and the 10-way fully connected layer
# to assume the input/output channels sizes after each residual stack
resnet = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    ResidualStack(16,32, stride,n),
    ResidualStack(32,64, stride,n),
    ResidualStack(64,10, stride,n), 
    nn.AdaptiveAvgPool2d(1), #this is the size of the window, if set to 1 we are basically done with convolutions
    Lambda(lambda x: x.squeeze()), 
    nn.Softmax(),
    Lambda(lambda x: x.view(x.size(0), -1))
)

Next we initialize the weights of our model.

In [6]:
def initialize_weight(module):
    if isinstance(module, (nn.Linear, nn.Conv2d)):
        nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
    elif isinstance(module, nn.BatchNorm2d):
        nn.init.constant_(module.weight, 1)
        nn.init.constant_(module.bias, 0)
        
resnet.apply(initialize_weight);

# Training
Now it is time to train the model.

First we load the data and split them into train and validation set, so that we only use the test set when we are completely done developing our model. 

In [7]:
class CIFAR10Subset(torchvision.datasets.CIFAR10):
    """
    Get a subset of the CIFAR10 dataset, according to the passed indices.
    """
    def __init__(self, *args, idx=None, **kwargs):
        super().__init__(*args, **kwargs)
        
        if idx is None:
            return
        
        self.data = self.data[idx]
        targets_np = np.array(self.targets)
        self.targets = targets_np[idx].tolist()

We next define transformations that change the images into PyTorch tensors, standardize the values according to the precomputed mean and standard deviation, and provide data augmentation for the training set.

In [8]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, 4),
    transforms.ToTensor(),
    normalize,
])
transform_eval = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

In [9]:
ntrain = 45_000
train_set = CIFAR10Subset(root='./data', train=True, idx=range(ntrain),
                          download=True, transform=transform_train)
val_set = CIFAR10Subset(root='./data', train=True, idx=range(ntrain, 50_000),
                        download=True, transform=transform_eval)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform_eval)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified


In [10]:
dataloaders = {}
dataloaders['train'] = torch.utils.data.DataLoader(train_set, batch_size=128,
                                                   shuffle=True, num_workers=2,
                                                   pin_memory=True)
dataloaders['val'] = torch.utils.data.DataLoader(val_set, batch_size=128,
                                                 shuffle=False, num_workers=2,
                                                 pin_memory=True)
dataloaders['test'] = torch.utils.data.DataLoader(test_set, batch_size=128,
                                                  shuffle=False, num_workers=2,
                                                  pin_memory=True)

Next we push the model to our GPU (if there is one).

In [11]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
resnet.to(device);

Next we define a helper method that does one epoch of training or evaluation. We have only defined training here, so you need to implement the necessary changes for evaluation!

In [12]:
def run_epoch(model, optimizer, dataloader, train):
    """
    Run one epoch of training or evaluation.
    
    Args:
        model: The model used for prediction
        optimizer: Optimization algorithm for the model
        dataloader: Dataloader providing the data to run our model on
        train: Whether this epoch is used for training or evaluation
        
    Returns:
        Loss and accuracy in this epoch.
    """
    
    device = next(model.parameters()).device
    
    # Set model to training mode (for e.g. batch normalization, dropout)
    if train is True:
        model.train()
    else:
        model.eval()

    epoch_loss = 0.0
    epoch_acc = 0.0

    # Iterate over data
    for xb, yb in dataloader:
        xb, yb = xb.to(device), yb.to(device)

        # zero the parameter gradients
        if train ==True:
            optimizer.zero_grad()

        # forward
        with torch.set_grad_enabled(True):
            pred = model(xb)
            loss = F.cross_entropy(pred, yb)
            
            top1 = torch.argmax(pred, dim=1)
            ncorrect = torch.sum(top1 == yb)
            
            if train ==True:

                loss.backward()
                optimizer.step()

        # statistics
        epoch_loss += loss.item()
        epoch_acc += ncorrect.item()
    
    epoch_loss /= len(dataloader.dataset)
    epoch_acc /= len(dataloader.dataset)
    return epoch_loss, epoch_acc

Next we implement a method for fitting (training) our model with early stopping.

In [13]:
def fit(model, optimizer, lr_scheduler, dataloaders, max_epochs, patience):
    """
    Fit the given model on the dataset.
    
    Args:
        model: The model used for prediction
        optimizer: Optimization algorithm for the model
        lr_scheduler: Learning rate scheduler that improves training
                      in late epochs with learning rate decay
        dataloaders: Dataloaders for training and validation
        max_epochs: Maximum number of epochs for training
        patience: Number of epochs to wait with early stopping the
                  training if validation loss has decreased
                  
    Returns:
        Loss and accuracy in this epoch.
    """
    
    best_acc = 0
    curr_patience = 0
    
    for epoch in range(max_epochs):
        train_loss, train_acc = run_epoch(model, optimizer, dataloaders['train'], train=True)
        lr_scheduler.step()
        print(f"Epoch {epoch + 1: >3}/{max_epochs}, train loss: {train_loss:.2e}, accuracy: {train_acc * 100:.2f}%")
        
        val_loss, val_acc = run_epoch(model, None, dataloaders['val'], train=False)
        print(f"Epoch {epoch + 1: >3}/{max_epochs}, val loss: {val_loss:.2e}, accuracy: {val_acc * 100:.2f}%")
        
        # Add early stopping and save the best weights (in best_model_weights)
        if np.isclose(best_acc,val_acc,0.5) is True:
            curr_patience += 1
        best_acc = val_acc
        best_model_weights = copy.deepcopy(model.state_dict())
        
        if curr_patience >= patience:
          #get best model weights but how
          break


    model.load_state_dict(best_model_weights)

In [14]:
optimizer = torch.optim.SGD(resnet.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)

# Fit model
fit(resnet, optimizer, lr_scheduler, dataloaders, max_epochs=200, patience=50)

  input = module(input)


Epoch   1/200, train loss: 1.68e-02, accuracy: 29.89%
Epoch   1/200, val loss: 1.67e-02, accuracy: 37.28%
Epoch   2/200, train loss: 1.61e-02, accuracy: 39.86%
Epoch   2/200, val loss: 1.63e-02, accuracy: 42.88%
Epoch   3/200, train loss: 1.57e-02, accuracy: 44.40%
Epoch   3/200, val loss: 1.60e-02, accuracy: 45.96%
Epoch   4/200, train loss: 1.55e-02, accuracy: 47.60%
Epoch   4/200, val loss: 1.56e-02, accuracy: 51.00%
Epoch   5/200, train loss: 1.53e-02, accuracy: 49.47%
Epoch   5/200, val loss: 1.56e-02, accuracy: 49.86%
Epoch   6/200, train loss: 1.52e-02, accuracy: 51.43%
Epoch   6/200, val loss: 1.53e-02, accuracy: 54.54%
Epoch   7/200, train loss: 1.51e-02, accuracy: 52.79%
Epoch   7/200, val loss: 1.55e-02, accuracy: 51.66%
Epoch   8/200, train loss: 1.50e-02, accuracy: 54.29%
Epoch   8/200, val loss: 1.51e-02, accuracy: 57.16%
Epoch   9/200, train loss: 1.48e-02, accuracy: 56.06%
Epoch   9/200, val loss: 1.51e-02, accuracy: 58.16%
Epoch  10/200, train loss: 1.48e-02, accuracy:

Once the model is trained we run it on the test set to obtain our final accuracy.
Note that we can only look at the test set once, everything else would lead to overfitting. 

In [15]:
test_loss, test_acc = run_epoch(resnet, None, dataloaders['test'], train=False)
print(f"Test loss: {test_loss:.1e}, accuracy: {test_acc * 100:.2f}%")

  input = module(input)


Test loss: 1.3e-02, accuracy: 83.74%


That's almost what was reported in the paper (92.49%) and we didn't even train on the full training set.