# 05. Batch Normalization and Residual Blocks

<div style="margin:.3rem 0 1rem;font-size:.9em;color:#555;display:flex;align-items:center;gap:.35rem;font-family:monospace">
  <time datetime="2025-04-23">23 Mar 2025</time> /
  <time datetime="2026-02-23">23 Feb 2026</time> 
</div>

<a href="https://colab.research.google.com/github/shahaliyev/csci4701/blob/main/docs/notebooks/05_batchnorm_resnet.ipynb"
   target="_blank" rel="noopener">
  <img
    src="https://colab.research.google.com/assets/colab-badge.svg"
    alt="Open in Colab"
  />
</a>

<div class="admonition warning">
  <p class="admonition-title">Important</p>
  <p style="margin: 1em 0;">
    The notebook is currently under revision.
  </p>
</div>

Increasing the number of layers in neural networks for learning more advanced functions is challenging due to issues like [vanishing gradients](../04_regul_optim). [VGGNet](https://arxiv.org/pdf/1409.1556) partially addressed this problem by using repetitive _blocks_ that stack multiple convolutional layers before downsampling with max-pooling. For instance, two consecutive $3 \times 3$ convolutional layers achieve the same receptive field as a single $5 \times 5$ convolution, while preserving a higher spatial resolution for the next layer. In simpler terms, repeating a smaller kernel allows the network to access the same input pixels while retaining more detail for subsequent processing. Larger kernels blur (downsample) the image more aggressively, which can lead to the loss of important details and force the network to reduce resolution earlier in the architecture and stop.

Despite this breakthrough, VGGNet was still limited and showed diminishing returns beyond 19 layers (hence, VGG19 architecture). Another architecture was introduced the same year with the paper of the [Inception](https://arxiv.org/pdf/1409.4842) architecture.<span class="fn"><span class="fn-body">It was named <em>Inception</em> because of the <a href='https://knowyourmeme.com/memes/we-need-to-go-deeper'>internet meme</a> from the infamous <em>Inception</em> movie. If you don't believe this, scroll down the paper for references section and check out the very first reference.</span></span> Its implementation, GoogLeNet model<span class="fn"><span class="fn-body">A play on words: GoogLeNet 1) was developed by Google researchers, and 2) pays homage to the <a href='../03_cnn_torch'>LeNet architecture</a>.</span></span>, significantly reduced parameter count and leveraged the advantages of the $1 \times 1$ convolution kernel (see the [Network in Network](https://arxiv.org/pdf/1312.4400) paper which also introduced _Global Average Pooling (GAP)_ layer). Despite enabling deeper networks with far fewer parameters, Inception did not fully resolve the core training and convergence problems faced by very deep models.

As a consequence, [Batch Normalization](https://arxiv.org/pdf/1502.03167) and [Residual Networks](https://arxiv.org/pdf/1512.03385) emerged as two major solutions for efficiently training neural networks as deep as 100 layers and more. We will now set up the data environment and go on discussing the core ideas and implementations of both papers. We had already introduced the [CIFAR dataset](https://www.cs.toronto.edu/~kriz/cifar.html) in our previous notebook.

In [19]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

DATA_PATH = './data'
BATCH_SIZE = 32

cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std  = (0.2470, 0.2435, 0.2616)

train_tfms = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  transforms.Normalize(cifar_mean, cifar_std),
])

test_tfms = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(cifar_mean, cifar_std),
])

train_data = datasets.CIFAR10(root=DATA_PATH, train=True,  download=True, transform=train_tfms)
test_data  = datasets.CIFAR10(root=DATA_PATH, train=False, download=True, transform=test_tfms)

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True,  pin_memory=True, num_workers=2)
test_loader  = DataLoader(test_data,  batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


<div class="admonition note">
  <p class="admonition-title">Note</p>
  <p style="margin: 1em 0;"> 
    From machine learning, we know that, it is encouraged to split the data into <em>training</em>, <em>validation</em> (also called <em>dev</em>), and <em>test</em> sets. When the dataset is not large, an <code>80 : 10 : 10</code> split is a reasonable ratio for allocation. For larger datasets (e.g. with one million images), it is fine to allocate 90% or more of your data for training. The training set is used to update the model's parameters. The validation set is used for tuning hyperparameters (e.g. testing different learning rates, regularization strengths, etc.). The test split should ideally be used only <em>once</em> to report the final performance of the selected model (e.g. for inclusion in a research paper).
  </p>
</div>

## Batch Normalization

<div class="admonition info">
  <p class="admonition-title">Info</p>
  <p style="margin: 0.5em 0;">
    The following source was consulted in preparing this material: Zhang, A., Lipton, Z. C., Li, M., & Smola, A. J. <a href="https://d2l.ai/">Dive into Deep Learning</a>. Cambridge University Press. <a href='https://d2l.ai/chapter_convolutional-modern/batch-norm.html'>Chapter 8.5: Batch Normalization</a>
  </p>
</div>

Batch normalization standardizes the hideen layer activations of a neural network during training. Instead of allowing the distribution of activations to vary freely from batch to batch, the layer normalizes them using statistics computed from the current mini-batch.

<div class="admonition note">
  <p class="admonition-title">Note</p>
  <p style="margin: 1em 0;">
  The terms <em>batch</em> and <em>mini-batch</em> are often used interchangeably in deep learning, although they are not exactly the same. In the strict sense, a <em>batch</em> refers to the entire training dataset processed in a single update of the model parameters. A <em>mini-batch</em> refers to a smaller subset of the dataset processed together before computing the gradient and updating the model. In practice, most deep learning libraries use the word <em>batch</em> to mean <em>mini-batch</em>. For instance, the parameter <code>batch_size</code> in PyTorch specifies how many samples are processed together in one forward and backward pass, not the entire dataset.
  
  For example, if a dataset contains 50,000 training images and the batch size is set to 128, the model will process 128 images at a time and update the parameters after each group. In this case, the algorithm performs many updates during one pass through the dataset (one epoch). This approach is called <em>mini-batch gradient descent</em>.
  </p>
</div>

For each feature value $x_i$ in the mini-batch, we compute the batch mean $\mu_B$ and batch variance $\sigma_B^2$, and normalize the value by subtracting the mean and dividing by the standard deviation. A small constant $\epsilon$ is added for numerical stability so that the denominator never becomes zero:

$$
\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}
$$

After this transformation, the normalized values $\hat{x}_i$ have a mean close to $0$ and a variance close to $1$ within that mini-batch. If this normalization were applied alone, it could restrict the representational flexibility of the network. To allow the model to learn the appropriate scale and offset of the activations, batch normalization introduces two learnable parameters: a scaling parameter $\gamma$ and a shifting parameter $\beta$.

$$
BN(x_i) = \gamma \hat{x}_i + \beta
$$

These parameters are learned together with the rest of the model during training. If necessary, the network can recover the original distribution of activations by choosing appropriate values for $\gamma$ and $\beta$.

<div class="admonition note">
  <p class="admonition-title">Note</p>
  <p style="margin: 1em 0;">
  When a mini-batch passes through the network, each neuron produces one value (feature) for every example. Batch normalization looks at these values together. The layer first computes the average value of that feature in the mini-batch and measures how much the values vary. It then shifts and rescales them so that they stay in a similar numerical range. This keeps the activations stable while the network is learning.
  
  If the process stopped here, every feature would always remain normalized, which could make the network too restrictive. Neural networks need the ability to adjust how strongly signals pass through a layer. For this reason batch normalization adds two learnable parameters. One allows the network to stretch the values and the other allows it to shift them. During training the model learns how much normalization is useful and how much it should modify it.
  </p>
</div>

Batch normalization is typically placed after the [affine transformation](../../mathematics/02_linear_algebra) of a layer and before the non-linear activation function. In other words, the linear mapping is applied first, the result is normalized, and only then the activation function is evaluated:

$$
z = \phi(\textrm{BN}(x)).
$$

<div class="admonition warning">
  <p class="admonition-title">Important</p>
  <p style="margin: 1em 0;">
    Note that the bias term $b$ is often omitted when batch normalization is used. The reason is that the shifting role of the bias is already provided by the learnable parameter $\beta$ of the batch normalization. In practice, many implementations therefore disable the bias parameter in layers with <code>bias = False</code>.
  </p>
</div>

Training very deep neural networks is difficult because the scale of activations can change significantly from layer to layer during learning. As parameters are updated, the distribution of intermediate activations also shifts. Each layer must continuously adapt to these changes, which slows down optimization and can make training unstable.

We had already seen parameter initialization in our previous notebook. Methods such as He initialization choose weight variances so that signals neither explode nor vanish as they propagate through the network. While these techniques help at the start of training, the distributions of activations can still drift as learning progresses. Batch normalization stabilizes these intermediate activations during training.

Keeping activations within a predictable numerical range makes gradient-based optimization more reliable. Normalization reduces the risk of exploding or vanishing gradients, allowing deeper networks to train effectively. Because the scale of inputs to each layer is controlled, larger learning rates can often be used. Since the statistics are computed from a random mini-batch, a small amount of noise is introduced into the activations, which can improve generalization.

<div class="admonition note">
  <p class="admonition-title">Note</p>
  <p style="margin: 1em 0;">
  The original paper argued that the method improves training by reducing <strong>internal covariate shift</strong>. This term refers to the phenomenon where the distribution of activations inside a network changes as the parameters of earlier layers are updated. If the input distribution to a layer keeps shifting, the layer must constantly adapt, which can slow down learning.
  
  Later research has suggested that the primary benefit of batch normalization may not be the reduction of this distribution shift itself. Instead, many studies indicate that normalization improves the geometry of the optimization problem, producing smoother loss surfaces and better-conditioned gradients. This makes gradient-based optimization more stable and allows larger learning rates. As a result, the exact mechanism behind the success of batch normalization is still discussed in the literature, although its practical effectiveness is well established.
  </p>
</div>


We will implement a batch normalization function and compare it with the [`torch.nn.BatchNorm2d`](https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html#torch.nn.BatchNorm2d) module of PyTorch.

In [None]:
import torch

def BatchNorm2d(X, gamma=None, beta=None, eps=1e-5):
    """
    Batch normalization for input (N, C, H, W).
    Statistics are computed per-channel over (N, H, W).
    Variance uses the population estimate (divide by N, not N-1).
    """
    mean = X.mean(dim=(0, 2, 3), keepdim=True)
    var = X.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
    X_hat = (X - mean) / torch.sqrt(var + eps)
    if gamma is not None and beta is not None:
        X_hat = gamma * X_hat + beta
    return X_hat

We will now apply this batch normalization layer to [`vgg11`](https://docs.pytorch.org/vision/main/models/generated/torchvision.models.vgg11.html) model of PyTorch, which has the smallest network of VGG architectures. Note that, PyTorch also has [`vgg11_bn`](https://docs.pytorch.org/vision/main/models/generated/torchvision.models.vgg11_bn.html) implementation of the same model, which applies batch normalization internally.


<div class="admonition success">
  <p class="admonition-title">Exercise</p>
  <p style="margin: 1em 0;">
    Use <code>vgg11_bn</code> model and explore its features.
  </p>
</div>

In [32]:
from torchvision import models

model = models.vgg11(weights=None)
model.features[:6]

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): ReLU(inplace=True)
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

<div class="admonition warning">
  <p class="admonition-title">Important</p>
  <p style="margin: 1em 0;">
    The terms <em>activation</em>, <em>feature</em>, and <em>feature map</em> are often used interchangeably. An <em>activation</em> is the numerical output of a layer, while a <em>feature</em> refers to the same values  interpreted as learned representations useful for a task. For example,  <code>model.features</code>  returns intermediate layer outputs that can be described both as activations and feature maps.
  </p>
</div>

In [42]:
X, _ = next(iter(train_loader))

# obtaining the layer activations
with torch.no_grad():
    A = model.features[:3](X) 

A.shape

torch.Size([32, 64, 16, 16])

In [43]:
gamma = torch.ones(1, A.shape[1], 1, 1)
beta  = torch.zeros(1, A.shape[1], 1, 1)

BN = BatchNorm2d(A, gamma, beta)

BN.shape

torch.Size([32, 64, 16, 16])

In [46]:
f"Input stats: {A.mean((0,2,3))[:3]} {A.var((0,2,3), unbiased=False)[:3]}"

'Input stats: tensor([0.1193, 0.0813, 0.0950]) tensor([0.0292, 0.0169, 0.0186])'

In [47]:
f"Input stats after BN: {BN.mean((0,2,3))[:3]}, {BN.var((0,2,3), unbiased=False)[:3]}"

'Input stats after BN: tensor([-6.7055e-08, -9.4995e-08, -7.9162e-08]), tensor([0.9997, 0.9994, 0.9995])'

In [48]:
import torch.nn as nn

BN_torch = nn.BatchNorm2d(A.shape[1], affine=False)(A)

f"Input stats after PyTorch BN: {BN_torch.mean((0,2,3))[:3]}, {BN_torch.var((0,2,3), unbiased=False)[:3]}"

'Input stats after PyTorch BN: tensor([ 6.5193e-09, -5.5879e-09,  1.5832e-08]), tensor([0.9997, 0.9994, 0.9995])'

## Running Statistics in Batch Normalization

During training, batch normalization computes the mean and variance from the current mini-batch. These statistics are then used to normalize the activations. However, during inference the model may process a single example or a batch that does not represent the full training distribution. If normalization relied only on the current batch, the output of the network could change depending on which samples appear together.

For this reason, batch normalization layers maintain _running estimates_ of the mean and variance observed during training. These estimates approximate the statistics of the full dataset and are used during evaluation. When a model is switched to evaluation mode (for example with `model.eval()` in PyTorch), the stored **running statistics** are used instead of the statistics of the current batch. This ensures stable and deterministic predictions.

PyTorch implementations automatically maintain running statistic (e.g. `track_running_stats`). The global mean and variance values are updated during training using an exponential moving average:

$$
\mu_{\text{running}} = (1 - \alpha)\,\mu_{\text{running}} + \alpha\,\mu_{\text{batch}}
$$

$$
\sigma^2_{\text{running}} = (1 - \alpha)\,\sigma^2_{\text{running}} + \alpha\,\sigma^2_{\text{batch}}.
$$


<div class="admonition note">
  <p class="admonition-title">Note</p>
  <p style="margin: 1em 0;">
  An <a href='https://en.wikipedia.org/wiki/Exponential_smoothing'>Exponential Moving Average (EMA)</a> is a method for smoothing a sequence of values over time. Instead of keeping all past observations, EMA maintains a running estimate that is updated whenever a new value is observed. If $x_t$ is the new value at step $t$ and $m_{t-1}$ is the
  previous estimate, the updated value is 
  $$
  m_t = (1 - \alpha)m_{t-1} + \alpha x_t
  $$
  where $0 < \alpha \le 1$ controls how quickly the estimate reacts to new data. A larger $\alpha$ makes the average respond more strongly to recent values, while a smaller $\alpha$  produces a smoother estimate that changes more slowly. Because older values are repeatedly multiplied by $1-\alpha$ , their influence decays exponentially over time, which is why the method is called an exponential moving average.
  </p>
</div>

In PyTorch the parameter controlling this update is also called **momentum**. Despite the name, it is unrelated to the momentum used in [optimization](../04_regul_optim) algorithms. Instead, it determines how quickly the running statistics adapt to new batches. A larger value updates the statistics more aggressively using recent batches. A smaller value averages information over a longer history of batches. The default value in PyTorch is `0.1`. During evaluation, batch normalization uses these stored statistics instead of recomputing them from the input batch.

In [None]:
def BatchNorm2d(X, running_mean, running_var, training=True, momentum=0.1, eps=1e-5):
  if training:
    mean = X.mean((0,2,3))
    var = X.var((0,2,3), unbiased=False)
    running_mean = (1 - momentum) * running_mean + momentum * mean
    running_var  = (1 - momentum) * running_var  + momentum * var
  else:
    mean = running_mean
    var = running_var
    
  X_hat = (X - mean[None,:,None,None]) / torch.sqrt(var[None,:,None,None] + eps)
  return X_hat, running_mean, running_var

In [52]:
running_mean = torch.zeros(X.shape[1])
running_var = torch.ones(X.shape[1])

for step, (X, _) in enumerate(train_loader):
  batch_mean = X.mean((0,2,3))
  _, running_mean, running_var = BatchNorm2d(X, running_mean, running_var, training=True, momentum=0.1)
  print(step + 1, "batch:", batch_mean.tolist(), "running:", running_mean.tolist())
  if step == 4:
    break

1 batch: [-0.24944937229156494, -0.2408471703529358, -0.2193523496389389] running: [-0.024944936856627464, -0.02408471703529358, -0.02193523570895195]
2 batch: [-0.19832071661949158, -0.2023598849773407, -0.23762772977352142] running: [-0.04228251427412033, -0.04191223531961441, -0.04350448399782181]
3 batch: [-0.32184067368507385, -0.32998546957969666, -0.25827017426490784] running: [-0.07023832947015762, -0.07071955502033234, -0.06498105078935623]
4 batch: [-0.2579638659954071, -0.29232633113861084, -0.2364785522222519] running: [-0.08901087939739227, -0.0928802341222763, -0.0821308046579361]
5 batch: [-0.25047898292541504, -0.26008734107017517, -0.24105609953403473] running: [-0.10515768826007843, -0.10960093885660172, -0.09802333265542984]


In [None]:
X _ = next(iter(test_loader))

y_train, _, _ = BatchNorm2d(X, running_mean, running_var, training=True)
y_eval,  _, _ = BatchNorm2d(X, running_mean, running_var, training=False)

print("train mean:", y_train.mean((0,2,3)))
print("eval  mean:", y_eval.mean((0,2,3)))

train mean: tensor([-2.7940e-08, -6.5193e-09,  7.4506e-09])
eval  mean: tensor([0.0746, 0.0722, 0.0453])


In [None]:
batch_mean = X.mean((0,2,3))
print("batch mean :", batch_mean)
print("running mean:", running_mean)

batch mean : tensor([-0.0247, -0.0320, -0.0507])
running mean: tensor([-0.1052, -0.1096, -0.0980])


### Layer Normalization

A rule of thumb is that batch sizes between `50-100` generally work well for batch normalization: the batch is large enough to return reliable statistics but not so large that it causes memory issues or slows down training unnecessarily. Batch size of `32` is usually the lower bound where batch normalization still provides relatively stable estimates. Batch size of `128` is also effective if the hardware allows, and can produce even smoother estimates. Beyond that the benefit often diminishes.

If the batch size is very small due to memory limitations, batch normalization may lose its effectiveness. In such cases, it's better to consider alternatives like [Layer Normalization](https://arxiv.org/abs/1607.06450) which do not depend on the batch dimension.

Layer normalization normalizes across features for each individual sample, not across the batch and works well for _transformers_ where batch sizes may be small or variable. Basically, batch normalization depends on the batch, but layer normalization does not.

Furthermore, in fully connected layers, each feature is just a single number per sample, so batch normalization computes the mean and variance across the batch for each feature. Fully connected layers don't have spatial structure, so there's nothing to average across except the batch. In convolutional layers, each feature channel height and width and is a 2D map (hence, `nn.BatchNorm2d`), so batch normalization uses not just the batch dimension, but also all the spatial positions to compute statistics. This gives more stable estimates because there are more values per channel.

## Residual Block

**Residual Network (ResNet)** consists of repeated _residual blocks_, in the style of the VGGNet architecture. Each residual block consists of a _residual (skip/shortcut) connection_ . We will first see what it does and then will attempt to understand the reasoning behind this simple breakthrough idea.

### Implementation

![Residual Block](https://d2l.ai/_images/residual-block.svg)

_Figure 8.6.2_ of [Dive into Deep Learning (Chapter 8)](https://d2l.ai/chapter_convolutional-modern/resnet.html) by [d2l.ai](https://d2l.ai/) authors and contributors. Licensed under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0)

Hence, the idea of the residual connection is very simple. Before the second activation function, we add the previous input to the affine transformation. You can imagine the simplified code as below:

In [13]:
def residual_block(X):
  act = torch.relu(X @ params['W1'] + params['b1'])
  out = act @ params['W2'] + params['b2']
  return torch.relu(out + X)

However, If we attempt to directly run the code above, we will see a shape mismatch, as our final layer returns a matrix of dimension `VOCAB_SIZE` which is not equal to the input dimension `BLOCK_SIZE * EMBED_SIZE`.

**Exercise:** Modifying the `forward` function by adding a residual connection.

In [14]:
def forward(X, params, batch_norm=False, bn_stats=None, residual=True):
  emb = params['C'][X]
  out = emb.view(emb.shape[0], -1) @ params['W1'] + params['b1']

  if batch_norm:
    mean, std = bn_stats if bn_stats else (out.mean(0, keepdim=True), out.std(0, keepdim=True) + 1e-5)
    out = (out - mean) / std
    out = params['gamma'] * out + params['beta']

  act = torch.tanh(out + emb) if residual else torch.tanh(out)
  logits = act @ params['W2'] + params['b2']
  return logits

In [15]:
X = params['C'][X_train].view(X_train.shape[0], -1)
X.shape

torch.Size([182535, 30])

What to do? For demonstration purposes we will have to add another layer.

**Exercise (Advanced)**: Train a three layer model with batch normalization and residual connections.

In [16]:
def get_params(batch_norm=True):
  C = torch.randn((VOCAB_SIZE, EMBED_SIZE), requires_grad=True)

  in_features = BLOCK_SIZE * EMBED_SIZE

  W1 = torch.randn((in_features, LAYER_SIZE), requires_grad = True)
  b1 = torch.zeros(LAYER_SIZE, requires_grad=True)

  W2 = torch.randn((LAYER_SIZE, in_features), requires_grad = True)
  b2 = torch.zeros(in_features, requires_grad=True)

  W3 = torch.randn((in_features, VOCAB_SIZE), requires_grad = True)
  b3 = torch.zeros(VOCAB_SIZE, requires_grad=True)

  params = {'C': C, 'W1': W1, 'b1': b1, 'W2': W2, 'b2': b2, 'W3': W3, 'b3': b3}

  if batch_norm:
    gamma = torch.ones((1, LAYER_SIZE), requires_grad=True)
    beta = torch.zeros((1, LAYER_SIZE), requires_grad=True)
    params['gamma'] = gamma
    params['beta'] = beta

  return params

In [17]:
def forward(X, params, batch_norm=False, bn_stats=None, residual=True):
  emb = params['C'][X].view(X.shape[0], -1)
  out = emb @ params['W1'] + params['b1']

  if batch_norm:
    mean, std = bn_stats if bn_stats else (out.mean(0, keepdim=True), out.std(0, keepdim=True) + 1e-5)
    out = (out - mean) / std
    out = params['gamma'] * out + params['beta']

  act = torch.relu(out)
  out2 = act @ params['W2'] + params['b2']

  if residual:
    out2 = out2 + emb

  logits = torch.tanh(out2) @ params['W3'] + params['b3']
  return logits

In [18]:
params = get_params()
params.keys()

dict_keys(['C', 'W1', 'b1', 'W2', 'b2', 'W3', 'b3', 'gamma', 'beta'])

In [19]:
# we are using relu in intermediate layer
if init:
  nn.init.kaiming_uniform_(params['W1'])
  nn.init.kaiming_uniform_(params['W2']);

In [20]:
train(X_train, Y_train, params, num_epochs=epochs, lr=lr, batch_size=batch_size, batch_norm=batch_norm)

Epoch 1000, Loss: 2.5227
Epoch 2000, Loss: 2.9970
Epoch 3000, Loss: 2.5845
Epoch 4000, Loss: 2.3321
Epoch 5000, Loss: 2.2630
Epoch 6000, Loss: 2.5062
Epoch 7000, Loss: 2.8853
Epoch 8000, Loss: 2.3080
Epoch 9000, Loss: 2.7023
Epoch 10000, Loss: 2.8854


In [21]:
bn_stats = get_bn_stats(X_train, params) if batch_norm else None

print('Train and Validation losses:')
evaluate(X_train, Y_train, params, batch_norm=batch_norm, bn_stats=bn_stats)
evaluate(X_val, Y_val, params, batch_norm=batch_norm, bn_stats=bn_stats)

Train and Validation losses:
Loss: 2.3690
Loss: 2.3698


### Reasoning

As our model is implementing a single residual block, we don't see any performance improvement. However, similar to batch normalization, the advantages will be obvious in case of 50 layers or more, with repeated residual blocks. But why adding input of the layer to the second affine transformation boosts training?

Let's take any deep learning model. The types of functions this model can learn depend on its design (e.g. number of layers, activation functions, etc). All these possible functions we can denote as class $\mathcal{F}$. If we cannot learn a perfect function for our data, which is usually the case, we can at least try to appoximate this function as closely as possible by minimizing a loss. We may assume that a more powerful model can learn more types of functions and show better performance. But that's not always the case. To achieve a better performance than a simpler model, our model must be capable of learning not only more functions but also all the functions the simpler model can learn. Simply, the possible function class of the more powerful model should be a superclass of the simpler model's function class $\mathcal{F} \subseteq \mathcal{F}'$. If the ${F}'$ isn't an expanded version of {F}$, the new model might actually learn a function that is farther from the truth, and even show worse performance.

Refer to the figure above, where our residual output is $f(x) = g(x) + x$. One advantage of residual blocks is their regularization effect. What if some activation nodes in our network are unnecessary and increase complexity or learn bad representations? Instead of learning weights and biases, our residual block can now learn an identity function $f(x) = x$ by simply setting that nodes parameters to zero. As a result, our inputs will propagate faster while ensuring that the learned functions are within the biggest function domain. Residual blocks not only act as a regularizer, but also, unlike, say, _dropout_ which stops input from propagating, allow the network to learn more functions by helping inputs to "jump over" (skip) the nodes. And it is very important that the function classes of the model with residual blocks is a superset of the same model without such blocks. Finally, along the way, it deals with the vanishing gradient problem by simply increasing the output of each layer. To sum up, residual connection allows the model to learn more complex functions, while allowing it to easily learn simpler ones, which tackles the vanishing gradient problem and has a regularizing effect.

## Residual Network for NLP in PyTorch

Originally, the complete Residual Network was developed for image classification tasks, winning _ImageNet_ competition. Each of its residual block consisted of two `3x3` convolutions (inspired  by _VGGNet_), both integrating batch normalization, followed by a skip connection. Even though, ResNet model relies on convolutional layer, the concept of residual connections has been adapted for NLP models as well. The infamous **Transformer** model, introduced in the paper titled [Attention is All You Need](https://arxiv.org/pdf/1706.03762) incorporates residual connections heavily in its design, which is very similar to ResNet.

In [22]:
class ResidualBlock(nn.Module):
  def __init__(self):
    super().__init__()
    self.emb = nn.Embedding(VOCAB_SIZE, EMBED_SIZE)
    self.fc1 = nn.Linear(in_features=EMBED_SIZE, out_features=LAYER_SIZE, bias=False)
    self.fc2 = nn.Linear(in_features=LAYER_SIZE, out_features=EMBED_SIZE, bias=False)
    self.fc3 = nn.Linear(in_features=EMBED_SIZE, out_features=VOCAB_SIZE, bias=True)
    self.bn1 = nn.LazyBatchNorm1d()
    self.bn2 = nn.LazyBatchNorm1d()
    nn.init.kaiming_uniform_(self.fc1.weight, nonlinearity='relu')

  # nn.LazyBatchNorm1d in 3D input expects shape (batch, channels, length) = (B, C, T)
  # it normalizes across the batch and time (token, block) dimensions for each channel, independently
  # we need to move that dimension to the middle (axis 1) with transpose(1, 2)
  def forward(self, X):
    emb = self.emb(X)                     # (BATCH_SIZE, BLOCK_SIZE, EMBED_SIZE)
    out = self.fc1(emb).transpose(1, 2)   # (BATCH_SIZE, LAYER_SIZE, BLOCK_SIZE) for BatchNorm1d
    out = self.bn1(out).transpose(1, 2)   # back to our dimensions
    act = F.relu(out)
    out = self.fc2(act).transpose(1, 2)
    out = self.bn2(out).transpose(1, 2)
    out += emb                            # shortcut connection
    logits = self.fc3(out)                # (BATCH_SIZE, BLOCK_SIZE, VOCAB_SIZE)
    return logits

In [23]:
model = ResidualBlock()
cel = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()

ResidualBlock(
  (emb): Embedding(28, 10)
  (fc1): Linear(in_features=10, out_features=100, bias=False)
  (fc2): Linear(in_features=100, out_features=10, bias=False)
  (fc3): Linear(in_features=10, out_features=28, bias=True)
  (bn1): LazyBatchNorm1d(0, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): LazyBatchNorm1d(0, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [24]:
num_epochs = 10_000
batch_size = 32

for epoch in range(1, num_epochs+1):
  model.train()
  idx = torch.randint(0, X_train.size(0), (batch_size,))
  batch_X, batch_Y = X_train[idx], Y_train[idx]
  optimizer.zero_grad()
  logits = model(batch_X)     # (BATCH_SIZE, BLOCK_SIZE, VOCAB_SIZE)
  logits = logits[:, -1, :]   # (BATCH_SIZE, VOCAB_SIZE)
  loss = cel(logits, batch_Y)
  loss.backward()
  optimizer.step()
  if epoch % 1000 == 0 or epoch == 1:
    print(f'Epoch {epoch}, Loss: {loss.item()}')

Epoch 1, Loss: 3.6320574283599854
Epoch 1000, Loss: 2.374105930328369
Epoch 2000, Loss: 2.6409666538238525
Epoch 3000, Loss: 2.6358656883239746
Epoch 4000, Loss: 2.36672043800354
Epoch 5000, Loss: 2.696502208709717
Epoch 6000, Loss: 2.4992451667785645
Epoch 7000, Loss: 2.413964033126831
Epoch 8000, Loss: 2.83028507232666
Epoch 9000, Loss: 2.3721745014190674
Epoch 10000, Loss: 2.6832263469696045


In [25]:
model.eval()
with torch.no_grad():
  logits_train = model(X_train)[:, -1, :]
  logits_val   = model(X_val)[:, -1, :]

  full_loss_train = cel(logits_train, Y_train)
  full_loss_val   = cel(logits_val, Y_val)

  print(f'Train loss: {full_loss_train.item()}')
  print(f'Validation loss: {full_loss_val.item()}')

Train loss: 2.4901065826416016
Validation loss: 2.4812421798706055


In [26]:
# modifying code to suit our needs
def sample(model, n=10, block_size=3):
  model.eval()
  names = []
  for _ in range(n):
    context = ['<START>'] * block_size
    name = ''
    while True:
      idx = [stoi[c] for c in context]
      X = torch.tensor([idx], dtype=torch.long)
      with torch.no_grad():
        logits = model(X)[0, -1] # VOCAB_SIZE
      probs = F.softmax(logits, dim=0)
      idx_next = torch.multinomial(probs, num_samples=1).item()
      char = itos[idx_next]
      if char == '<END>':
        break
      name += char
      context = context[1:] + [char]
    names.append(name)
  return names

In [27]:
sample(model)

['kelifo',
 'ja',
 'tha',
 'elarhncasoria',
 'ka',
 'voratte',
 'eniysh',
 'th',
 'kelld',
 'edm']