# Machine Translation + Attention

<img src="../figures/attention1.jpg" width="800">
<img src="../figures/attention2.jpg" width="800">
<img src="../figures/attention3.jpg" width="800">

In [1]:
import torch, torchdata, torchtext
from torch import nn
import torch.nn.functional as F

import random, math, time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [2]:
# torch.cuda.get_device_name(0)

In [3]:
torch.__version__

In [4]:
torchtext.__version__

## 1. ETL: Loading the dataset

In [5]:
from datasets import load_dataset

EN_LANGUAGE = 'en'
DE_LANGUAGE = 'de'
# Use "de-en" as dataset doesn't have en-de and treat English as source, German as target.
LANG_PAIR = f"{DE_LANGUAGE}-{EN_LANGUAGE}"
print("Translation Language Pair:", LANG_PAIR)

dataset = load_dataset("wmt14", LANG_PAIR)

In [6]:
#so this is a datapipe object; very similar to pytorch dataset version 2 which is better
dataset

In [7]:
train = dataset['train']
valid = dataset['validation']
test = dataset['test']

## 2. EDA - simple investigation

In [8]:
#let's take a look at one example of train
sample = next(iter(dataset['train']))
# WMT14 returns {"translation": {"de": "...", "en": "..."}}
print("Sample structure:", sample)
print("English:", sample["translation"]["en"])
print("German:", sample["translation"]["de"])

Since 29001 is plenty,, we gonna call `random_split` to train, val and test

In [9]:
train = dataset["train"]
val = dataset["validation"]
test = dataset["test"]

In [10]:
train_size = len(list(iter(train)))
train_size

In [11]:
val_size = len(list(iter(val)))
val_size
#5800

In [12]:
test_size = len(list(iter(test)))
test_size
#2900

## 3. Preprocessing 

### Tokenizing

**Note**: the models must first be downloaded using the following on the command line: 
```
python3 -m spacy download en_core_web_sm
python3 -m spacy download de_core_news_sm
```

uv command

```
uv add spacy
uv add pip

uv run python3 -m spacy download en_core_web_sm

uv run python3 -m spacy download de_core_news_sm


```

First, since we have two languages, let's create some constants to represent that.  Also, let's create two dicts: one for holding our tokenizers and one for holding all the vocabs with assigned numbers for each unique word

In [13]:
# Place-holders
token_transform = {}
vocab_transform = {}

In [14]:
from torchtext.data.utils import get_tokenizer
token_transform[EN_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')
token_transform[DE_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')

In [15]:
#example of tokenization of the english part
print("Sentence: ", sample["translation"]["en"])
print("Tokenization: ", token_transform[EN_LANGUAGE](sample["translation"]["en"]))

A function to tokenize our input.

In [16]:
# Faster tokenization using multiprocessing
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
import spacy

# Load spaCy models directly (faster than get_tokenizer for batch processing)
spacy_en = spacy.load("en_core_web_sm", disable=["tagger", "parser", "ner", "lemmatizer"])
spacy_de = spacy.load("de_core_news_sm", disable=["tagger", "parser", "ner", "lemmatizer"])

def tokenize_batch_en(texts):
    """Batch tokenize English texts"""
    return [[tok.text for tok in doc] for doc in spacy_en.pipe(texts, batch_size=1000)]

def tokenize_batch_de(texts):
    """Batch tokenize German texts"""
    return [[tok.text for tok in doc] for doc in spacy_de.pipe(texts, batch_size=1000)]

def yield_tokens_fast(data, language, max_samples=None):
    """Fast tokenization using spaCy's pipe() for batch processing"""
    # Collect texts first
    texts = []
    for i, sample in enumerate(tqdm(data, desc=f"Collecting {language}", total=max_samples)):
        if max_samples and i >= max_samples:
            break
        texts.append(sample["translation"][language])
    
    # Batch tokenize (much faster!)
    print(f"Batch tokenizing {len(texts)} {language} sentences...")
    spacy_model = spacy_en if language == "en" else spacy_de
    for doc in tqdm(spacy_model.pipe(texts, batch_size=1000, n_process=1), 
                    total=len(texts), desc=f"Tokenizing {language}"):
        yield [tok.text for tok in doc]

Before we tokenize, let's define some special symbols so our neural network understand the embeddings of these symbols, namely the unknown, the padding, the start of sentence, and end of sentence.

In [17]:
# Define special symbols and indices
UNK_IDX, PAD_IDX, SOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<sos>', '<eos>']

### Text to integers (Numericalization)

Next we gonna create function (torchtext called vocabs) that turn these tokens into integers.  Here we use built in factory function <code>build_vocab_from_iterator</code> which accepts iterator that yield list or iterator of tokens.

In [18]:
# torchtext.vocab replacement - Vocab class to mimic torchtext API
from collections import Counter

special_symbols = ['<unk>', '<pad>', '<sos>', '<eos>']
UNK_IDX, PAD_IDX, SOS_IDX, EOS_IDX = 0, 1, 2, 3

class Vocab:
    """A simple Vocab class to replace torchtext.vocab"""
    def __init__(self, stoi, itos, default_index=0):
        self.stoi = stoi  # string to index
        self.itos = itos  # index to string (list)
        self.default_index = default_index
    
    def __call__(self, tokens):
        """Convert list of tokens to list of indices"""
        return [self.stoi.get(token, self.default_index) for token in tokens]
    
    def __len__(self):
        return len(self.itos)
    
    def __getitem__(self, token):
        return self.stoi.get(token, self.default_index)
    
    def set_default_index(self, index):
        self.default_index = index
    
    def get_itos(self):
        return self.itos

def build_vocab(token_iterator, min_freq=2, specials=None, special_first=True):
    """Build vocabulary from token iterator"""
    if specials is None:
        specials = []
    
    # Count token frequencies
    counter = Counter()
    for tokens in token_iterator:
        counter.update(tokens)
    
    # Build itos (index to string) list
    itos = []
    if special_first:
        itos.extend(specials)
    
    # Add tokens that meet min_freq threshold
    for token, freq in counter.items():
        if freq >= min_freq and token not in specials:
            itos.append(token)
    
    if not special_first:
        itos.extend(specials)
    
    # Build stoi (string to index) dict
    stoi = {token: idx for idx, token in enumerate(itos)}
    
    return Vocab(stoi, itos)

In [19]:
# Build vocabulary for both languages
# Use a subset of training data for faster vocab building
SRC_LANGUAGE = EN_LANGUAGE
TRG_LANGUAGE = DE_LANGUAGE

VOCAB_MAX_SAMPLES = 50_000  # 50k is usually enough, faster than 100k

for ln in [SRC_LANGUAGE, TRG_LANGUAGE]:
    # Create Vocab object using custom build_vocab function with fast tokenization
    vocab_transform[ln] = build_vocab(yield_tokens_fast(train, ln, max_samples=VOCAB_MAX_SAMPLES), 
                                      min_freq=2,
                                      specials=special_symbols,
                                      special_first=True)

# Set UNK_IDX as the default index
for ln in [SRC_LANGUAGE, TRG_LANGUAGE]:
    vocab_transform[ln].set_default_index(UNK_IDX)

print(f"English vocab size: {len(vocab_transform[SRC_LANGUAGE])}")
print(f"German vocab size: {len(vocab_transform[TRG_LANGUAGE])}")

In [20]:
#see some example
vocab_transform[SRC_LANGUAGE](['here', 'is', 'a', 'unknownword', 'a'])

In [21]:
#we can reverse it....
mapping = vocab_transform[SRC_LANGUAGE].get_itos()

#print 1891, for example
mapping[1891]

In [22]:
#let's try unknown vocab
mapping[0]
#they will all map to <unk> which has 0 as integer

In [23]:
#let's try special symbols
mapping[1], mapping[2], mapping[3]

In [24]:
#check unique vocabularies
len(mapping)

## 4. Preparing the dataloader

One thing we change here is the <code>collate_fn</code> which now also returns the length of sentence.  This is required for <code>packed_padded_sequence</code>

In [25]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

BATCH_SIZE = 64

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids):
    return torch.cat((torch.tensor([SOS_IDX]), 
                      torch.tensor(token_ids), 
                      torch.tensor([EOS_IDX])))

# src and trg language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TRG_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tensors
# Updated for Hugging Face WMT14 dataset structure
def collate_batch(batch):
    src_batch, src_len_batch, trg_batch = [], [], []
    for data_sample in batch:
        # WMT14 dataset structure: {"translation": {"de": "...", "en": "..."}}
        src_sample = data_sample["translation"][SRC_LANGUAGE]
        trg_sample = data_sample["translation"][TRG_LANGUAGE]
        
        processed_text = text_transform[SRC_LANGUAGE](src_sample.rstrip("\n"))
        src_batch.append(processed_text)
        trg_batch.append(text_transform[TRG_LANGUAGE](trg_sample.rstrip("\n")))
        src_len_batch.append(processed_text.size(0))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    trg_batch = pad_sequence(trg_batch, padding_value=PAD_IDX)
    return src_batch, torch.tensor(src_len_batch, dtype=torch.int64), trg_batch

Create train, val, and test dataloaders

In [26]:
batch_size = 64

train_loader = DataLoader(train, batch_size=batch_size, shuffle=True,  collate_fn=collate_batch)
valid_loader = DataLoader(val,   batch_size=batch_size, shuffle=False, collate_fn=collate_batch)
test_loader  = DataLoader(test,  batch_size=batch_size, shuffle=False, collate_fn=collate_batch)

Let's test the train loader.

In [27]:
for en, _, de in train_loader:
    break

In [28]:
print("English shape: ", en.shape)  # (seq len, batch_size)
print("German shape: ", de.shape)   # (seq len, batch_size)

## 5. Design the model

### Seq2Seq

In [29]:
class Seq2SeqPackedAttention(nn.Module):
    def __init__(self, encoder, decoder, src_pad_idx, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.device  = device
        
    def create_mask(self, src):
        #src: [src len, batch_size]
        mask = (src == self.src_pad_idx).permute(1, 0)  #permute so that it's the same shape as attention
        #mask: [batch_size, src len] #(0, 0, 0, 0, 0, 1, 1)
        return mask
        
    def forward(self, src, src_len, trg, teacher_forcing_ratio = 0.5):
        #src: [src len, batch_size]
        #trg: [trg len, batch_size]
        
        #initialize something
        batch_size = src.shape[1]
        trg_len    = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        
        outputs    = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        attentions = torch.zeros(trg_len, batch_size, src.shape[0]).to(self.device)
        
        #send our src text into encoder
        encoder_outputs, hidden = self.encoder(src, src_len)
        #encoder_outputs refer to all hidden states (last layer)
        #hidden refer to the last hidden state (of each layer, of each direction)
        
        input_ = trg[0, :]
        
        mask   = self.create_mask(src) #(0, 0, 0, 0, 0, 1, 1)
        
        #for each of the input of the trg text
        for t in range(1, trg_len):
            #send them to the decoder
            output, hidden, attention = self.decoder(input_, hidden, encoder_outputs, mask)
            #output: [batch_size, output_dim] ==> predictions
            #hidden: [batch_size, hid_dim]
            #attention: [batch_size, src len]
            
            #append the output to a list
            outputs[t] = output
            attentions[t] = attention
            
            teacher_force = random.random() < teacher_forcing_ratio
            top1          = output.argmax(1)  #autoregressive
            
            input_ = trg[t] if teacher_force else top1
            
        return outputs, attentions

### Encoder

In [30]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn       = nn.GRU(emb_dim, hid_dim, bidirectional=True)
        self.fc        = nn.Linear(hid_dim * 2, hid_dim)
        self.dropout   = nn.Dropout(dropout)
        
    def forward(self, src, src_len):
        #embedding
        embedded = self.dropout(self.embedding(src))
        #packed
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, src_len.to('cpu'), enforce_sorted=False)
        #rnn
        packed_outputs, hidden = self.rnn(packed_embedded)
        #unpacked
        outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs)
        #-1, -2 hidden state
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim = 1)))
        
        #outputs: [src len, batch_size, hid dim * 2]
        #hidden:  [batch_size, hid_dim]
        
        return outputs, hidden
        

### Attention

The attention used here is additive attention which is defined by:

$$e = v\text{tanh}(Uh + Ws + b)$$

The `forward` method now takes a `mask` input. This is a `[batch size, source sentence length]` tensor that is 1 when the source sentence token is not a padding token, and 0 when it is a padding token. For example, if the source sentence is: `["hello", "how", "are", "you", "?", `<pad>`, `<pad>`]`, then the mask would be `[1, 1, 1, 1, 1, 0, 0]`.

We apply the mask after the attention has been calculated, but before it has been normalized by the `softmax` function. It is applied using `masked_fill`. This fills the tensor at each element where the first argument (`mask == 0`) is true, with the value given by the second argument (`-1e10`). In other words, it will take the un-normalized attention values, and change the attention values over padded elements to be `-1e10`. As these numbers will be miniscule compared to the other values they will become zero when passed through the `softmax` layer, ensuring no attention is payed to padding tokens in the source sentence.

In [31]:
class Attention(nn.Module):
    def __init__(self, hid_dim):
        super().__init__()
        self.v = nn.Linear(hid_dim, 1, bias = False)
        self.W = nn.Linear(hid_dim, hid_dim) #for decoder input_
        self.U = nn.Linear(hid_dim * 2, hid_dim)  #for encoder_outputs
    
    def forward(self, hidden, encoder_outputs, mask):
        #hidden = [batch_size, hid_dim] ==> first hidden is basically the last hidden of the encoder
        #encoder_outputs = [src len, batch_size, hid_dim * 2]
        
        batch_size = encoder_outputs.shape[1]
        src_len    = encoder_outputs.shape[0]
        
        #repeat the hidden src len times
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        #hidden = [batch_size, src_len, hid_dim]
        
        #permute the encoder_outputs just so that you can perform multiplication / addition
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        #encoder_outputs = [batch_size, src_len, hid_dim * 2]
        
        #add
        energy = self.v(torch.tanh(self.W(hidden) + self.U(encoder_outputs))).squeeze(2)
        #(batch_size, src len, 1) ==> (batch_size, src len)
        
        #mask
        energy = energy.masked_fill(mask, -1e10)
        
        return F.softmax(energy, dim = 1)

### Decoder

In [32]:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, dropout, attention):
        super().__init__()
        self.output_dim = output_dim
        self.attention  = attention
        self.embedding  = nn.Embedding(output_dim, emb_dim)
        self.rnn        = nn.GRU((hid_dim * 2) + emb_dim, hid_dim)
        self.fc         = nn.Linear((hid_dim * 2) + hid_dim + emb_dim, output_dim)
        self.dropout    = nn.Dropout(dropout)
        
    def forward(self, input, hidden, encoder_outputs, mask):
        #input: [batch_size]
        #hidden: [batch_size, hid_dim]
        #encoder_ouputs: [src len, batch_size, hid_dim * 2]
        #mask: [batch_size, src len]
                
        #embed our input
        input    = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))
        #embedded = [1, batch_size, emb_dim]
        
        #calculate the attention
        a = self.attention(hidden, encoder_outputs, mask)
        #a = [batch_size, src len]
        a = a.unsqueeze(1)
        #a = [batch_size, 1, src len]
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        #encoder_ouputs: [batch_size, src len, hid_dim * 2]
        weighted = torch.bmm(a, encoder_outputs)
        #weighted: [batch_size, 1, hid_dim * 2]
        weighted = weighted.permute(1, 0, 2)
        #weighted: [1, batch_size, hid_dim * 2]
        
        #send the input to decoder rnn
            #concatenate (embed, weighted encoder_outputs)
            #[1, batch_size, emb_dim]; [1, batch_size, hid_dim * 2]
        rnn_input = torch.cat((embedded, weighted), dim = 2)
        #rnn_input: [1, batch_size, emb_dim + hid_dim * 2]
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
            
        #send the output of the decoder rnn to fc layer to predict the word
            #prediction = fc(concatenate (output, weighted, embed))
        embedded = embedded.squeeze(0)
        output   = output.squeeze(0)
        weighted = weighted.squeeze(0)
        prediction = self.fc(torch.cat((embedded, output, weighted), dim = 1))
        #prediction: [batch_size, output_dim]
            
        return prediction, hidden.squeeze(0), a.squeeze(1)

## 6. Training

We use a simplified version of the weight initialization scheme used in the paper. Here, we will initialize all biases to zero and all weights from $\mathcal{N}(0, 0.01)$.

In [33]:
def initialize_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

In [34]:
input_dim   = len(vocab_transform[SRC_LANGUAGE])
output_dim  = len(vocab_transform[TRG_LANGUAGE])
emb_dim     = 128 #256  
hid_dim     = 256 #512  
dropout     = 0.1
SRC_PAD_IDX = PAD_IDX

attn = Attention(hid_dim)
enc  = Encoder(input_dim,  emb_dim,  hid_dim, dropout)
dec  = Decoder(output_dim, emb_dim,  hid_dim, dropout, attn)

model = Seq2SeqPackedAttention(enc, dec, SRC_PAD_IDX, device).to(device)
model.apply(initialize_weights)

In [35]:
#we can print the complexity by the number of parameters
def count_parameters(model):
    params = [p.numel() for p in model.parameters() if p.requires_grad]
    for item in params:
        print(f'{item:>6}')
    print(f'______\n{sum(params):>6}')
    
count_parameters(model)

Our loss function calculates the average loss per token, however by passing the index of the `<pad>` token as the `ignore_index` argument we ignore the loss whenever the target token is a padding token. 

In [36]:
import torch.optim as optim

lr = 0.001

#training hyperparameters
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX) #combine softmax with cross entropy

The training is very similar to part 1.

In [37]:
def train(model, loader, optimizer, criterion, clip, loader_length):
    
    model.train()
    epoch_loss = 0
    
    for src, src_length, trg in loader:
        
        src = src.to(device)
        trg = trg.to(device)
        
        optimizer.zero_grad()
        
        output, attentions = model(src, src_length, trg)
        
        #trg    = [trg len, batch size]
        #output = [trg len, batch size, output dim]
        output_dim = output.shape[-1]
        
        #the loss function only works on 2d inputs with 1d targets thus we need to flatten each of them
        output = output[1:].view(-1, output_dim)
        trg    = trg[1:].view(-1)
        #trg    = [(trg len - 1) * batch size]
        #output = [(trg len - 1) * batch size, output dim]
        
        loss = criterion(output, trg)
        
        loss.backward()
        
        #clip the gradients to prevent them from exploding (a common issue in RNNs)
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / loader_length

Our evaluation loop is similar to our training loop, however as we aren't updating any parameters we don't need to pass an optimizer or a clip value.

In [None]:
def evaluate(model, loader, criterion, loader_length):
        
    #turn off dropout (and batch norm if used)
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():
    
        for src, src_length, trg in loader:
        
            src = src.to(device)
            trg = trg.to(device)

            output, attentions = model(src, src_length, trg, 0) #turn off teacher forcing

            #trg    = [trg len, batch size]
            #output = [trg len, batch size, output dim]

            output_dim = output.shape[-1]
            
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)

            #trg    = [(trg len - 1) * batch size]
            #output = [(trg len - 1) * batch size, output dim]

            loss = criterion(output, trg)
            
            epoch_loss += loss.item()
        
    return epoch_loss / loader_length

### Putting everything together

In [None]:
train_loader_length = len(list(iter(train_loader)))
val_loader_length   = len(list(iter(valid_loader)))
test_loader_length  = len(list(iter(test_loader)))

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
best_valid_loss = float('inf')
num_epochs = 10
clip       = 1

save_path = f'models/{model.__class__.__name__}.pt'

train_losses = []
valid_losses = []

for epoch in range(num_epochs):
    
    start_time = time.time()

    train_loss = train(model, train_loader, optimizer, criterion, clip, train_loader_length)
    valid_loss = evaluate(model, valid_loader, criterion, val_loader_length)
    
    #for plotting
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), save_path)
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
    
    #lower perplexity is better

In [None]:
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(5, 3))
ax = fig.add_subplot(1, 1, 1)
ax.plot(train_losses, label = 'train loss')
ax.plot(valid_losses, label = 'valid loss')
plt.legend()
ax.set_xlabel('updates')
ax.set_ylabel('loss')

In [None]:
model.load_state_dict(torch.load(save_path))
test_loss = evaluate(model, test_loader, criterion, test_loader_length)

print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')

In [None]:
model

## 7. Test on some random news

In [None]:
sample[0]

In [None]:
sample[1]

In [None]:
src_text = text_transform[SRC_LANGUAGE](sample[0]).to(device)
src_text

In [None]:
trg_text = text_transform[TRG_LANGUAGE](sample[1]).to(device)
trg_text

In [None]:
src_text = src_text.reshape(-1, 1)  #because batch_size is 1

In [None]:
trg_text = trg_text.reshape(-1, 1)

In [None]:
src_text.shape, trg_text.shape

In [None]:
text_length = torch.tensor([src_text.size(0)]).to(dtype=torch.int64)

In [None]:
model.load_state_dict(torch.load(save_path))

model.eval()
with torch.no_grad():
    output, attentions = model(src_text, text_length, trg_text, 0) #turn off teacher forcing

In [None]:
output.shape #trg_len, batch_size, trg_output_dim

Since batch size is 1, we just take off that dimension

In [None]:
output = output.squeeze(1)

In [None]:
output.shape

We shall remove the first token since it's zeroes anyway

In [None]:
output = output[1:]
output.shape #trg_len, trg_output_dim

Then we just take the top token with highest probabilities

In [None]:
output_max = output.argmax(1) #returns max indices

In [None]:
output_max

Get the mapping of the target language

In [None]:
mapping = vocab_transform[TRG_LANGUAGE].get_itos()

In [None]:
for token in output_max:
    print(mapping[token.item()])

## 8. Attention

Let's display the attentions to understand how the source text links with the generated text

In [None]:
attentions.shape

In [None]:
src_tokens = ['<sos>'] + token_transform[SRC_LANGUAGE](sample[0]) + ['<eos>']
src_tokens

In [None]:
trg_tokens = ['<sos>'] + [mapping[token.item()] for token in output_max]
trg_tokens

In [None]:
import matplotlib.ticker as ticker

def display_attention(sentence, translation, attention):
    
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(111)
    
    attention = attention.squeeze(1).cpu().detach().numpy()
    
    cax = ax.matshow(attention, cmap='bone')
   
    ax.tick_params(labelsize=10)
    
    y_ticks =  [''] + translation
    x_ticks =  [''] + sentence 
     
    ax.set_xticklabels(x_ticks, rotation=45)
    ax.set_yticklabels(y_ticks)

    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.show()
    plt.close()

In [None]:
display_attention(src_tokens, trg_tokens, attentions)