<a href="https://colab.research.google.com/github/vjsurampudi/END/blob/main/Session10/NMT_Attention_EN2DE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The image below depicts architecture where encoders are combined to a single context vector z. From there seq2seq RNN used to decode the message for translation.

![.](https://raw.githubusercontent.com/bentrevett/pytorch-seq2seq/master/assets/seq2seq1.png)

In the above architecture, the context vector has a lot of information encoded in z. The decoder has two tasks - one to translate the words and pass the combined context vector to the next layer. In order to ease the load on the context vector z, the below architecture adds z vector to all the neural networks.

![.](https://github.com/bentrevett/pytorch-seq2seq/blob/master/assets/seq2seq7.png?raw=1)

Attention will allow the decoder to focus on what is important from encoder sequence. The below code will implement attention mechanism where attention vector is calculated with all the inputs.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator

import spacy
import numpy as np

import random
import math
import time

In [None]:
SEED = 1234 ## This SEED will be valid until the kernel is alive

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [None]:
!python -m spacy download en

In [None]:
!python -m spacy download de

In [None]:
## Loading spacy for German and English
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')

In [None]:
def tokenize_de(text):
    """
      Tokenizes German text from a string into a list of strings
    """
    return [tok.text for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    """
      Tokenizes English text from a string into a list of strings
    """
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [None]:
SRC = Field(tokenize = tokenize_en,
            init_token = '<sos>',
            eos_token = '<eos>',
            lower = True)

TRG = Field(tokenize = tokenize_de,
            init_token = '<sos>',
            eos_token = '<eos>',
            lower = True)

In [None]:
train_data, valid_data, test_data = Multi30k.splits(exts = ('.en', '.de'), fields = (SRC, TRG))

In [None]:
SRC.build_vocab(train_data, min_freq = 2)
TRG.build_vocab(train_data, min_freq = 2)

In [None]:
BATCH_SIZE = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
      (train_data, valid_data, test_data), 
      batch_size = BATCH_SIZE, 
      device = device)
device

device(type='cuda')

![.](https://github.com/bentrevett/pytorch-seq2seq/blob/master/assets/seq2seq8.png?raw=1)

Encoder is a bidirectional GRU. Embedding layer is going to left to right (Forward RNN) RNN (Green). Same embedded code is also going to right to left RNN layerv (Teal) (Backward RNN). Hidden vector passed to these layers - backward RNN and forward RNN is different as represented by different arrows. By default pyTorch will initialize these hidden vectors as zero unless specified otherwise.

The pairs of hidden vectors h1-> and h4<- , h3<- and h2->, h2<- and h3->, h1<- and h4-> can be merged together and input to a fully connected layer of dimension (enc_hid_dim * 2) and a common vector can be derived. 

In [None]:
class Encoder(nn.Module):
  def __init__(self, input_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout):
    super().__init__()
    self.embedding = nn.Embedding(input_dim, emb_dim)
    self.rnn = nn.GRU(emb_dim, enc_hid_dim, bidirectional = True)
    self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
    self.dropout = nn.Dropout(dropout)

  def forward(self, src):
    # src = [src len, batch size]
    embedding = self.dropout(self.embedding(src))
    # embedding = [src len, batch size, emb dim]
    outputs, hidden = self.rnn(embedding)
    # outputs = [src len, batch size, hid dim * number of directions]
    # hidden = [n layers * number of directions, batch size, hidden dimensions]
    # hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
    # outputs are always from the last layer
        
    # hidden [-2, :, : ] is the last of the forwards RNN 
    # hidden [-1, :, : ] is the last of the backwards RNN
        
    #initial decoder hidden is final hidden state of the forwards and backwards
    #  encoder RNNs fed through a linear layer
    hidden = torch.tanh(self.fc(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim = 1)))
    # outputs = [src len, batch size, enc hid dim * 2]
    # hidden = [batch size, dec hid dim]
    return outputs, hidden

In [None]:
enc = Encoder(input_dim = 32,
              emb_dim = 256, 
              enc_hid_dim = 512, 
              dec_hid_dim = 512, 
              dropout = 0.1)

In [None]:
## Sample output of the forward layer with lenght of sentence as 12 and batch size as 32
outputs, hidden = enc.forward(torch.zeros(12,32).to(torch.int64))
outputs.size(), hidden.size()

(torch.Size([12, 32, 1024]), torch.Size([32, 512]))

## Attention
Next up is the attention layer. This will take in the previous hidden state of the decoder,  st−1 , and all of the stacked forward and backward hidden states from the encoder,  H . The layer will output an attention vector,  at , that is the length of the source sentence, each element is between 0 and 1 and the entire vector sums to 1.

Intuitively, this layer takes what we have decoded so far,  st−1 , and all of what we have encoded,  H , to produce a vector,  at , that represents which words in the source sentence we should pay the most attention to in order to correctly predict the next word to decode,  y^t+1 .

First, we calculate the energy between the previous decoder hidden state and the encoder hidden states. As our encoder hidden states are a sequence of  T  tensors, and our previous decoder hidden state is a single tensor, the first thing we do is repeat the previous decoder hidden state  T  times. We then calculate the energy,  Et , between them by concatenating them together and passing them through a linear layer (attn) and a  tanh  activation function.

Et=tanh(attn(st−1,H)) 

This can be thought of as calculating how well each encoder hidden state "matches" the previous decoder hidden state.

We currently have a [dec hid dim, src len] tensor for each example in the batch. We want this to be [src len] for each example in the batch as the attention should be over the length of the source sentence. This is achieved by multiplying the energy by a [1, dec hid dim] tensor,  v .

a^t=vEt 

We can think of  v  as the weights for a weighted sum of the energy across all encoder hidden states. These weights tell us how much we should attend to each token in the source sequence. The parameters of  v  are initialized randomly, but learned with the rest of the model via backpropagation. Note how  v  is not dependent on time, and the same  v  is used for each time-step of the decoding. We implement  v  as a linear layer without a bias.

Finally, we ensure the attention vector fits the constraints of having all elements between 0 and 1 and the vector summing to 1 by passing it through a  softmax  layer.

at=softmax(at^) 

This gives us the attention over the source sentence!

Graphically, this looks something like below. This is for calculating the very first attention vector, where  st−1=s0=z . The green/teal blocks represent the hidden states from both the forward and backward RNNs, and the attention computation is all done within the pink block.
![.](https://github.com/bentrevett/pytorch-seq2seq/blob/master/assets/seq2seq9.png?raw=1)

In [None]:
class Attention(nn.Module):
    def __init__(self, enc_hid_dim, dec_hid_dim):
        super().__init__()
        
        self.attn = nn.Linear((enc_hid_dim * 2) + dec_hid_dim, dec_hid_dim)
        self.v = nn.Linear(dec_hid_dim, 1, bias = False)
        
    def forward(self, hidden, encoder_outputs):
        
        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [src len, batch size, enc hid dim * 2]
        
        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]
        
        #repeat decoder hidden state src_len times
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        
        #hidden = [batch size, src len, dec hid dim]
        #encoder_outputs = [batch size, src len, enc hid dim * 2]
        
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim = 2))) 
        
        #energy = [batch size, src len, dec hid dim]

        attention = self.v(energy).squeeze(2)
        
        #attention= [batch size, src len]
        
        return F.softmax(attention, dim=1)



## Decoder

The decoder contains the attention layer, attention, which takes the previous hidden state,  st−1 , all of the encoder hidden states,  H , and returns the attention vector,  at .

We then use this attention vector to create a weighted source vector,  wt , denoted by weighted, which is a weighted sum of the encoder hidden states,  H , using  at  as the weights.

wt=atH 

The embedded input word,  d(yt) , the weighted source vector,  wt , and the previous decoder hidden state,  st−1 , are then all passed into the decoder RNN, with  d(yt)  and  wt  being concatenated together.

st=DecoderGRU(d(yt),wt,st−1) 

We then pass  d(yt) ,  wt  and  st  through the linear layer,  f , to make a prediction of the next word in the target sentence,  y^t+1 . This is done by concatenating them all together.

y^t+1=f(d(yt),wt,st) 

The image below shows decoding the first word in an example translation.



The green/teal blocks show the forward/backward encoder RNNs which output  H , the red block shows the context vector,  z=hT=tanh(g(h→T,h←T))=tanh(g(z→,z←))=s0 , the blue block shows the decoder RNN which outputs  st , the purple block shows the linear layer,  f , which outputs  y^t+1  and the orange block shows the calculation of the weighted sum over  H  by  at  and outputs  wt . Not shown is the calculation of  at .

![.](https://github.com/bentrevett/pytorch-seq2seq/blob/master/assets/seq2seq10.png?raw=1)

In [None]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dropout, attention):
        super().__init__()

        self.output_dim = output_dim
        self.attention = attention
        
        self.embedding = nn.Embedding(output_dim, emb_dim)
        
        self.rnn = nn.GRU((enc_hid_dim * 2) + emb_dim, dec_hid_dim)
        
        self.fc_out = nn.Linear((enc_hid_dim * 2) + dec_hid_dim + emb_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, encoder_outputs):
             
        #input = [batch size]
        #hidden = [batch size, dec hid dim]
        #encoder_outputs = [src len, batch size, enc hid dim * 2]
        
        input = input.unsqueeze(0)
        
        #input = [1, batch size]
        
        embedded = self.dropout(self.embedding(input))
        
        #embedded = [1, batch size, emb dim]
        
        a = self.attention(hidden, encoder_outputs)
                
        #a = [batch size, src len]
        
        a = a.unsqueeze(1)
        
        #a = [batch size, 1, src len]
        
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        
        #encoder_outputs = [batch size, src len, enc hid dim * 2]
        
        weighted = torch.bmm(a, encoder_outputs)
        
        #weighted = [batch size, 1, enc hid dim * 2]
        
        weighted = weighted.permute(1, 0, 2)
        
        #weighted = [1, batch size, enc hid dim * 2]
        
        rnn_input = torch.cat((embedded, weighted), dim = 2)
        
        #rnn_input = [1, batch size, (enc hid dim * 2) + emb dim]
            
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
        
        #output = [seq len, batch size, dec hid dim * n directions]
        #hidden = [n layers * n directions, batch size, dec hid dim]
        
        #seq len, n layers and n directions will always be 1 in this decoder, therefore:
        #output = [1, batch size, dec hid dim]
        #hidden = [1, batch size, dec hid dim]
        #this also means that output == hidden
                
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        
        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim = 1))
        
        #prediction = [batch size, output dim]
        
        return prediction, hidden.squeeze(0)

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        
        #src = [src len, batch size]
        #trg = [trg len, batch size]
        #teacher_forcing_ratio is probability to use teacher forcing
        #e.g. if teacher_forcing_ratio is 0.75 we use teacher forcing 75% of the time
        
        batch_size = src.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        
        #tensor to store decoder outputs
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        
        #encoder_outputs is all hidden states of the input sequence, back and forwards
        #hidden is the final forward and backward hidden states, passed through a linear layer
        encoder_outputs, hidden = self.encoder(src)
                
        #first input to the decoder is the <sos> tokens
        input = trg[0,:]
        
        for t in range(1, trg_len):
            
            #insert input token embedding, previous hidden state and all encoder hidden states
            #receive output tensor (predictions) and new hidden state
            output, hidden = self.decoder(input, hidden, encoder_outputs)
            
            #place predictions in a tensor holding predictions for each token
            outputs[t] = output
            
            #decide if we are going to use teacher forcing or not
            teacher_force = random.random() < teacher_forcing_ratio
            
            #get the highest predicted token from our predictions
            top1 = output.argmax(1) 
            
            #if teacher forcing, use actual next token as next input
            #if not, use predicted token
            input = trg[t] if teacher_force else top1

        return outputs

In [None]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
ENC_HID_DIM = 512
DEC_HID_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

attn = Attention(ENC_HID_DIM, DEC_HID_DIM)
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, ENC_HID_DIM, DEC_HID_DIM, DEC_DROPOUT, attn)

model = Seq2Seq(enc, dec, device).to(device)

In [None]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)
            
model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(5893, 256)
    (rnn): GRU(256, 512, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (attention): Attention(
      (attn): Linear(in_features=1536, out_features=512, bias=True)
      (v): Linear(in_features=512, out_features=1, bias=False)
    )
    (embedding): Embedding(7855, 256)
    (rnn): GRU(1280, 512)
    (fc_out): Linear(in_features=1792, out_features=7855, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 24,036,783 trainable parameters


In [None]:
optimizer = optim.Adam(model.parameters())

In [None]:
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]

criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

In [None]:
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        
        src = batch.src
        trg = batch.trg
        
        optimizer.zero_grad()
        
        output = model(src, trg)
        
        #trg = [trg len, batch size]
        #output = [trg len, batch size, output dim]
        
        output_dim = output.shape[-1]
        
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        
        #trg = [(trg len - 1) * batch size]
        #output = [(trg len - 1) * batch size, output dim]
        
        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [None]:
def evaluate(model, iterator, criterion):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            src = batch.src
            trg = batch.trg

            output = model(src, trg, 0) #turn off teacher forcing

            #trg = [trg len, batch size]
            #output = [trg len, batch size, output dim]

            output_dim = output.shape[-1]
            
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)

            #trg = [(trg len - 1) * batch size]
            #output = [(trg len - 1) * batch size, output dim]

            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
N_EPOCHS = 20
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'en2de-attn-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

Epoch: 01 | Time: 1m 32s
	Train Loss: 3.393 | Train PPL:  29.746
	 Val. Loss: 3.588 |  Val. PPL:  36.147
Epoch: 02 | Time: 1m 33s
	Train Loss: 2.819 | Train PPL:  16.767
	 Val. Loss: 3.238 |  Val. PPL:  25.490
Epoch: 03 | Time: 1m 32s
	Train Loss: 2.404 | Train PPL:  11.065
	 Val. Loss: 3.114 |  Val. PPL:  22.513
Epoch: 04 | Time: 1m 33s
	Train Loss: 2.057 | Train PPL:   7.824
	 Val. Loss: 3.087 |  Val. PPL:  21.907
Epoch: 05 | Time: 1m 33s
	Train Loss: 1.797 | Train PPL:   6.033
	 Val. Loss: 3.095 |  Val. PPL:  22.091
Epoch: 06 | Time: 1m 34s
	Train Loss: 1.591 | Train PPL:   4.908
	 Val. Loss: 3.171 |  Val. PPL:  23.833
Epoch: 07 | Time: 1m 33s
	Train Loss: 1.443 | Train PPL:   4.232
	 Val. Loss: 3.136 |  Val. PPL:  23.021
Epoch: 08 | Time: 1m 32s
	Train Loss: 1.326 | Train PPL:   3.766
	 Val. Loss: 3.222 |  Val. PPL:  25.071
Epoch: 09 | Time: 1m 33s
	Train Loss: 1.212 | Train PPL:   3.361
	 Val. Loss: 3.227 |  Val. PPL:  25.193
Epoch: 10 | Time: 1m 32s
	Train Loss: 1.115 | Train PPL

In [None]:
model.load_state_dict(torch.load('en2de-attn-model.pt'))

test_loss = evaluate(model, test_iterator, criterion)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

| Test Loss: 3.057 | Test PPL:  21.262 |
