PAPER BY  : https://arxiv.org/abs/1409.0473

In [1]:
import torch
import spacy
import random
import numpy as np

import torch.nn as nn
import torch.optim as optim

from tqdm.notebook import tqdm
from torchtext.datasets import Multi30k
from torchtext.data import Field,BucketIterator
from torch.utils.tensorboard import SummaryWriter
from helper_utils import translate_sentence,bleu,save_checkpoint,load_checkpoint

In [2]:
def get_gpu_details():
    t = torch.cuda.get_device_properties(0).total_memory
    c = torch.cuda.memory_cached(0)
    a = torch.cuda.memory_allocated(0)
    print(torch.cuda.get_device_name())
    print(f'Total GPU Memory {t} B , Cached GPU Memory {c} B, Allocated GPU Memory {a} B')
    
    
    
if torch.cuda.is_available():
    device='cuda:0'
else:
    device='cpu'
print(f'Current Device: {device}')
if device=='cuda:0':
    torch.cuda.empty_cache()
    get_gpu_details()

Current Device: cuda:0
NVIDIA GeForce RTX 2060
Total GPU Memory 6442450944 B , Cached GPU Memory 0 B, Allocated GPU Memory 0 B


In [3]:
spacy_german = spacy.load('de_core_news_md')
spacy_english = spacy.load('en')

In [4]:
# spacy_german = spacy.load('de_core_news_sm')
# spacy_english = spacy.load('en_core_web_sm')

In [5]:
def german_tokenizer(sentence):
        return [token.text for token in spacy_german.tokenizer(sentence)]
    
def english_tokenizer(sentence):
    return [token.text for token in spacy_english.tokenizer(sentence)]

In [6]:
german = Field(tokenize = german_tokenizer,lower = True,init_token = '<sos>',eos_token = '<eos>')

english = Field(tokenize = english_tokenizer,lower = True,init_token = '<sos>',eos_token = '<eos>')


In [7]:
train_data , val_data , test_data = Multi30k.splits(exts = ('.de','.en'),
                                                   fields = (german,english))


german.build_vocab(train_data,max_size = 10000,min_freq = 2)
english.build_vocab(train_data,max_size = 10000,min_freq = 2)

In [8]:
class Encoder(nn.Module):
    def __init__(self,size_of_vocab1, embedding_size, hidden_size ,num_layers, dropout_rate):
        
        super(Encoder,self).__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.embedding = nn.Embedding(size_of_vocab1 , embedding_size)
        self.dropout = nn.Dropout(dropout_rate)
        
        self.encoderRNN = nn.LSTM(embedding_size, hidden_size, num_layers , bidirectional=True)
        self.full_connected_hidden = nn.Linear(hidden_size * 2, hidden_size)
        self.full_connected_cell = nn.Linear(hidden_size * 2, hidden_size)
        
        
    def forward(self, sentence):
        #sentence shape = (sequence_length , batch_size)
        
        embedding = self.embedding(sentence)
        #for each word in the seq there's a mapping to a embedding dim space (seq_len, batch_size,embed_dim)
        
        embedding = self.dropout(embedding)
        
        encoder_states, (hidden_state,cell_state) = self.encoderRNN(embedding)
        
        #shape = (2,batchsize,hiddensize)
        hidden_state = self.full_connected_hidden(torch.cat((hidden_state[0:1],hidden_state[1:2]),dim = 2))
        
        cell_state = self.full_connected_cell(torch.cat((cell_state[0:1],cell_state[1:2]),dim = 2))
    
        return encoder_states,hidden_state,cell_state
        

In [9]:
class Decoder(nn.Module):
    def __init__(self,size_of_vocab2,embedding_size,hidden_size,num_layers, dropout_rate):
        
        super(Decoder,self).__init__()
        
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.dropout = nn.Dropout(dropout_rate)
        self.embedding = nn.Embedding(size_of_vocab2, embedding_size)
        
        self.decoderRNN = nn.LSTM(hidden_size*2 + embedding_size,hidden_size,
                                  num_layers)
        
        self.energy = nn.Linear(hidden_size*3 , 1)
        
        self.fully_connected = nn.Linear(hidden_size,size_of_vocab2)
        
        self.softmax = nn.Softmax(dim = 0)
        self.relu = nn.ReLU()
        
    def forward(self,inp, encoder_states,hidden_state, cell_state):
        
        #shape of inp is batch_size, we want 1,batch_size so unsqueeze to add one dimension
        
        inp = inp.unsqueeze(0)
        
        embedding = self.embedding(inp)
        embedding  = self.dropout(embedding)
        
        #find energy states
        sequence_length = encoder_states.shape[0]
        hidden_reshaped = hidden_state.repeat(sequence_length,1,1)
        
        energy = self.relu(self.energy(torch.cat((hidden_reshaped,encoder_states),dim =2 )))
        
        attention = self.softmax(energy)
        
        #(batch_size,1,hidden_size*2) ---> (1,batch_size,hidden_size*2)
    
        context_vector = torch.einsum("snk,snl->knl", attention, encoder_states)
        
        rnn_input = torch.cat((context_vector,embedding),dim = 2)
        
        
        
        outputs, (hidden_state,cell_state) = self.decoderRNN(rnn_input ,(hidden_state,cell_state))
        #shape : 1,batch_size,hidden_size
        #hidden_state and cell_state will be used for next word prediction in the sequence
        
        predictions = self.fully_connected(outputs)
        #predictions shape is (1,batch_size,length_of_vocab) so squeeze again
        
        predictions = predictions.squeeze(0)
        
        return predictions,hidden_state,cell_state

In [10]:
class seq2seq(nn.Module):
    def __init__(self,encoder,decoder):
            super(seq2seq,self).__init__()
            self.encoder_block = encoder
            self.decoder_block = decoder
            
    def forward(self,source,target, teaching_force_ratio = 0.5):
        #teaching_force_ratio helps in sometimes using the actual word as next inp and sometimes
        #the predicted word as the next input
        
        #source = (target_len,batch_size)
        batch_size = source.shape[1]
        
        target_len = target.shape[0]
                
        target_vocab_size = len(english.vocab)
        
        outputs = torch.zeros(target_len,batch_size,target_vocab_size).to(device)
        
        encoder_states,encoder_hidden , encoder_cell = self.encoder_block(source)
        
        #get start token
        next_inp = target[0]
        
        for idx in range(1,target_len):
            
            output,encoder_hidden,encoder_cell = self.decoder_block(next_inp,encoder_states,
                                                                    encoder_hidden,encoder_cell)
            outputs[idx] = output
            
            best_guess = output.argmax(1)
            
            next_inp = target[idx] if random.random() < teaching_force_ratio else best_guess
            
        return outputs

In [11]:
num_epochs = 25
learning_rate = 1e-3
batch_size = 64

In [12]:
load_model = False

encoder_inp_size = len(german.vocab)
decoder_inp_size = len(english.vocab)

output_size = len(english.vocab)

encoder_embedding_dim = 300
decoder_embedding_dim = 300

hidden_size = 1024
num_layers = 1
encoder_dropout = 0.5
decoder_dropout = 0.5

writer = SummaryWriter(f'runs/loss_plot')
step = 0 

In [13]:
train_iterator, val_iterator, test_iterator = BucketIterator.splits((train_data,val_data,test_data),
                                                                   batch_size = batch_size,
                                                                  sort_within_batch = True,
                                                                  sort_key = lambda x : len(x.src),
                                                                  device = device)

In [14]:
encoderNN = Encoder(encoder_inp_size,encoder_embedding_dim,hidden_size,num_layers,encoder_dropout).to(device)

In [15]:
decoderNN = Decoder(decoder_inp_size,decoder_embedding_dim,hidden_size,num_layers,decoder_dropout).to(device)

In [16]:
seq2seqModel = seq2seq(encoderNN,decoderNN).to(device)

In [17]:
optimizer = optim.Adam(seq2seqModel.parameters(),lr = learning_rate)

In [18]:
pad_idx = english.vocab.stoi['<pad>']

criterion = nn.CrossEntropyLoss(ignore_index  = pad_idx)

if load_model:
    load_checkpoint(torch.load('seq2seq_chkpt.pth.pt'),model,optimizer)
    

loss = 10000
sentence = "ein boot mit mehreren männern darauf wird von einem großen pferdegespann ans ufer gezogen."
real_translation = 'a boat with several men on it is pulled ashore by a large team of horses.'
for epoch in tqdm(range(num_epochs)):
    
    checkpoint = {'state_dict':seq2seqModel.state_dict(),'optimizer':optimizer.state_dict()}
    save_checkpoint(checkpoint)
    
    seq2seqModel.eval()
    translated_sentence = translate_sentence(seq2seqModel,sentence,german,english,device,max_length = 50)

    print(f'Translated example sentence {translated_sentence}')
    print(f'Real Translation {real_translation.split()}')
    
    seq2seqModel.train()
    
    
    
    for batch_idx,batch in tqdm(enumerate(train_iterator)):
        
        input_data = batch.src.to(device)
        target = batch.trg.to(device)
        
        output = seq2seqModel(input_data, target)
        
        #output shape : targetlen,batchsize,outputdim
        
        output = output[1:].reshape(-1,output.shape[2])
        target = target[1:].reshape(-1)
        
        
        optimizer.zero_grad()
        loss = criterion(output,target)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(seq2seqModel.parameters(),max_norm = 1)
        
        
        optimizer.step()
        
        writer.add_scalar('Training_loss',loss,global_step = step)
        
        step+=1
        
    if epoch % 2 == 0:
            print(loss.item())
        
    

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

=> Saving checkpoint
Translated example sentence ['glancing', 'covering', 'winning', 'sweaters', 'finishes', 'vacant', 'waterproof', 'mom', 'paddling', 'khakis', 'off', 'care', 'attaching', 'pilot', 'attacking', 'across', 'yarn', 'routine', 'guitar', 'campfire', 'presents', 'tipped', 'manicured', 'bullet', 'hotel', 'sweatsuit', 'flag', 'him', 'flights', 'equipment', 'tv', 'tv', 'tv', 'tv', 'tv', 'tv', 'loaded', 'snorkel', 'tin', 'laundry', 'cyclists', 'magician', 'parade', 'gown', 'align', 'prepared', 'directly', 'photography', 'watery', 'tye']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


3.2756073474884033
=> Saving checkpoint
Translated example sentence ['a', 'toddler', 'player', 'wearing', 'a', 'black', 'shirt', 'is', 'to', 'a', 'a', 'of', 'a', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


=> Saving checkpoint
Translated example sentence ['a', 'white', 'man', 'with', 'a', 'is', 'being', 'by', 'from', 'a', 'large', 'large', 'large', '.', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


2.337797164916992
=> Saving checkpoint
Translated example sentence ['a', 'man', 'with', 'no', 'is', 'being', 'pulled', 'by', 'a', 'large', 'large', 'large', '.', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


=> Saving checkpoint
Translated example sentence ['a', 'cowboy', 'with', 'no', 'pulled', 'pulled', 'by', 'a', 'large', 'large', 'large', 'large', '.', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


3.077989101409912
=> Saving checkpoint
Translated example sentence ['a', 'shirtless', 'man', 'with', 'no', 'pulled', 'is', 'being', 'pulled', 'by', 'a', 'large', 'bull', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


=> Saving checkpoint
Translated example sentence ['a', 'shirtless', 'with', 'a', 'cowboy', 'men', 'pulled', 'pulled', 'by', 'a', 'large', 'bull', 'by', 'a', 'large', 'bull', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


2.273533344268799
=> Saving checkpoint
Translated example sentence ['a', 'cowboy', 'cowboy', 'with', 'a', 'pulled', 'pulled', 'pulled', 'by', 'a', 'large', 'bull', 'by', 'a', 'large', 'bull', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


=> Saving checkpoint
Translated example sentence ['a', 'boat', 'with', 'pulled', 'pulled', 'pulled', 'pulled', 'pulled', 'by', 'a', 'large', 'large', 'bull', 'by', 'large', 'large', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


1.3996968269348145
=> Saving checkpoint
Translated example sentence ['a', 'boat', 'with', 'several', 'men', 'being', 'pulled', 'by', 'a', 'large', 'by', 'large', 'bull', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


=> Saving checkpoint
Translated example sentence ['a', 'cowboy', 'with', 'several', 'men', 'being', 'pulled', 'by', 'a', 'large', 'bull', 'by', 'by', 'horses', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


1.8416991233825684
=> Saving checkpoint
Translated example sentence ['a', 'cowboy', 'cowboy', 'several', 'men', 'being', 'pulled', 'by', 'a', 'large', 'bull', 'by', 'horses', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


=> Saving checkpoint
Translated example sentence ['a', 'cowboy', 'cowboy', 'with', 'several', 'men', 'being', 'pulled', 'by', 'by', 'a', 'large', 'cable', 'horses', 'horses', 'horses', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


1.9940751791000366
=> Saving checkpoint
Translated example sentence ['a', 'boat', 'with', 'several', 'men', 'pulled', 'pulled', 'by', 'a', 'large', 'german', 'pulled', 'by', 'the', 'shore', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


=> Saving checkpoint
Translated example sentence ['a', 'boat', 'with', 'several', 'men', 'pulled', 'pulled', 'by', 'a', 'large', 'by', 'a', 'large', 'brown', 'horses', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


1.1212990283966064
=> Saving checkpoint
Translated example sentence ['a', 'boat', 'with', 'several', 'men', 'pulled', 'pulled', 'pulled', 'by', 'a', 'large', 'cable', 'horses', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


=> Saving checkpoint
Translated example sentence ['a', 'boat', 'with', 'several', 'men', 'is', 'pulled', 'to', 'shore', 'by', 'a', 'large', 'donkey', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


1.122049331665039
=> Saving checkpoint
Translated example sentence ['a', 'boat', 'with', 'several', 'men', 'is', 'pulled', 'pulled', 'by', 'shore', 'by', 'a', 'large', 'cable', 'horses', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


=> Saving checkpoint
Translated example sentence ['a', 'cowboy', 'with', 'several', 'men', 'pulled', 'pulled', 'pulled', 'by', 'by', 'horses', 'by', 'a', 'large', 'wooden', 'horses', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


0.8255928754806519
=> Saving checkpoint
Translated example sentence ['a', 'boat', 'carrying', 'several', 'men', 'is', 'pulled', 'pulled', 'pulled', 'by', 'by', 'a', 'large', 'wooden', 'horses', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


=> Saving checkpoint
Translated example sentence ['a', 'boat', 'carrying', 'several', 'men', 'men', 'pulled', 'out', 'of', 'a', 'large', 'bull', 'by', 'horses', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


2.612826108932495
=> Saving checkpoint
Translated example sentence ['a', 'boat', 'carrying', 'several', 'men', 'pulled', 'pulled', 'pulled', 'to', 'shore', 'by', 'a', 'large', 'horses', 'horses', 'horses', 'horses', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


=> Saving checkpoint
Translated example sentence ['a', 'boat', 'carrying', 'several', 'men', 'is', 'pulled', 'to', 'shore', 'by', 'a', 'large', 'cable', 'horses', 'horses', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


1.9882984161376953
=> Saving checkpoint
Translated example sentence ['a', 'boat', 'carrying', 'several', 'men', 'is', 'pulled', 'by', 'shore', 'by', 'a', 'large', 'brown', 'horses', 'horses', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


=> Saving checkpoint
Translated example sentence ['a', 'boat', 'carrying', 'several', 'men', 'is', 'pulled', 'to', 'shore', 'by', 'a', 'large', 'team', 'of', 'horses', '.', '<eos>']
Real Translation ['a', 'boat', 'with', 'several', 'men', 'on', 'it', 'is', 'pulled', 'ashore', 'by', 'a', 'large', 'team', 'of', 'horses.']


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


0.511240541934967



In [19]:
score = bleu(test_data[1:100], seq2seqModel, german, english, device)
print(f"Bleu score {score * 100:.2f}")

Bleu score 20.82
