# Exercise 02 

BatchNorm, unlike other normalization layers like LayerNorm/GroupNorm etc. has the big advantage that after training, the batchnorm gamma/beta can be "folded into" the weights of the preceeding Linear layers, effectively erasing the need to forward it at test time. Set up a small 3-layer MLP with batchnorms, train the network, then "fold" the batchnorm gamma/beta into the preceeding Linear layer's W,b by creating a new W2, b2 and erasing the batch norm. Verify that this gives the same forward pass during inference. i.e. we see that the batchnorm is there just for stabilizing the training, and can be thrown out after training is done! pretty cool.


In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline

## Load Data

In [2]:
# read in all the words
words = open('../names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [3]:
print(f"{len(words)=}")

len(words)=32033


## Helpers

In [4]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


## Build Dataset

In [5]:
# build the dataset
block_size = 3 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
  X, Y = [], []
  
  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix] # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))

Xtr,  Ytr  = build_dataset(words[:n1])     # 80%
Xdev, Ydev = build_dataset(words[n1:n2])   # 10%
Xte,  Yte  = build_dataset(words[n2:])     # 10%

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


## Pytorch Layers: Linear & Batch-norm implementation

In [97]:
class Linear:
    def __init__(self, fan_in:int, fan_out:int, _generator:torch.Generator, bias:bool = True) -> None:
        self.W = torch.randn((fan_in,fan_out), generator=_generator) / fan_in**0.5
        self.b = torch.randn(fan_out, generator=_generator) if bias else None

    def __call__(self, x):
        self.out = x @ self.W
        if self.b is not None:
            self.out += self.b
        return self.out

    def parameters(self):
        return [self.W] + ([self.b] if self.b is not None else [])

class BatchNorm1d:
    def __init__(self, dim, eps=1e-5, momentum=0.1) -> None:
        self.eps = eps
        self.momentum = momentum
        self.training = True
        # parameters (trained with back-prop)
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)
        # buffers (trained with running-momentum update)
        self._running_mean = torch.zeros(dim)
        self._running_var = torch.ones(dim)

    def __call__(self, x):
        # forward-pass
        if self.training:
            xmean = x.mean(0, keepdim=True) # batch mean
            xvar = x.var(0, keepdim=True, unbiased=True) # batch var
        else:
            xmean = self._running_mean
            xvar = self._running_var
        xhat = (x - xmean)/torch.sqrt(xvar + self.eps) # normalization  of batch
        self.out = self.gamma * xhat + self.beta

        # update buffers
        if self.training:
            with torch.no_grad():
                self._running_mean = (1-self.momentum) * self._running_mean + self.momentum * xmean
                self._running_var = (1-self.momentum) * self._running_var + self.momentum * xvar
        return self.out

    def parameters(self):
        return [self.gamma, self.beta]


class FusedLinear:
    def __init__(self, lin:Linear, bn:BatchNorm1d)->None:
        self.lin = lin
        self.bn = bn
        self.W, self.b = self._compute_w_b()

    def _compute_w_b(self):
        # folding gammma and beta into previous layers W & b
        W, b = self.lin.W, self.lin.b
        gamma, beta = self.bn.gamma, self.bn.beta
        mu, sigma = self.bn._running_mean, self.bn._running_var**0.5
        gs = gamma/sigma
        W2 = gs*W
        b2 = gs*b + beta - mu*gs
        return W2, b2

    def __call__(self,x):
        self.out = x @ self.W + self.b
        return self.out

    def parameters(self):
        return [self.W, self.b]

class Tanh:
    def __call__(self,x):
        self.out = torch.tanh(x)
        return self.out
    def parameters(self):
        return []

In [98]:
vocab_size = 27
n_embed = 10 # dimensions of embedding
n_hidden = 100 # number of neurons in hidden layer
g = torch.Generator().manual_seed(2147483647) # for reproducibility

C = torch.randn((vocab_size, n_embed), generator=g)
layers = [
    Linear(n_embed*block_size,  n_hidden, g), BatchNorm1d(n_hidden), Tanh(),
    Linear(n_hidden,            n_hidden, g), BatchNorm1d(n_hidden), Tanh(),
    Linear(n_hidden,          vocab_size, g), BatchNorm1d(vocab_size),
]

# initializations
with torch.no_grad():
    # make last layer weights and bias to 0
    layers[-1].gamma *= 0.1  
    # apply kaiming init to other layers
    for layer in layers[:-1]:
        if isinstance(layer, Linear):
            layer.W *= 1

parameters = [C] + [p  for layer in layers for p in layer.parameters()]
print(sum(p.nelement() for p in parameters)) # total number of parameters

# set requries grad flag
for p in parameters:
    p.requires_grad = True

16651


In [99]:
max_steps = 200_000
batch_size = 32
lossi = []
ud = []

for i in range(max_steps):
    # mini-batch
    ix = torch.randint(0,Xtr.shape[0], (batch_size,), generator=g)
    Xb,Yb = Xtr[ix], Ytr[ix]
    
    # forward pass
    emb = C[Xb]
    x = emb.view(emb.shape[0], -1)
    for layer in layers:
        x = layer(x)
    loss = F.cross_entropy(x, Yb)

    # backward pass
    for layer in layers:
        layer.out.retain_grad()
    for p in parameters:
        p.grad = None
    loss.backward()

    # update
    lr = 0.1 if i < 100_000 else 0.01 # learning rate with decay
    for p in parameters:
        p.data += -lr * p.grad

    # track stats
    if i%10_000 == 0:
        print(f"{i:7d}/{max_steps:7d}:{loss.item():4f}")
    lossi.append(loss.log10().item())
    with torch.no_grad():
        ud.append([(lr*p.grad.std()/(p.data.std() + 1e-5)).log10().item()  for p in parameters])

    if i == 1000:
        break

      0/ 200000:3.309078


### Mathetatical unfolding of linear and batch norm layers
The output of linear layer
$$x_l = W.x + b$$

This feeds into the batch-norm layer. The output is 
$$x_b = \gamma . \hat{x_l} + \beta$$
where, 
$$ \hat{x_l} = \frac{x_l - \mu}{\sigma}$$

Expanding the formula and reordering it
$$x_b = \frac{\gamma}{\sigma} . x_l + \left( \beta - \frac{\mu \gamma}{\sigma} \right)$$

Substituting $x_l$ to above equation 
$$x_b = \frac{\gamma}{\sigma} (W.x + b) + \left( \beta - \frac{\mu \gamma}{\sigma} \right)$$

rearranging to $Wx+b$ format
$$x_b = \frac{\gamma}{\sigma}.W. x + \left( \frac{\gamma}{\sigma}. b + \left( \beta - \frac{\mu \gamma}{\sigma} \right) \right)$$

So, finally
$$W2 = \frac{\gamma}{\sigma}.W ; b2 = \left( \frac{\gamma}{\sigma}. b + \left( \beta - \frac{\mu \gamma}{\sigma} \right) \right)$$

In [100]:
# set training of batchnorm layer to false
for layer in layers:
    if isinstance(layer, BatchNorm1d):
        layer.training = False

with torch.no_grad():
    emb = C[Xtr[:3]]
    x = emb.view(emb.shape[0], -1)

    # noraml forward pass on first linear-batch layers
    x0 = layers[0](x)   # First linear layer

    x1 = layers[1](x0)  # Batch norm layer

    # folding gammma and beta into previous layers W & b
    W, b = layers[0].W, layers[0].b
    gamma, beta = layers[1].gamma, layers[1].beta
    mu, sigma = layers[1]._running_mean, layers[1]._running_var**0.5
    gs = gamma/sigma
    W2 = gs*W
    b2 = gs*b + beta - mu*gs
    x1_uf = x@W2 + b2 


    # print(x[1])

In [101]:
print(mu.shape, sigma.shape, gamma.shape, beta.shape, gs.shape)
print(x.shape, W2.shape)

torch.Size([1, 100]) torch.Size([1, 100]) torch.Size([100]) torch.Size([100]) torch.Size([1, 100])
torch.Size([3, 30]) torch.Size([30, 100])


In [102]:
x1_uf[-1] 

tensor([-0.1793,  0.2994, -0.4260,  0.5528, -0.0344,  0.5147, -0.3274, -0.0393,
        -0.2055,  0.3896, -0.2481, -0.7650, -0.0507, -0.4181, -0.9867,  0.3624,
         0.6038, -1.0864, -0.7159, -1.7000, -0.3483,  0.6470, -0.4465, -0.2338,
         0.0991,  0.0341, -0.8191, -0.8012, -0.9569, -0.3122, -1.5597,  0.9804,
         0.7465, -1.3748,  0.5566,  0.1926, -0.2842,  0.7759, -0.2604, -0.4401,
        -1.4257,  1.6296, -0.9216,  1.5503,  0.8523,  0.8311,  1.8679,  0.1112,
        -0.3980,  0.5268, -0.4364, -1.4508, -1.0315, -0.0180,  0.4832,  0.7953,
         0.4846,  0.2712, -0.5988, -0.3970,  0.8105, -0.1421,  0.1642,  2.1538,
         0.5010, -0.7763,  0.9983,  0.0344,  0.9966,  0.8128, -0.3665, -0.2647,
         0.4427,  0.7612,  0.1248,  0.3893, -0.0479,  0.8442,  0.3901,  0.5991,
         0.7502, -0.3544, -1.7815, -0.6673, -1.1491,  0.3127, -0.4502, -1.4068,
        -0.2253,  1.2449, -1.5305, -0.1753,  0.2042, -0.3764,  2.3007,  0.6251,
        -1.1372, -1.2816, -0.7786, -0.84

In [103]:
x1[-1]

tensor([-0.1793,  0.2994, -0.4260,  0.5528, -0.0344,  0.5147, -0.3274, -0.0393,
        -0.2055,  0.3896, -0.2481, -0.7650, -0.0507, -0.4181, -0.9867,  0.3624,
         0.6038, -1.0863, -0.7159, -1.7000, -0.3483,  0.6470, -0.4465, -0.2338,
         0.0991,  0.0341, -0.8191, -0.8012, -0.9569, -0.3122, -1.5597,  0.9804,
         0.7465, -1.3748,  0.5566,  0.1926, -0.2842,  0.7759, -0.2604, -0.4401,
        -1.4256,  1.6296, -0.9216,  1.5503,  0.8523,  0.8311,  1.8679,  0.1112,
        -0.3980,  0.5267, -0.4364, -1.4508, -1.0315, -0.0180,  0.4832,  0.7953,
         0.4846,  0.2712, -0.5988, -0.3970,  0.8105, -0.1421,  0.1642,  2.1537,
         0.5010, -0.7763,  0.9983,  0.0344,  0.9966,  0.8128, -0.3665, -0.2647,
         0.4427,  0.7612,  0.1248,  0.3893, -0.0479,  0.8442,  0.3901,  0.5991,
         0.7502, -0.3544, -1.7814, -0.6673, -1.1491,  0.3126, -0.4502, -1.4068,
        -0.2253,  1.2449, -1.5305, -0.1753,  0.2042, -0.3764,  2.3007,  0.6251,
        -1.1372, -1.2816, -0.7786, -0.84

Folding the batchnorm layer into linear layer we obtain identical results

In [105]:
# create new layers
comb_layers = []
for i, layer in enumerate(layers):
    if isinstance(layer, BatchNorm1d):
        comb_layers.append(FusedLinear(layers[i-1], layer))
    elif isinstance(layer, Tanh):
        comb_layers.append(layer)



[layer.__class__.__name__ for layer in comb_layers]

# len(layers)

['FusedLinear', 'Tanh', 'FusedLinear', 'Tanh', 'FusedLinear']

In [106]:
# set training of batchnorm layer to false
for layer in layers:
    if isinstance(layer, BatchNorm1d):
        layer.training = False

@torch.no_grad()
def split_loss(split, layers):
    X,y = {
    'train': (Xtr, Ytr),
    'val': (Xdev, Ydev),
    'test': (Xte, Yte),
    }[split]

    # forward pass
    emb = C[X] # (N, block_size, n_embd)
    x = emb.view(emb.shape[0], -1)
    for layer in layers:
        x = layer(x)
    loss = F.cross_entropy(x, y)

    print(split, loss.item())

In [107]:
split_loss('val', layers)
split_loss('test', layers)

val 2.4326910972595215
test 2.4375462532043457


In [108]:
split_loss('val', comb_layers)
split_loss('test', comb_layers)

val 2.432690143585205
test 2.4375455379486084


The losses obtained after combining the layers are very similar to loss obtained with original model sandwidch. 