<a href="https://colab.research.google.com/github/saurav997/transformer_translator/blob/master/Translator_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# prompt: import pytorch and .nn
!pip install datasets
import torch
import torch.nn as nn
import math


class InputEmbeddings(nn.Module):

  def __init__(self,d_model: int, vocab_size: int):
    super().__init__()
    self.d_model = d_model
    self.vocab_size = vocab_size
    self.embedding = nn.Embedding(vocab_size,d_model)

  def forward(self,x):
    return self.embedding(x)*math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):

  def __init__(self, d_model:int, seq_len:int, dropout: float)->None:
    super().__init__()
    self.d_model = d_model
    self.seq_len = seq_len
    self.dropout = nn.Dropout(dropout)
    pe = torch.zeros(seq_len,d_model)
    position = torch.arange(0,seq_len,dtype = torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))
    pe[:,0::2] = torch.sin(position*div_term)
    pe[:,1::2] = torch.cos(position*div_term)
    pe = pe.unsqueeze(0)
    self.register_buffer('pe',pe)

  def forward(self,x):
    x = x + (self.pe[:,:x.shape[1],:]).detach()
    return self.dropout(x)

class LayerNormalization(nn.Module):
  def __init__(self, features: int, eps:float=10**-6) -> None:
    super().__init__()
    self.eps = eps
    self.alpha = nn.Parameter(torch.ones(1)) # alpha is a learnable parameter
    self.bias = nn.Parameter(torch.zeros(1)) # bias is a learnable parameter

  def forward(self, x):
    # x: (batch, seq_len, hidden_size)
      # Keep the dimension for broadcasting
    mean = x.float().mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
    # Keep the dimension for broadcasting
    std = x.float().std(dim = -1, keepdim = True) # (batch, seq_len, 1)
    # eps is to prevent dividing by zero or when std is very small
    # print(f'input shape: {x.shape} , {mean.shape} , {std.shape} , {self.alpha.shape} , {self.bias.shape}')
    x = self.alpha * (x - mean) / (std + self.eps) + self.bias
    # print(f'shape of x:{x.shape}')
    return x

class FeedForwardLayer(nn.Module):
  def __init__(self, d_model:int,d_ff:int,dropout:float)->None:
    super().__init__()
    self.linear_1 = nn.Linear(d_model,d_ff)
    self.dropout = nn.Dropout(dropout)
    self.linear_2 = nn.Linear(d_ff,d_model)

  def forward(self,x):
      #(batch,seq_len,d_model)->(batch,seq_len,d_ff)->(batch,seq_len,d_model)
      return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

class MultiHeadAttention(nn.Module):

  def __init__(self, d_model:int, h:int, dropout:float)->None:
    super().__init__()
    self.d_model = d_model
    self.h = h
    assert d_model % h==0 ,"please check d_model divisibility with h"
    self.d_k = d_model//h
    self.W_q = nn.Linear(d_model,d_model,bias =False)
    self.W_k = nn.Linear(d_model,d_model,bias =False)
    self.W_v = nn.Linear(d_model,d_model,bias =False)
    self.W_o = nn.Linear(d_model,d_model,bias =False)
    self.dropout = nn.Dropout(dropout)

  @staticmethod
  def attention(Q,K,V,mask,dropout:nn.Dropout):
    d_k = Q.shape[-1]
    attention_scores = (Q@(K.transpose(-2,-1)))/math.sqrt(d_k)
    if mask is not None:
      attention_scores.masked_fill_(mask==0,-1e9)
    attention_scores = attention_scores.softmax(dim = -1)
    if dropout is not None:
      attention_scores = dropout(attention_scores)
    return (attention_scores@V),attention_scores

  def forward(self,q,k,v,mask):
    # print(f'q:{q.shape},k:{k.shape},v:{v.shape}')
    query = self.W_q(q)
    key = self.W_k(k)
    value = self.W_v(v)
    query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
    key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
    value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
    x,self.attention_scores = MultiHeadAttention.attention(query,key,value,mask,self.dropout)
    x = x.transpose(1,2).contiguous().view(x.shape[0],-1,self.h*self.d_k)
    return self.W_o(x)

class ResidualConnection(nn.Module):
  def __init__(self,features,dropout:float)->None:
    super().__init__()
    self.dropout = nn.Dropout(dropout)
    self.norm = LayerNormalization(features)

  def forward(self,x,sublayer):
    return (x + self.dropout(sublayer(self.norm(x))))


In [None]:
class EncoderBlock(nn.Module):

  def __init__(self,features, self_attention_block : MultiHeadAttention, feed_forward_block : FeedForwardLayer,dropout:float):
    super().__init__()
    self.self_attention_block = self_attention_block
    self.feed_forward_block = feed_forward_block
    self.residual_connections = nn.ModuleList([ResidualConnection(features,dropout) for _ in range(2)])

  def forward(self,x,src_mask):
    x = self.residual_connections[0](x,lambda x: self.self_attention_block(x,x,x,src_mask))
    x = self.residual_connections[1](x,self.feed_forward_block)
    return x

class Encoder(nn.Module):
  def __init__(self,features,layers:nn.ModuleList)->None:
    super().__init__()
    self.layers = layers
    self.norm = LayerNormalization(features)

  def forward(self,x,mask):
    for layer in self.layers:
      x = layer(x,mask)
    return self.norm(x)


In [None]:
class DecoderBlock(nn.Module):
  def __init__(self,features,self_attention_block:MultiHeadAttention, cross_attention_block:MultiHeadAttention, feed_forward_block:FeedForwardLayer,dropout:float)->None:
    super().__init__()
    self.self_attention_block = self_attention_block
    self.cross_attention_block = cross_attention_block
    self.feed_forward_block = feed_forward_block
    self.residual_connections = nn.ModuleList([ResidualConnection(features,dropout) for _ in range(3)])

  def forward(self,x,encoder_output,src_mask,tgt_mask):
    x = self.residual_connections[0](x,lambda x: self.self_attention_block(x,x,x,tgt_mask))
    x = self.residual_connections[1](x,lambda x: self.cross_attention_block(x,encoder_output,encoder_output,src_mask))
    x = self.residual_connections[2](x,self.feed_forward_block)
    return x


class Decoder(nn.Module):
  def __init__(self,features, layers:nn.ModuleList)->None:
    super().__init__()
    self.layers = layers
    self.norm = LayerNormalization(features)

  def forward(self,x,encoder_output,src_mask,tgt_mask):
    for layer in self.layers:
      # print(f'in decoder: {x.shape}')
      x = layer(x,encoder_output,src_mask,tgt_mask)
    return self.norm(x)

class ProjectionLayer(nn.Module):
  def __init__(self,d_model:int,vocab_size:int)->None:
    super().__init__()
    self.proj = nn.Linear(d_model,vocab_size)

  def forward(self,x):
    #(b,s,d)-->(b,s,vocab_size)
    return torch.log_softmax(self.proj(x),dim = -1)



In [None]:
class Transformer(nn.Module):
  def __init__(self,encoder:Encoder,decoder:Decoder, src_embed:InputEmbeddings, tgt_embed:InputEmbeddings, src_pos:PositionalEncoding, tgt_pos:PositionalEncoding, projection_layer:ProjectionLayer)->None:
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.src_embed = src_embed
    self.tgt_embed = tgt_embed
    self.src_pos = src_pos
    self.tgt_pos = tgt_pos
    self.projection_layer = projection_layer

  def encode(self,src,src_mask):
    src_emb = self.src_embed(src)
    src_emb = self.src_pos(src_emb)
    encoder_output = self.encoder(src_emb,src_mask)
    return encoder_output

  def decode(self,encoder_output,src_mask,tgt,tgt_mask):
    tgt = self.tgt_embed(tgt)
    tgt = self.tgt_pos(tgt)
    return self.decoder(tgt,encoder_output,src_mask,tgt_mask)

  def project(self,x):
    return self.projection_layer(x)

In [None]:
# prompt: given all the above code blocks write a function named build_transformer that builds and returns a transformer object of the class transformer after initializing all hyper parameters involved

def build_transformer(src_vocab_size:int , tgt_vocab_size:int ,src_seq_len:int,tgt_seq_len:int,d_model:int=512,N: int = 6,h:int = 8,dropout:float = .1,d_ff =2048)->Transformer:
    src_embed = InputEmbeddings(d_model,src_vocab_size)
    tgt_embed = InputEmbeddings(d_model,tgt_vocab_size)
    src_pos = PositionalEncoding(d_model,src_seq_len,dropout)
    tgt_pos = src_pos

    encoder_blocks = []
    for _ in range(N):
      encoder_self_attention_block = MultiHeadAttention(d_model, h, dropout)
      feed_forward_block = FeedForwardLayer(d_model,d_ff,dropout)
      encoder_block = EncoderBlock(d_model,encoder_self_attention_block,feed_forward_block,dropout)
      encoder_blocks.append(encoder_block)

    decoder_blocks = []
    for _ in range(N):
      decoder_self_attention_block = MultiHeadAttention(d_model, h, dropout)
      decoder_cross_attention_block = MultiHeadAttention(d_model, h, dropout)
      feed_forward_block = FeedForwardLayer(d_model,d_ff,dropout)
      decoder_block = DecoderBlock(d_model,decoder_self_attention_block,decoder_cross_attention_block,feed_forward_block,dropout)
      decoder_blocks.append(decoder_block)

    encoder = Encoder(d_model,nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model,nn.ModuleList(decoder_blocks))
    projection_layer = ProjectionLayer(d_model,tgt_vocab_size)
    transformer = Transformer(encoder,decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

    for p in transformer.parameters():
      if p.dim()>1:
        nn.init.xavier_uniform_(p)

    return transformer


In [None]:
from torch.utils.data import random_split, Dataset, DataLoader
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path

def get_all_sentences(ds,lang):
  for item in ds:
    yield item['translation'][lang]

def get_or_build_tokenizer(config,ds,lang):
  tokenizer_path = Path(config['tokenizer_file'].format(lang))
  if not Path.exists(tokenizer_path):
    tokenizer = Tokenizer(WordLevel(unk_token = '[UNK]'))
    tokenizer.pre_tokenizer = Whitespace()
    trainer = WordLevelTrainer(special_tokens = ["[UNK]","[PAD]","[SOS]","[EOS]"],min_frequency =2)
    tokenizer.train_from_iterator(get_all_sentences(ds,lang),trainer = trainer)
    tokenizer.save(str(tokenizer_path))
  else:
    tokenizer = Tokenizer.from_file(str(tokenizer_path))
  return tokenizer




In [None]:
class BilingualDataset(Dataset):
  def __init__(self,ds,tokenizer_src,tokenizer_tgt,src_lang,tgt_lang,seq_len)->None:
    super().__init__()
    self.ds = ds
    self.tokenizer_src = tokenizer_src
    self.tokenizer_tgt = tokenizer_tgt
    self.tgt_lang = tgt_lang
    self.src_lang = src_lang
    self.seq_len = seq_len
    self.sos_token = torch.tensor([tokenizer_src.token_to_id('[SOS]')],dtype = torch.int64)
    self.eos_token = torch.tensor([tokenizer_src.token_to_id('[EOS]')],dtype = torch.int64)
    self.pad_token = torch.tensor([tokenizer_src.token_to_id('[PAD]')],dtype = torch.int64)

  def __len__(self):
    return len(self.ds)

  def __getitem__(self,index):
    src_target_pair = self.ds[index]
    tgt_text = src_target_pair['translation'][self.tgt_lang]
    src_text = src_target_pair['translation'][self.src_lang]
    enc_input_tokens = self.tokenizer_src.encode(src_text).ids
    dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
    enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
    dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

    if(dec_num_padding_tokens<0 or enc_num_padding_tokens<0):
      raise ValueError("Sentence is too long!")

    encoder_input = torch.cat(
        [
            self.sos_token,
            torch.tensor(enc_input_tokens,dtype = torch.int64),
            self.eos_token,
            torch.tensor([self.pad_token]*enc_num_padding_tokens,dtype = torch.int64)
        ], dim=0,)
    decoder_input = torch.cat(
        [
            self.sos_token,
            torch.tensor(dec_input_tokens,dtype = torch.int64),
            torch.tensor([self.pad_token]*dec_num_padding_tokens,dtype = torch.int64)
        ], dim=0,)

    label = torch.cat(
        [
            torch.tensor(dec_input_tokens,dtype = torch.int64),
            self.eos_token,
            torch.tensor([self.pad_token]*dec_num_padding_tokens,dtype = torch.int64)
        ], dim=0,)
    assert encoder_input.size(0) == self.seq_len
    assert decoder_input.size(0) == self.seq_len
    assert label.size(0) == self.seq_len
    size = decoder_input.size(0)
    return {"encoder_input": encoder_input,
            "decoder_input": decoder_input,
            "encoder_mask" : (encoder_input!= self.pad_token).unsqueeze(0).unsqueeze(0).int(),
            "decoder_mask" : (decoder_input!= self.pad_token).unsqueeze(0).unsqueeze(0).int() & (torch.triu(torch.ones(1,size,size),diagonal =1).type(torch.int64)==0),
            "label": label,
            "src_text":src_text,
            "tgt_text":tgt_text
            }


In [None]:
def causal_mask(size):
  mask = torch.triu(torch.ones(1,size,size),diagonal =1).type(torch.int64)
  return mask == 0

def get_ds(config):
  ds_raw = load_dataset('opus_books',f'{config["lang_src"]}-{config["lang_tgt"]}',split = 'train')
  tokenizer_src = get_or_build_tokenizer(config,ds_raw,config["lang_src"])
  tokenizer_tgt = get_or_build_tokenizer(config,ds_raw,config["lang_tgt"])
  train_ds_size = int(0.9*len(ds_raw))
  val_ds_size = len(ds_raw) - train_ds_size
  train_ds_raw,val_ds_raw = random_split(ds_raw,[train_ds_size,val_ds_size])
  train_ds = BilingualDataset(train_ds_raw,tokenizer_src,tokenizer_tgt,config['lang_src'],config['lang_tgt'],config['seq_len'])
  val_ds = BilingualDataset(val_ds_raw,tokenizer_src,tokenizer_tgt,config['lang_src'],config['lang_tgt'],config['seq_len'])
  max_len_src = 0
  max_len_tgt = 0
  for item in ds_raw:
    src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
    tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
    max_len_src = max(max_len_src,len(src_ids))
    max_len_tgt = max(max_len_tgt,len(tgt_ids))

  print(f'Max length of source sentence:{max_len_src}')
  print(f'Max length of target sentence:{max_len_tgt}')

  train_dataloader = DataLoader(train_ds,batch_size = config['batch_size'],shuffle = True)
  val_dataloader = DataLoader(val_ds,batch_size = 1,shuffle = True)
  return train_dataloader,val_dataloader,tokenizer_src,tokenizer_tgt

In [None]:
def get_model(config,vocab_src_len,vocab_tgt_len):
  model = build_transformer(vocab_src_len,vocab_tgt_len,config['seq_len'],config['seq_len'],config['d_model'])
  return model

def get_config():
  return{
      "batch_size":8,
      "num_epochs":20,
      "lr": 10**-4,
      "seq_len": 350,
      "d_model":512,
      "lang_src":"en",
      "lang_tgt":"it",
      "model_folder":"weights",
      "model_filename": "tmodel_",
      "preload":None,
      "tokenizer_file":"tokenizer_{0}.json",
      "experiment_name":"runs/model"
  }

def get_weights_file_path(config,epoch:str):
  model_folder = config['model_folder']
  model_basename = config['model_basename']
  model_filename = f"{model_basename}{epoch}.pt"
  return str(Path('.')/model_folder/model_filename)


In [None]:
def greedy_decode(model,source,source_mask,tokenizer_src,tokenizer_tgt,max_len,device):
  sos_idx = tokenizer_tgt.token_to_id('[SOS]')  #Start of sentence token
  eos_idx = tokenizer_tgt.token_to_id('[EOS]')  #End of sentence token
  encoder_output = model.encode(source,source_mask)  #model is the transformer
  decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(source).to(device)
  while True:
    if decoder_input.size(1)==max_len:
      break
    decoder_mask = causal_mask(decoder_input.size(1))
    output = model.decode(encoder_output,source_mask,decoder_input,decoder_mask)
    prob = model.project(output[:,-1])
    _,next_word = torch.max(prob,dim =1)
    decoder_input = torch.cat([decoder_input,torch.empty(1,1).type_as(source).fill_(next_word.item()).to(device)],dim = 1)
    if(next_word == eos_idx):
      break
  return decoder_input.squeeze(0)


def validate_model(model,validation_ds,tokenizer_src,tokenizer_tgt,max_len,device,print_msg,global_state,writer,num_examples=2):
  model.eval()
  count = 0
  source_texts = []
  expected = []
  predicted = []
  console_width = 80
  with torch.no_grad():
    for batch in validation_ds:
      count+=1
      encoder_input = batch['encoder_input'].to(device)
      encoder_mask = batch['encoder_mask'].to(device)
      assert encoder_input.size(0) ==1,"batch size must be 1 for validation!"
      model_out = greedy_decode(model,encoder_input,encoder_mask,tokenizer_src,tokenizer_tgt,max_len,device)
      source_texts.append(batch['src_text'][0])
      expected.append(batch['tgt_text'][0])
      predicted.append(tokenizer_tgt.decode(model_out.detach().cpu().numpy()))
      print_msg('-'*console_width)
      print_msg(f'source: {source_texts[-1]}')
      print_msg(f'Target: {expected[-1]}')
      print_msg(f'prediction: {predicted[-1]}')
      if(count == num_examples): break


In [None]:
#training loop starts here:
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

def train_model(config):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  print(f'using device {device}')
  Path(config['model_folder']).mkdir(parents =True,exist_ok = True)
  train_dataloader, val_dataloader, tokenizer_src,tokenizer_tgt = get_ds(config)
  model = get_model(config,tokenizer_src.get_vocab_size(),tokenizer_tgt.get_vocab_size()).to(device)
  writer = SummaryWriter(config['experiment_name'])
  optimizer = torch.optim.Adam(model.parameters(),lr = config['lr'],eps =1e-9)
  initial_epoch = 0
  global_step = 0
  if config['preload']:
    model_filename = get_weights_file_path(config,config['preload'])
    print(f'Preloading Model from {model_filename}')
    state = torch.load(model_filename)
    initial_epoch = state['epoch']+1
    optimizer.load_state_dict(state['optimizer_state_dict'])
    global_step = state['global_step']
  loss_fn = nn.CrossEntropyLoss(ignore_index = tokenizer_src.token_to_id('[PAD]'),label_smoothing =.1 ).to(device)
  for epoch in range(initial_epoch,config['num_epochs']):
    batch_iterator = tqdm(train_dataloader,desc = f'Processing epoch{epoch:02d}')
    for batch in batch_iterator:
      model.train()
      encoder_input = batch['encoder_input'].to(device)
      decoder_input = batch['decoder_input'].to(device)
      # print(f"decoder_input_OG:{decoder_input.shape}")
      encoder_mask = batch['encoder_mask'].to(device)
      decoder_mask = batch['decoder_mask'].to(device)
      encoder_output = model.encode(encoder_input,encoder_mask)
      # print(f'encoder_output:{encoder_output.shape}')
      decoder_output = model.decode(encoder_output,encoder_mask,decoder_input,decoder_mask)
      # print(f'decoder_output:{decoder_output.shape}')
      proj_output = model.project(decoder_output)
      label = batch['label'].to(device)
      loss = loss_fn(proj_output.view(-1,tokenizer_tgt.get_vocab_size()),label.view(-1))
      batch_iterator.set_postfix({f"loss":f"{loss.item():6.3f}"})
      writer.add_scalar('train_loss',loss.item(),global_step)
      writer.flush()

      # Backpropogation step
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()
      global_step += 1

    # Save model if epoch is times 4
    validate_model(model,val_dataloader,tokenizer_src,tokenizer_tgt,config['seq_len'],device,lambda msg: batch_iterator.write(msg),global_step,writer)
    if(epoch%4==0 and epoch!=0):
      model_filename = get_weights_file_path(config,f'{epoch:02d}')
      torch.save({
          'epoch':epoch,
          'model_state_dict': model.state_dict(),
          'optmizer_state_dict': optimizer.state_dict(),
          'global_step': global_step
      },model_filename)








In [None]:
import warnings
import pandas as pd
import numpy as np
import altair as alt
warnings.filterwarnings('ignore')
config = get_config()
train_model(config)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'using device {device}')

In [None]:
config = get_config()
train_dataloader, val_dataloader,vocab_src,vocab_tgt = get_ds(config)
model_viz = get_model(config,vocab_src.get_vocab_size(),vocab_tgt.get_vocab_size()).to(device)

model_filename_viz = get_weights_file_path(config,f'29')
state = torch.load(model_filename_viz)
model_viz.load_state_dict(state['model_state_dict'])

In [None]:
def load_next_batch():
  batch = next(iter(val_dataloader))
  encoder_input = batch['encoder_input'].to(device)
  encoder_mask = batch['encoder_mask'].to(device)
  decoder_input = batch['decoder_input'].to(device)
  decoder_mask = batch['decoder_mask'].to(device)
  encoder_input_tokenized = [vocab_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]
  decoder_input_tokenized = [vocab_tgt.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]
  model_out = greedy_decode(model_viz,encoder_input,encoder_mask,vocab_src,vocab_tgt,config['seq_len'],device)
  return batch,encoder_input,decoder_input_tokenized