http://nlp.seas.harvard.edu/annotated-transformer/

Table of Contents

1. Prelims
2. Backgrounmd
3. Part 1 : Model Architecture
  
  - Encoder and Decoder Stacks
  - Position-wise FFN
  - Embeddings and Softmax
  - PE
  - Full Model

4. Part 2 : Model Training
  
  - Batches and Masking
  - Training Loop
  - Hardware and Schedule
  - Optimizer
  - Regularization

#1. Prelims

In [None]:
!pip install -r requirements.txt

In [None]:
!pip install -q torchdata == 0.3.0 torchtext == 0.12 spacy == 3.2 altair GPUtil
!python -m spacy download de_core_news_sm
!python -m spacy download en_core_web_sm

In [None]:
import os
from os.path import exists
import torch
import torch.nn as nn
from torch.nn.functional import log_softmax, pad

import math
import copy
import time

from torch.optim.lr_scheduler import LambdaLR

import pandas as pd
import altair as alt

from torchtext.data.functional import to_map_style_dataset
from torchtext.vocab import build_vocab_from_iterator

import torchtext.datasets import datasets
from torch.utils.data import DataLoader

import spacy
import GPUtil
import warnings

from torch.utils.data.distributed import DistributedSampler

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataPalallel as DDP

warnings.filterwarnings('ignore')
run_examples = True

In [None]:
# some convenience helper functions used throughout the notebook

def is_interactive_notebook() :
  return __name__ == '__main__'

def show_examples(fn, args=[]) :
  if __name__ == '__main__' and run_examples :
    return fn(*args)

def execute_examples(fn, args=[]) :
  if __name__ == '__main__' and run_examples :
    return fn(*args)

class DummyOptimizer(torch.optim.Optimizer) :
  def __init__(self) :
    self.param_groups = [{'lr' : 0}]
    None
  
  def step(self) :
    None
  
  def zero_grad(self, set_to_none=False) :
    None
  

Class DummyScheduler :
def step(self) :
  None

2. Model Architecture

In [None]:
class EncoderDecoder(nn.Module) :
  '''
  A standard Encoder-Decoder architecture.
  '''

  def __init__(self, encoder, decoder, src_embed, tgt_embed, generator) :
    super(EncoderDecoder, self).__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_embed = src_embed
    self.tgt_embed = tgt_embed
    self.generator = generator

  
  def forward(self, src, tgt, src_mask, tgt_mask) :
    'take in and process maksed src and tgt sequence'
    return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
 
  def encode(self, src, src_mask) :
    return self.encoder(self.src_embed(src), src_mask)

  def encoder(self, memory, src_mask, tgt, tgt_mask) :
    return self.decoder(self.tgt_embed(tgt), memory, tgt_mask, src_mask)

class generator(nn.Module) :
  'Define standard linear + softmax'

  def __init__(self, d_model, vocab) :
    super(generator,self).__init__() 
    self.proj = nn.Linear(d_model, vocab)

  def forward(self,x) :
    return log_softmax(self.proj(x), dim=-1)

## Encoder and Decoder Stacks

### Encoder

the encoder is composed of a stack N=6 of identical layers

In [None]:
# N = 6
class clones(nn.Module) :
  return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


In [None]:
class Encoder(nn.Module) :
  'Core encoder is a stack of N layers'

  def __init__(self, layer, N) :
    super(Encoder,self).__init__()
    self.layers = clones(layer,N)
    self.norm = LayerNorm(layer.size)

  def forward(self, x, mask) :
    for layer in self.layers :
      x = layer(x,mask)
    return self.norm(x)

residual connection around each of the two sublayers, followed by layer norm

In [None]:
class LayerNorm(nn.Module) :
  def __init__(self,features,eps=1e-6) :
    super(LayerNorm,self).__init__()
    self.a_2 = nn.Parameter(torch.ones(features))
    self.b_2 = nn.Parameter(torch.zeros(featrues))
    self.eps = eps

  def forward(self,x) :
    mean = x.mean(-1, keepdim=True)
    std = x.std(-1, keepdim=True)
    return self.a_2 * (x-mean) / (std_sels.eps) _ self.b_2

To faciliate these residual connections, all sub-layers in the model, as well as the embedding layers, produce outputs of d_model == 512

In [None]:
class SublayerConnection(nn.Module) :

  def __init__(self, size, dropout) :
    super(SublayerConnection,self).__init__()
    self.norm = LayerNorm(size)
    self.dropout = nn.Dropout(dropout)

  def forward(self,x,sublayer) :
    return x+ self.dropout(sublayer =(self.norm(x)))

  

Each layer has 2 sublayer. 1st is multi headed, 2nd is position wised fully connected FFN

In [None]:
class EncoderLayer(nn.Module) :

  def __init__(self, size, self_attn, ff, dropout) :
    super(EncoderLayer,self).__init__()

    self.self_attn = self_attn
    self.ff = ff
    self.sublayer = clones(SublayerConnection(size, dropout), 2)
    self.size = size

  def forward(self,x,mask) :
    x = self.sublayer[0](x, lambda x : self.self_attn(x,x,x,mask))
    return self.sublayer[1](x, self.ff)

### Decoder

The decoder is also composed of a stack of N=6 identical layer

In [None]:
class Decoder(nn.Module) :

  def __init__(self, layer, N) :
    super(Decoder,self).__init__()
    self.layers = clones(layer,N)
    self.norm  = LayerNorm(layer.size)

  def forward(self,x,src_mask, memory, tgt_mask): 
    for layer in self.layers :
      x = layer(x, memory, src_mask, tgt_mask)
    
    return self.norm(x)

In [None]:
class DecoderLayer(nn.Module) :

  def __init__(self, size, self_attn, src_attn, ff, dropout) :
    super(DecoderLayer,self).__init__()

    self.size = size
    self.self_attn = self_attn
    self.src_attn = src_attn
    self.ff = ff
    self.sublayer = clones(SublayerConnection(size,dropout),3)

  def forward(self,x,memory,src_mask,tgt_mask) :
    m = memory
    x = self.sublayer[0](x, lambda x : self.self_attn(x,x,x,tgt_mask))
    x = self.sublayer[1](x, lambda x : self.src_attn(x,m,m,src_mask))

    return self.sublayer[2](x, self.ff)

Make Masking to prevent positions from attending to subsequent positions.

In [None]:
def subsequent_mask(size) :
  'Mask out subsequent positions'
  attn_shape = (1,size,size)
  subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.unit8)

  return subsequent_mask == 0

## Attnetion

An attention function can be described as mapping a query and a set of key-value paris to an output, where the query,key,values, and output are all vectors.

The output is computed as a weighted sum of the values.

We call our particular attention 'Scaled dot product attention'.

Attention(Q,K,V) = Softmax(QK^T / sqrt(d_k)V


In [None]:
def attention(query, key, value, mask=None, Dropout=None) :

  d_k = query.size(-1)
  softmax_scores = torch.matmul(query, key.transpose(-2,-1)) / sqrt(d_k)

  if mask is not None :
  softmax_scores = softmax_scores.masked_fill(mask == 0, -1e9)

  p_attn = softmax_scores.softmax(dim=-1)

  if dropout is not None :
    p_attn = dropout(p_attn)
  
  return torch.matmul(p_attn, value), p_attn

### Multihead

Multihead(Q,K,V) = Concat(head1,,,,,head_h)Wo

where head_i = Attention(QW_I^Q, KW_I^K, VW_I^V)



In [None]:
class MultiHeadedAttention(nn.Module) :

  def __init__(self, h, d_model, dropout=0.1) :
    super(MultiHeadedAttention,self).__init__()
    assert d_model %h == 0

    self.d_k = d_model / h
    self.h = h
    self.linears = clones(nn.Linear(d_model, d_model), 4)
    self.attn = None
    self.dropout = nn.Dropout(dropout)

  def forward(self,query,key,value, maks=None) :
    if mask is not None :
      mask = mask.unsqueeze(1)
    
    nbatches = query.size(0)

    query,key,value = [
        lin.(x).view(nbatches, -1, self.h, self.d_k).transpose(1,2)
    ]

    x, self.attn = attention(
        query, key, value, mask=mask, dropout=self.dropout
      
    )

    x = (
        x.transpose(1,2)
        .contiguous()
        .view(nbatches, -1, self.h * self.d_k)
    )

    del query
    del key
    del value

    return self.linears[-1](x)

Applications of Attnetion in Model

1. Encoder-Decoder attention

2. Self-attention layers in Encoder

3. self attention layers in Decoder with masking

## Position-wised FFN

In [None]:
class PositionWiseFF(nn.Module) :

  def __init__(self, d_model, d_ff, dropout=0.1) :
    super(PositionWiseFF,self).__init__()

    self.w_1 = nn.Linear(d_model, d_ff)
    self.w_2 = nn.Linear(d_ff, d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self,x) :
    return self.w_2(self.dropout(self.w_1(x).relu()))

## Embedding and Softmax

In [None]:
class Embedding(nn.Module) :
  def __init__(self, d_model, vocab) :
    super(Embedding,self).__init__()

    self.lut = nn.Embedding(vocab, d_model)
    self.d_model = d_model

  def forward(self,x) :
    return self.lut(x) * sqrt(self.d_model)

## Position Encoding

PE(pos,2i) = sin(pos/ 10000^2i/d_model)

PE(pos,2i+1) = cos(pos/ 10000^2i/d_model)



In [None]:
class PE(nn.Module) :

  def __init__(self, d_model, dropout, max_len = 5000) :

    super(PE,self).__init__()

    self.dropout = nn.Dropout(dropout)

    pe = torch.zeros(max_len, d_model)

    position = torch.arange(0, max_len).unsqueeze(1)

    div_tern = torch.exp(
        torch.arange(0, d_model,2) * -(math.log(10000) / d_model)
    )

    pe[:,0::2] = torch.sin(position * div_term)
    pe[:,1::2] = torch.cos(position * div_term)

    pe = pe.unsqueeze(0)
    self.register_buffer('pe',pe)

  def forward(self, x) :
    x = x + self.pe[:, : x.size(1)].requires_grad_(False)

    return self.dropout(x)

## Full Model

In [None]:
def make_model(src_vocab, tgt_vocab, N=6, d_model = 512, d_ff= 2048, h=8, dropout=0.1) :

  c = copy.deepcopy()
  attn = MultiHeadedAttention(h, d_model)
  ff = PositionwiseFF(d_model, d_ff, dropout)
  position = PE(d_model, dropout)

  model = EncoderDecoder(
      Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
      Decoder(DecoderLayer(d_model, c(attn), c(attn), dropout),N),
      nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
      nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
      Generator(d_model, tgt_vocab)
  )

  # This was importrant from their code.
  # Initialize parameters with Glorot/ fan_avg

  for p in model.Parameters() :
    if p.dim() >1 :
      nn.init.xavier_uniform(p)
  
  return model

#Batches and Masking

In [None]:
class batch :
  'Object for holding a batch of data with mask during training'

  def __init__(self, src, tgt=None, pad=2) :
    self.src = src
    self.src_mask = (src != pad).unsqueeze(-2)

    if tgt is not None :
      self.tgt = tgt[:,:-1]
      self.tgt_y = tgt[:,1:]
      self.tgt_mask = self.make_std_mask(self.tgt, pad)
      self.ntokens = (self.tgt_y != pad).data.sum()

    
  @staticmethod
  def make_std_mask(tgt,pad) :
    tgt_mask = (tgt!=pad).unsqueeze(2)
    tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)

    return tgt_mask

# Training Loop

In [None]:
class TrainState : 
  step : int =0
  accum_step : int = 0
  samples : int = 0
  tokens : int = 0


In [None]:
def run_epoch(
    data_iter,
    model,
    loss_compute,
    optimizer,
    scheduler,
    mode='train',
    accum_iter =1,
    train_state = TrainState()
) :
  for i, batch in enumerate(data_iter) :
    out = model.forward(batch.src, batch.tgt, batch.src_mask, batch.tgt_mask)
    loss, loss_node = loss_compute(out, batch.tgt_y, batch.ntokens)

    if mode == 'train' or mode == 'train+log' :
      loss_node.backward()
      train_state.step +=1
      train_state.samples += batch.src.shape[0]
      train_state.tokens += batch.ntokens

      if i % accum_iter == 0 :
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        n_accum+=1
        train_state.accum_step +=1 
    
    scheduler.step()

  total_loss += loss
  total_tokens += batch.ntokens
  tokens += batch.ntokens

  if i % 40 ==1 and (mode == 'train' or mode == 'train+log') :
    lr = optimizer.param_groups[0]['lr']
    elapsed = time.time() - start

    start = time.time()
    ntokens = 0
  del loss
  del loss_node

return total_loss / total_tokens, train_state



# Optimizer

In [None]:
def rate(step,model_size, factor, warmup) :
  if step == 0 :
    step = 1
  return factor * (
      model_size ** (-0.5)* min(step ** (-0.5), step* warmup**(-1.5))
  )

# Regularization

In [None]:
class LabelSmoothing(nn.Module) :

  def __init__(self, size, padding_idx, smoothing = 0.0) :
    super(LabelSmoothing,self).__init__()
    self.criterion = nn.KLDivLoss(reduction='sum')
    self.padding_idx = padding_idx
    self.confidence = 1.0 - smoothing
    self.smoothing = smoothing
    self.size = size
    self.true_dist = None

  def forward(self, x, target) :
    assert x.size(1) == self.size

    true_dist = x.data.clone()
    true.dist.fill_(self.smoothing / (self.size -2))
    true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
    true_dist[:, self.padding_idx] = 0
    mask = torch.nonzero(target.data == self.padding_idx)

    if mask.dim() > 0 :
      true_dist.index_fill(0, mask.squeeze(),0.0)
    
    self.true_dist = true_dist
    return self.criterion(x, true_dist.clone().detach())