# Practical Session 6

In [1]:
# Imports
import torch

from torch import nn
from torch.nn import functional as F
from torch import optim
from torchvision import datasets

# Matplot
import matplotlib.pyplot as plt

In [2]:
# Import data
cifar_train_set = datasets.CIFAR10('./data/cifar10/', train = True, download = True)
train_input = torch.from_numpy(cifar_train_set.data).permute(0, 3, 1, 2).float()
train_targets = torch.tensor(cifar_train_set.targets, dtype = torch.int64)

mu, std = train_input.mean(), train_input.std()
train_input = train_input.sub_(mu).div_(std)

train_input.size()

Files already downloaded and verified


torch.Size([50000, 3, 32, 32])

## Modification of the ResNet implementation

Edit the implementation of the ResNet and ResNetBlock so that you can pass two Boolean flags
`skip_connections` and `batch_normalization` to specify if these features are activated or not.

In [3]:
class ResNetBlock(nn.Module):
    def __init__(self, nb_channels, kernel_size, skip_connections=True, batch_normalization=True):
        super().__init__()

        self.skip_connections = skip_connections
        self.batch_normalization = batch_normalization

        self.conv1 = nn.Conv2d(nb_channels, nb_channels,
                               kernel_size = kernel_size,
                               padding = (kernel_size - 1) // 2)

        self.bn1 = nn.BatchNorm2d(nb_channels)

        self.conv2 = nn.Conv2d(nb_channels, nb_channels,
                               kernel_size = kernel_size,
                               padding = (kernel_size - 1) // 2)

        self.bn2 = nn.BatchNorm2d(nb_channels)

    def forward(self, x):
        y = self.conv1(x)
        if self.batch_normalization: y = self.bn1(y)
        y = F.relu(y)
        y = self.conv2(y)
        if self.batch_normalization: y = self.bn2(y)
        if self.skip_connections: y = y + x
        y = F.relu(y)

        return y

In [4]:
class ResNet(nn.Module):

    def __init__(self, nb_residual_blocks, nb_channels,
                 kernel_size = 3, nb_classes = 10,
                 batch_normalization=True,
                 skip_connections=True):
        super().__init__()

        self.batch_normalization = batch_normalization
        self.skip_connections = skip_connections

        # Layer definition
        self.conv = nn.Conv2d(3, nb_channels,
                              kernel_size = kernel_size,
                              padding = (kernel_size - 1) // 2)
        self.bn = nn.BatchNorm2d(nb_channels)

        self.resnet_blocks = nn.Sequential(
            *(ResNetBlock(nb_channels, kernel_size, skip_connections, batch_normalization)
              for _ in range(nb_residual_blocks))
        )

        self.fc = nn.Linear(nb_channels, nb_classes)

    def forward(self, x, minibatch_size=32):
        x = self.conv(x)
        if self.batch_normalization: x = self.bn(x)
        x = F.relu(x)
        x = self.resnet_blocks(x)
        x = F.avg_pool2d(x, 32).view(x.size(0), -1)
        x = self.fc(x)
        return x

## Monitoring the gradient norm

In [10]:
def get_stats(skip_connections, batch_normalization):
    # Create model
    model = ResNet(nb_residual_blocks=30, nb_channels=10, kernel_size=3, 
        skip_connections=skip_connections, batch_normalization=batch_normalization)

    # Loss criteria
    criterion = nn.CrossEntropyLoss()

    # Initialize
    grad_norm = torch.zeros(30, 100)

    # Loop over 100 samples
    for i in range(100):
        # Reset gradient
        model.zero_grad()

        # Get sample
        image = train_input[i:i+1]
        target = train_targets[i:i+1]

        # Calculate gradient
        output = model(image)
        loss = criterion(output, target)
        loss.backward()

        # Get model params
        grad_norm[:,i] = torch.Tensor([ b.conv1.weight.grad.norm() for b in model.resnet_blocks ])

    return grad_norm

## Graph

In [11]:
skip_connections = [True, False]
batch_normalization = [True, False]

plt.figure(figsize=(15,10))
plt.yscale('log')

for sc in skip_connections:
    for bn in batch_normalization:
        # Get gradient norms
        grad_norm = get_stats(sc, bn)

        # Calcualte average
        grad_mean = grad_norm.mean(dim=1)

        # Plot  
        plt.plot([i+1 for i in range(30)], grad_mean.numpy(), label = "skip={}, norm={}".format(sc, bn))

# Styling
plt.xlabel("depth")
plt.ylabel("average grad norm")
plt.xlim(1,30)
plt.legend()
plt.show()

tensor([[-2.4543,  2.2286,  0.4680,  1.6983, -2.7910, -1.2428,  3.5782, -0.6584,
          2.3081,  0.2355]], grad_fn=<AddmmBackward>)
tensor([[-2.5272,  2.4931,  0.5515,  1.9310, -2.9782, -1.2644,  3.7377, -0.5942,
          2.4302,  0.3174]], grad_fn=<AddmmBackward>)
tensor([[-2.4883,  2.6449,  0.5420,  1.8979, -2.9956, -1.2168,  3.6885, -0.6082,
          2.4468,  0.3590]], grad_fn=<AddmmBackward>)
tensor([[-2.3910,  2.2913,  0.4873,  1.7617, -2.7692, -1.0447,  3.4525, -0.5127,
          2.1491,  0.3367]], grad_fn=<AddmmBackward>)
tensor([[-2.5028,  2.5021,  0.4732,  1.8700, -2.9078, -1.2410,  3.6518, -0.6218,
          2.3220,  0.3117]], grad_fn=<AddmmBackward>)
tensor([[-2.5928,  2.3989,  0.5092,  1.9539, -2.9474, -1.2832,  3.7891, -0.6107,
          2.3410,  0.2483]], grad_fn=<AddmmBackward>)
tensor([[-2.5856,  2.5118,  0.4913,  1.9759, -2.9961, -1.3487,  3.8890, -0.6264,
          2.5003,  0.3243]], grad_fn=<AddmmBackward>)
tensor([[-2.6781,  2.5175,  0.4721,  1.9460, -2.9984, -

KeyboardInterrupt: 