> Batch normalization normalizes across all data points in the batch, while layer normalization normalizes on a single data point across the hidden units.

> Why transformer/LLM uses layernorm rather than batchnorm?
> - Batch stats vary with the time-layer of the neural network and they need to be maintained for different timestamps
> - The number of layers is input-dependent, so in some mini-batches stats may not be available.
>- Batch normalization requires all data in the batch to compute, which could be an issue when perform efficient training such as data parallelism on different GPUs (require sync for the entire batch at each BN node)

In [2]:
import torch
from torch import nn

In [3]:
class BatchNorm(nn.Module):
    def __init__(self, n_features:int, eps=1e-5, trainable=False):
        super().__init__()
        self.n_features = n_features
        self.eps = eps
        self.trainable = trainable
        self.beta = nn.Parameter(torch.zeros(self.n_features))
        self.gamma = nn.Parameter(torch.ones(self.n_features))
        if not self.trainable:
            self.beta.requires_grad = False
            self.gamma.requires_grad = False
            
    def forward(self, x):
        # x has shape (B, n_features)
        mu = x.mean(axis=0)
        sigma = ((x-mu)**2).mean(axis=0)+self.eps
        x = (x - mu)/torch.sqrt(sigma) 
        x = self.gamma*x + self.beta
        return x

In [4]:
bn = BatchNorm(12)
x = torch.rand((16,12))*torch.arange(1,13)
print(bn(x))
print(bn(x).mean(axis=0))
print(bn(x).std(axis=0))

tensor([[-0.9681,  1.2817,  0.2988,  1.0407,  1.4796,  1.0418,  1.4976, -1.8626,
          1.3309,  0.0239, -1.7359,  0.5197],
        [ 1.1636, -0.7685,  1.0525,  0.9157, -1.0946, -0.4234,  0.7327, -1.9050,
         -1.0799,  1.5753, -1.3696, -0.6924],
        [-0.1481, -0.6816,  1.1876, -0.8209,  1.1980, -0.8677, -1.4088,  0.9827,
          0.9965, -0.9997,  0.1709,  0.0177],
        [-1.8566, -0.3192,  1.1693, -0.1286,  1.1690, -0.2256,  0.5283, -0.1031,
         -0.1562, -0.6333, -1.5097,  0.2571],
        [ 0.8689, -1.0719, -0.7818,  2.3533, -0.0615,  0.6098,  0.4189, -0.2295,
         -1.2368, -1.2044,  0.8641,  1.5173],
        [-1.4383, -1.4673,  1.4675, -1.0227, -1.1297,  1.4058, -0.7517,  0.1495,
         -1.1174,  1.2494,  1.0894,  0.4233],
        [-0.4150,  0.8856, -0.2005,  0.3491, -0.8582,  1.5819, -0.5425,  1.2694,
          0.9575, -0.8455, -1.1827, -1.5559],
        [ 0.7443, -1.4563, -0.1399,  1.7998, -0.3497, -1.6391,  1.5039,  1.3012,
          0.9346,  0.0249, -0.

In [5]:
(1,12)+(3,4)

(1, 12, 3, 4)

In [88]:
from typing import Tuple

class LayerNorm(nn.Module):
    def __init__(self, normalized_shape: Tuple[int], trainable: bool=False, eps: float=1e-7):
        super().__init__()
        self.normalized_shape = normalized_shape
        self.trainable = trainable
        self.eps = eps
        self.beta = nn.Parameter(torch.zeros(self.normalized_shape))
        self.gamma = nn.Parameter(torch.ones(self.normalized_shape))
        if not self.trainable:
            self.beta.requires_grad = False
            self.gamma.requires_grad = False
        
    def forward(self, x):
        d = len(self.normalized_shape)
        assert tuple(x.shape[-d:]) == self.normalized_shape, f"Expected the last dimensions of input shape match normalized shape {self.normalized_shape}, got {tuple(x.shape[-d:])} instead."
        
        t = tuple([-i for i in range(1,d+1)])
        mu = x.mean(t)
        sigma = (x**2).mean(t) - mu**2 + self.eps
        
        for i in range(d):
            mu = torch.stack(tuple([mu for _ in range(self.normalized_shape[i])]), axis=-1)
            sigma = torch.stack(tuple([sigma for _ in range(self.normalized_shape[i])]), axis=-1)
        
        x = (x - mu) / torch.sqrt(sigma)
        x = self.gamma * x + self.beta
        return x

In [89]:
from transformers import BertTokenizer, BertModel
from typing import List

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

def get_word_embeddings(sentences: List[str]):

    # Tokenize and encode text using batch_encode_plus
    # The function returns a dictionary containing the token IDs and attention masks
    encoding = tokenizer.batch_encode_plus(
        sentences,                    # List of input texts
        padding=True,              # Pad to the maximum sequence length
        truncation=True,           # Truncate to the maximum sequence length if necessary
        return_tensors='pt',      # Return PyTorch tensors
        add_special_tokens=True    # Add special tokens CLS and SEP
    )

    input_ids = encoding['input_ids']  # Token IDs
    # print input IDs
    print(f"Input ID: {input_ids}")
    attention_mask = encoding['attention_mask']  # Attention mask
    # print attention mask
    print(f"Attention mask: {attention_mask}")

    # Generate embeddings using BERT model
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        word_embeddings = outputs.last_hidden_state  # This contains the embeddings

    # Output the shape of word embeddings
    print(f"Shape of Word Embeddings: {word_embeddings.shape}")
    return attention_mask, word_embeddings

In [90]:
sentences = ["I would like to sell bitcoin", "Today is a beautiful day.", "Spain has just won the 2024 Euro cup!", "Damn, I am late for the doctor's appointment!", "You just keep on trying", "Stay foolish"]

att_mask, word_embeddings = get_word_embeddings(sentences)

Input ID: tensor([[  101,  1045,  2052,  2066,  2000,  5271,  2978,  3597,  2378,   102,
             0,     0,     0,     0],
        [  101,  2651,  2003,  1037,  3376,  2154,  1012,   102,     0,     0,
             0,     0,     0,     0],
        [  101,  3577,  2038,  2074,  2180,  1996, 16798,  2549,  9944,  2452,
           999,   102,     0,     0],
        [  101,  4365,  1010,  1045,  2572,  2397,  2005,  1996,  3460,  1005,
          1055,  6098,   999,   102],
        [  101,  2017,  2074,  2562,  2006,  2667,   102,     0,     0,     0,
             0,     0,     0,     0],
        [  101,  2994, 13219,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0]])
Attention mask: tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1,

In [91]:
ln = LayerNorm((14,768))
x = ln(word_embeddings)

In [92]:
x.shape

torch.Size([6, 14, 768])

In [93]:
x.mean((-1,-2))

tensor([-7.9828e-10,  8.3376e-09,  5.6766e-09,  7.0958e-10, -1.4192e-09,
        -9.9341e-09])

In [94]:
x.var((-1,-2))

tensor([1.0001, 1.0001, 1.0001, 1.0001, 1.0001, 1.0001])