<a href="https://colab.research.google.com/github/yashika-git/Transformers_from_Scratch/blob/main/Transformers_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [35]:
import torch
import torch.nn as nn
from torch.autograd import Variable

import math

class SelfAttention(nn.Module):

  def __init__(self, embed_size, heads):
    super(SelfAttention, self).__init__()
    self.embed_size = embed_size
    self.heads = heads
    self.head_dim = embed_size//heads

    assert(
        self.head_dim*heads == embed_size
    ), " Embedding size needs to be divisible by number of heads"

    # Defining the linear layers
    self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
    self.fc_out = nn.Linear(heads*self.head_dim, embed_size)


  def forward(self, values, keys, query, mask):
    # getting the number of training examples
    N = query.shape[0]

    value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]  # source senetence length or target sentece length (going to vary)

    # split embeddings into self.heads pieces
    values = values.reshape(N, value_len, self.heads, self.head_dim) # earlier the dimension of values was (N, value_len, embed_size)
    keys = keys.reshape(N, key_len, self.heads, self.head_dim)  
    query = query.reshape(N, query_len, self.heads, self.head_dim)

    values = self.values(values)
    keys = self.keys(keys)
    queries = self.queries(query)

    energy = torch.einsum('nqhd, nkhd -> nhqk', [queries, keys]) 
    # energy: (N, heads, query_len, key_len) query_len(target_sen) while key_len(source_sen)

    # Mask padded indices so their weights become 0
    if mask is not None:
      energy = energy.masked_fill(mask == 0, float('-1e20'))
    
    # Normalizing energy values so that they sum to 1. Also, dividing by scaling factor for better stability
    attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
    # attention: (N, heads, query_len, key_len)

    out = torch.einsum('nhql, nlhd -> nqhd', [attention, values]).reshape(
        N, query_len, self.heads * self.head_dim
    )
    # attention shape: (N, heads, query_len, key_len)
    # values shape: (N, value_len, heads, heads_dim)
    # out after matrix multiply: (N, query_len, heads, head_dim), multiplying across key_len and value_len since they have same dims
    #then we reshape and flatten the last 2 dimensions

    out = self.fc_out(out)
    # final output shape (N, query_len, embed_size)
    # embed_size = heads*heads_dim

    return out


class TransformerBlock(nn.Module):
  def __init__(self, embed_size, heads, dropout, forward_expansion):
    super(TransformerBlock, self).__init__()
    self.attention = SelfAttention(embed_size, heads)

    # statistics of NLP data across the batch dimension exhibit large fluctuations throughout training. 
    # This results in instability, if BN is naively implemented. 
    # Also, it's difficult to parallelize batchnorm. Hence, layernorm is preferred in transformers and NLP.

    self.norm1 = nn.LayerNorm(embed_size)
    self.norm2 = nn.LayerNorm(embed_size)

    # forward_expansion helps in adding some extra nodes
    self.feed_forward = nn.Sequential(
        nn.Linear(embed_size, forward_expansion * embed_size),
        nn.ReLU(),
        nn.Linear(forward_expansion * embed_size, embed_size)
    )

    self.dropout = nn.Dropout(dropout)


  def forward(self, value, key, query, mask):
    attention = self.attention(value, key, query, mask)

    # Adding skip connections, normalization and dropout
    x = self.dropout(self.norm1(attention + query))
    forward = self.feed_forward(x)
    out = self.dropout(self.norm2(forward + x))
    return out


class Encoder(nn.Module):
  def __init__(
      self,
      src_vocab_size,
      embed_size,
      num_layers,
      heads,
      device,
      forward_expansion,
      dropout,
      max_length  #realted to positional encoding, length of the largest sentence  
  ):
    super(Encoder, self).__init__()
    self.embed_size = embed_size
    self.device = device
    self.word_embedding = nn.Embedding(src_vocab_size, embed_size)

    # positional encoding
    # Positional encoding reference https://kazemnejad.com/blog/transformer_architecture_positional_encoding/, https://github.com/harvardnlp/annotated-transformer/blob/master/AnnotatedTransformer.ipynb

    pe = torch.zeros(max_length, embed_size)
    position = torch.arange(0, max_length).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embed_size, 2) * -(math.log(10000.0) / embed_size))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    #print(pe.size())
    pe = pe.unsqueeze(0)
    #print(pe.size())
    self.register_buffer('pe', pe) #registering parameters as buffers helps in saving and storing them but these parameters are not trained by the optimizer

    self.layers = nn.ModuleList(
        [
         TransformerBlock(
             embed_size,
             heads,
             dropout = dropout,
             forward_expansion = forward_expansion
         )
         for _ in range(num_layers)
        ]
    )

    self.dropout = nn.Dropout(dropout)

  def forward(self, x, mask):
    #print(x.shape)
    N, seq_length = x.shape
    out = self.dropout(
        (self.word_embedding(x) + Variable(self.pe[:, :x.size(1), :])) 
        # the dims of pe are (1, max_length, embed_size). Now, we want the dims to be (1, seq_length, embed_size). That's why x.size(1) has been used
    )
    # 
    for layer in self.layers:
      out = layer(out, out, out, mask)
      return out  

       
class DecoderBlock(nn.Module):
  def __init__(self, embed_size, heads, forward_expansion, dropout, device):
    super(DecoderBlock, self).__init__()
    self.norm = nn.LayerNorm(embed_size)
    self.attention = SelfAttention(embed_size, heads=heads)
    self.transformer_block = TransformerBlock(
        embed_size, heads, dropout, forward_expansion
    )
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, value, key, src_mask, trg_mask):
    # src_mask is optional. It is for masking the padded elements in the source sentence, in order to prevent additional compute.

    # masked multi-headed attention
    attention = self.attention(x, x, x, trg_mask)
    query = self.dropout(self.norm(attention + x))
    out = self.transformer_block(value, key, query, src_mask )
    return out


class Decoder(nn.Module):
  def __init__(
      self,
      trg_vocab_size,
      embed_size,
      num_layers,
      heads,
      forward_expansion,
      dropout,
      device,
      max_length 
  ):
    super(Decoder, self).__init__()
    self.device = device
    self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)

    # Position encoding
    pe = torch.zeros(max_length, embed_size)
    position = torch.arange(0, max_length).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embed_size, 2) * -(math.log(10000.0) / embed_size))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0)
    self.register_buffer('pe', pe)

    self.layers = nn.ModuleList(
        [
         DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
         for _ in range(num_layers)
        ]
    )   
    self.fc_out = nn.Linear(embed_size, trg_vocab_size)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, enc_out, src_mask, trg_mask):
    N, seq_length = x.shape
    x = self.dropout((self.word_embedding(x) + + Variable(self.pe[:, :x.size(1)]) ))

    for layer in self.layers:
      x = layer(x, enc_out, enc_out, src_mask, trg_mask)

    out = self.fc_out(x)

    return out 

   
class Transformer(nn.Module):
  def __init__(
      self,
      src_vocab_size,
      trg_vocab_size,
      src_pad_idx,
      trg_pad_idx,
      embed_size=512,
      num_layers=6,
      forward_expansion=4,
      heads=8,
      dropout=0,
      device='cpu',
      max_length=100
  ):
    super(Transformer, self).__init__()
    self.Encoder = Encoder(
        src_vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length
    )

    self.Decoder = Decoder(
        trg_vocab_size,
        embed_size,
        num_layers,
        heads,
        forward_expansion,
        dropout,
        device,
        max_length
    )

    self.src_pad_idx = src_pad_idx
    self.trg_pad_idx = trg_pad_idx
    self.device = device

  def make_src_mask(self, src):
    src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
    # unsqueezing changes the dimensions to (N, 1, 1, src_len)
    # print(src_mask)
    return src_mask.to(self.device)
    
  def make_trg_mask(self, trg):
     N, trg_len = trg.shape
     # tril is for making lower triangular matrix
     # expand helps in constructing 1 triangular matrix for each training example
     trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len)
     # print(trg_mask)
     return trg_mask.to(self.device)

  def forward(self, src, trg):
    src_mask = self.make_src_mask(src)
    trg_mask = self.make_trg_mask(trg)
    enc_src = self.Encoder(src, src_mask)
    out = self.Decoder(trg, enc_src, src_mask, trg_mask)
    return out



if __name__ == '__main__':
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  #print(device)

  # 1 denotes the start token; 0 is the padding; 2 is the end token
  x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(device)

  # target need not necessarily be of same shape as source
  trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

  src_pad_idx = 0
  trg_pad_idx = 0
  src_vocab_size = 10
  trg_vocab_size = 10
  model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)

  # target will be shifted by one, so that it doesn't have the eos token (we want the model to learn to predict the eos)
  out = model(x, trg[:, :-1])
  print(out.shape)


torch.Size([2, 7, 10])


In [None]:
# Reference: https://www.youtube.com/watch?v=U0s0f995w14&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=41, https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/more_advanced/transformer_from_scratch/transformer_from_scratch.py