# [AI 504] Programming for AI, Fall 2021
# Practice 10: Transformers
----- 

#### [Notifications]
- If you have any questions, feel free to ask
- For additional questions, send emails: yeonsu.k@kaist.ac.kr    
      

     
     
# Table of contents
1. [Prepare input](#1)
2. [Implement Transformer](#2)
3. [Train and Evaluate](#3)
4. [Visualize attention](#4)


# Prepare essential packages

In [None]:
%matplotlib inline
!pip install torchtext==0.10.0
!git clone https://github.com/sjpark9503/attentionviz.git
!python -m spacy download de
!python -m spacy download en

# I. Prepare input
<a id='1'></a>

In [None]:
!git clone --recursive https://github.com/multi30k/dataset.git multi30k-datase

In [None]:
!find multi30k-datase/ -name '*.gz' -exec gunzip {} \;

We've already learned how to preprocess the text data in week 8, 9 & 10.

You can see some detailed explanation about translation datasets in [torchtext](https://pytorch.org/text/), [practice session,week 9](https://classum.com/main/course/7726/103) and [PyTorch NMT tutorial](https://pytorch.org/tutorials/beginner/torchtext_translation_tutorial.html)

In [None]:
import spacy
import torch
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import os
import io
from torch.nn.utils.rnn import pad_sequence

spacy_de = spacy.load('de_core_news_sm')
spacy_en = spacy.load('en_core_web_sm')

def tokenize_de(text):
    return [tok.text.lower() for tok in spacy_de.tokenizer(text)]

def tokenize_en(text):
    return [tok.text.lower() for tok in spacy_en.tokenizer(text)]

class TranslationDataset(Dataset):
    def __init__(self, root_dir, split):
        self.root_dir = root_dir
        self.split = split

        self.data_files = {
            'train': ('train.de', 'train.en'),
            'valid': ('val.de', 'val.en'),
            'test': ('test_2016_flickr.de', 'test_2016_flickr.en')
        }

        self.de_file_path = os.path.join(self.root_dir, self.data_files[self.split][0])
        self.en_file_path = os.path.join(self.root_dir, self.data_files[self.split][1])

        with io.open(self.de_file_path, mode='r', encoding='utf-8') as de_file, \
             io.open(self.en_file_path, mode='r', encoding='utf-8') as en_file:
            self.de_sentences = de_file.readlines()
            self.en_sentences = en_file.readlines()

    def __len__(self):
        return len(self.de_sentences)

    def __getitem__(self, idx):
        de_sentence = tokenize_de(self.de_sentences[idx].strip())
        en_sentence = tokenize_en(self.en_sentences[idx].strip())
        return {'SRC': de_sentence, 'TRG': en_sentence}

class Vocab:
    def __init__(self, counter, min_freq):
        self.itos = ['<pad>', '<sos>', '<eos>', '<unk>']
        self.stoi = {token: i for i, token in enumerate(self.itos)}
        self.min_freq = min_freq
        self.build_vocab(counter)

    def build_vocab(self, counter):
        for word, freq in counter.items():
            if freq >= self.min_freq and word not in self.stoi:
                self.stoi[word] = len(self.itos)
                self.itos.append(word)

    def numericalize(self, tokens):
        return [self.stoi.get(token, self.stoi['<unk>']) for token in tokens]

def build_counter(dataset):
    counter = Counter()
    for i in range(len(dataset)):
        example = dataset[i]
        counter.update(example['SRC'])
        counter.update(example['TRG'])
    return counter

train_data = TranslationDataset(root_dir='/content/multi30k-datase/data/task1/raw', split='train')
valid_data = TranslationDataset(root_dir='/content/multi30k-datase/data/task1/raw', split='valid')
test_data = TranslationDataset(root_dir='/content/multi30k-datase/data/task1/raw', split='test')

counter = build_counter(train_data)
SRC_vocab = Vocab(counter, min_freq=2)
TRG_vocab = Vocab(counter, min_freq=2)

def collate_fn(batch):
    src_batch = [torch.tensor(SRC_vocab.numericalize(item['SRC'])) for item in batch]
    trg_batch = [torch.tensor(TRG_vocab.numericalize(item['TRG'])) for item in batch]

    src_batch_padded = pad_sequence(src_batch, padding_value=SRC_vocab.stoi['<pad>'], batch_first=True)
    trg_batch_padded = pad_sequence(trg_batch, padding_value=TRG_vocab.stoi['<pad>'], batch_first=True)

    return {'SRC': src_batch_padded, 'TRG': trg_batch_padded}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 128


train_iterator = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
valid_iterator = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
test_iterator = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

# II. Implement Transformer
<a id='2'></a>
In practice week 11, we will learn how to implement the __[Attention is all you need](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) (Vaswani et al., 2017)__

The overall architecutre is as follows:
![picture](http://incredible.ai/assets/images/transformer-architecture.png)

## 1. Basic building blocks

In this sections, we will implement the building blocks of the transformer: [Multi-head attention](#1a), [Position wise feedforward network](#1b) and [Positional encoding](#1c)

### a. Attention
<a id='1a'></a>
In this section, you will implement scaled dot-product attention and multi-head attention.

__Scaled dot product:__

![picture](http://incredible.ai/assets/images/transformer-scaled-dot-product.png)

__Multi-head attention:__

![picture](http://jalammar.github.io/images/t/transformer_multi-headed_self-attention-recap.png)
Equation:

$$\begin{align} \text{MultiHead}(Q, K, V) &= \text{Concat}(head_1, ...., head_h) W^O \\
\text{where head}_i &= \text{Attention} \left( QW^Q_i, K W^K_i, VW^v_i \right)
\end{align}$$

__Query, Key and Value projection:__

![picture](http://jalammar.github.io/images/t/self-attention-matrix-calculation.png)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time

class MultiHeadAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        emb_dim,
        num_heads,
        dropout=0.0,
        bias=False,
        encoder_decoder_attention=False,  # otherwise self_attention
        causal = False
    ):
        super().__init__()
        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = emb_dim // num_heads
        assert self.head_dim * num_heads == self.emb_dim, "emb_dim must be divisible by num_heads"

        self.encoder_decoder_attention = encoder_decoder_attention
        self.causal = causal
        self.k_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
        self.v_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
        self.q_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
        self.out_proj = nn.Linear(emb_dim, emb_dim, bias=bias)

    def transpose_for_scores(self, x):
        """
        To-Do : Reshape input
          Args : batch_size X sequence_length X embedding dimension
          Return : batch_size X # attention head X sequence_length X head dimension
        """
        return x.permute(0, 2, 1, 3)
        # This is equivalent to
        # return x.transpose(1,2)
    
    def scaled_dot_product(self, 
                           query: torch.Tensor, 
                           key: torch.Tensor, 
                           value: torch.Tensor,
                           attention_mask: torch.BoolTensor):
        """
        To-Do : Implement scaled dot product
          Args:
            Query (Tensor): shape `(batch, seq_len, emb_dim)`
            Key (Tensor): shape `(batch, seq_len, emb_dim)`
            Value (Tensor): shape `(batch, seq_len, emb_dim)`
            attention_mask: binary BoolTensor of shape `(batch, seq_len)` or `(seq_len, seq_len)`

          Returns:
            attn_output : attended output (result of attention mechanism)
            attn_weights: value of each attention
        """
        return attn_output, attn_weights
    
    def MultiHead_scaled_dot_product(self, 
                       query: torch.Tensor, 
                       key: torch.Tensor, 
                       value: torch.Tensor,
                       attention_mask: torch.BoolTensor):
        """
        To-Do : Implement Multi-head version of scaled dot product, please also take the causal masking into account.
          Args:
            Query (Tensor): shape `(batch,# attention head, seq_len, head_dim)`
            Key (Tensor): shape `(batch,# attention head, seq_len, head_dim)`
            Value (Tensor): shape `(batch,# attention head, seq_len, head_dim)`
            attention_mask: binary BoolTensor of shape `(batch, src_len)` or `(seq_len, seq_len)`

          Returns:
            attn_output : attended output (result of attention mechanism)
            attn_weights: value of each attention
        """

        return attn_output, attn_weights

        
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        attention_mask: torch.Tensor = None,
        ):
        q = self.q_proj(query)
        # Enc-Dec attention
        if self.encoder_decoder_attention:
            k = self.k_proj(key)
            v = self.v_proj(key)
        # Self attention
        else:
            k = self.k_proj(query)
            v = self.v_proj(query)

        q = self.transpose_for_scores(q)
        k = self.transpose_for_scores(k)
        v = self.transpose_for_scores(v)

        attn_output, attn_weights = self.MultiHead_scaled_dot_product(q,k,v,attention_mask)
        return attn_output, attn_weights


### b. Position-wise feed forward network
<a id='1b'></a>
In this section, we will implement position-wise feed forward network

$$\text{FFN}(x) = \max \left(0, x W_1 + b_1 \right) W_2 + b_2$$

In [None]:
class PositionWiseFeedForward(nn.Module):

    def __init__(self, emb_dim: int, d_ff: int, dropout: float = 0.1):
        super(PositionWiseFeedForward, self).__init__()

        self.activation = nn.ReLU()
        self.w_1 = nn.Linear(emb_dim, d_ff)
        self.w_2 = nn.Linear(d_ff, emb_dim)
        self.dropout = dropout

    def forward(self, x):
        """
        To-Do : Implement position-wise feed forward network
          Args:
            x (Tensor): input to the layer of shape `(batch, seq_len, emb_dim)`
        """
        return x + residual

### c. Sinusoidal Positional Encoding
<a id='1c'></a>
In this section, we will implement sinusoidal positional encoding

$$\begin{align}
PE(pos, 2i) &= \sin \left( pos / 10000^{2i / d_{model}} \right)  \\
PE(pos, 2i+1) &= \cos \left( pos / 10000^{2i / d_{model}} \right)  
\end{align}$$

In [None]:
import numpy as np

class SinusoidalPositionalEmbedding(nn.Embedding):
    def __init__(self, num_positions, embedding_dim, padding_idx=None):
        super().__init__(num_positions, embedding_dim)
        self.weight = self._init_weight(self.weight)
    
    @staticmethod
    def _init_weight(out: nn.Parameter):
        n_pos, embed_dim = out.shape
        pe = nn.Parameter(torch.zeros(out.shape))
        for pos in range(n_pos):
            for i in range(0, embed_dim, 2):
              """
              To-Do : Implement sinusoidal positional encoding
              """
        pe.detach_()
                
        return pe

    @torch.no_grad()
    def forward(self, input_ids):
        bsz, seq_len = input_ids.shape[:2]
        positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
        return super().forward(positions)

## 2. Transformer Encoder

Now we have all basic building blocks which are essential to build Transformer. 

Let's implement Transformer step-by-step

### a. Encoder layer
In this section, we will implement single layer of Transformer encoder.
![picture](https://www.researchgate.net/publication/334288604/figure/fig1/AS:778232232148992@1562556431066/The-Transformer-encoder-structure.ppm)

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.emb_dim = config.emb_dim
        self.ffn_dim = config.ffn_dim
        self.self_attn = MultiHeadAttention(            
            emb_dim=self.emb_dim,
            num_heads=config.attention_heads, 
            dropout=config.attention_dropout)
        self.self_attn_layer_norm = nn.LayerNorm(self.emb_dim)
        self.dropout = config.dropout
        self.activation_fn = nn.ReLU()
        self.PositionWiseFeedForward = PositionWiseFeedForward(self.emb_dim, self.ffn_dim, config.dropout)
        self.final_layer_norm = nn.LayerNorm(self.emb_dim)

    def forward(self, x, encoder_padding_mask):
        """
        To-Do : Implement transformer encoder layer
          Args:
            x (Tensor): input to the layer of shape `(batch, seq_len, emb_dim)`
            encoder_padding_mask: binary BoolTensor of shape `(batch, src_len)`

          Returns:
            x : encoded output of shape `(batch, seq_len, emb_dim)`
            self_attn_weights: self attention score
        """
        return x, attn_weights

### b. Encoder

Stack encoder layers and build full Transformer encoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, config, embed_tokens):
        super().__init__()

        self.dropout = config.dropout

        emb_dim = embed_tokens.embedding_dim
        self.padding_idx = embed_tokens.padding_idx
        self.max_source_positions = config.max_position_embeddings

        self.embed_tokens = embed_tokens
        self.embed_positions = SinusoidalPositionalEmbedding(
                config.max_position_embeddings, config.emb_dim, self.padding_idx
            )

        self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])

    def forward(self, input_ids, attention_mask=None):
        """
        To-Do : Implement the transformer encoder
          Args:
            input_ids (Tensor): input to the layer of shape `(batch, seq_len)`
            attention_mask: binary BoolTensor of shape `(batch, src_len)`

          Returns:
            x: encoded output of shape `(batch, seq_len, emb_dim)`
            self_attn_scores: a list of self attention score of each layer
        """

        return x, self_attn_scores


## 3. Transformer Decoder

### a.Decoder layer
In this section, we will implement single layer of Transformer decoder.
![picture](http://incredible.ai/assets/images/transformer-decoder.png)

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.emb_dim = config.emb_dim
        self.ffn_dim = config.ffn_dim
        self.self_attn = MultiHeadAttention(
            emb_dim=self.emb_dim,
            num_heads=config.attention_heads,
            dropout=config.attention_dropout,
            causal=True,
        )
        self.dropout = config.dropout
        self.self_attn_layer_norm = nn.LayerNorm(self.emb_dim)
        self.encoder_attn = MultiHeadAttention(
            emb_dim=self.emb_dim,
            num_heads=config.attention_heads,
            dropout=config.attention_dropout,
            encoder_decoder_attention=True,
        )
        self.encoder_attn_layer_norm = nn.LayerNorm(self.emb_dim)
        self.PositionWiseFeedForward = PositionWiseFeedForward(self.emb_dim, self.ffn_dim, config.dropout)
        self.final_layer_norm = nn.LayerNorm(self.emb_dim)


    def forward(
        self,
        x,
        encoder_hidden_states,
        encoder_attention_mask=None,
        causal_mask=None,
    ):
        """
        To-Do : Implement the transformer decoder layer
          Args:
            x (Tensor): input to the layer of shape `(batch, seq_len, emb_dim)`
            encoder_hidden_states: output from the encoder, used for
                encoder-side attention
            encoder_attention_mask: binary BoolTensor of shape `(batch, src_len)` to mask out encoder padding
            causal_mask: binary BoolTensor of shape `(batch, src_len)` to mask out future tokens in decoder.


          Returns:
            x: decoded output of shape `(batch, seq_len, emb_dim)`
            self_attn_weights: self attention score
            cross_attn_weights: encoder-decoder attention score
        """
        return (
            x,
            self_attn_weights,
            cross_attn_weights,
        ) 

### b. Decoder

Stack decoder layers and build full Transformer decoder.

Unlike the encoder, you need to do one more job: pass the causal(unidirectional) mask to the decoder self attention layer 

In [None]:
class Decoder(nn.Module):
    """
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`DecoderLayer`

    Args:
        config: BartConfig
        embed_tokens (torch.nn.Embedding): output embedding
    """

    def __init__(self, config, embed_tokens: nn.Embedding):
        super().__init__()
        self.dropout = config.dropout
        self.padding_idx = embed_tokens.padding_idx
        self.max_target_positions = config.max_position_embeddings
        self.embed_tokens = embed_tokens
        self.embed_positions = SinusoidalPositionalEmbedding(
            config.max_position_embeddings, config.emb_dim, self.padding_idx
        )
        self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.decoder_layers)])  # type: List[DecoderLayer]

    def forward(
        self,
        input_ids,
        encoder_hidden_states,
        encoder_attention_mask,
        decoder_causal_mask,
    ):
        """
        To-Do : Implement the transformer decoder

        Args:
            input_ids (LongTensor): previous decoder outputs of shape
                `(batch, tgt_len)`, for teacher forcing
            encoder_hidden_states: output from the encoder, used for
                encoder-side attention
            encoder_attention_mask: binary BoolTensor of shape `(batch, src_len)` to mask out encoder padding
            causal_mask: binary BoolTensor of shape `(batch, src_len)` to mask out future tokens in decoder.

          Returns:
            x: decoded output of shape `(batch, seq_len, emb_dim)`
            cross_attn_scores: list of encoder-decoder attention score of each layer
        """

        return x, cross_attention_scores

## 4. Transformer

Let's combine encoder and decoder in one place!

In [None]:
import torch
import torch.nn as nn


class Transformer(nn.Module):
    def __init__(self, SRC_vocab, TRG_vocab, config):
        super().__init__()

        self.SRC_vocab = SRC_vocab
        self.TRG_vocab = TRG_vocab

        self.enc_embedding = nn.Embedding(len(SRC_vocab.itos), config.emb_dim, padding_idx=SRC_vocab.stoi['<pad>'])
        self.dec_embedding = nn.Embedding(len(TRG_vocab.itos), config.emb_dim, padding_idx=TRG_vocab.stoi['<pad>'])

        self.encoder = Encoder(config, self.enc_embedding)
        self.decoder = Decoder(config, self.dec_embedding)

        self.prediction_head = nn.Linear(config.emb_dim, len(TRG_vocab.itos))

        self.init_weights()

    def generate_mask(self, src, trg):
        '''
        To-Do : Generate mask for encoder and decoder attention.

        Args:
            src(LongTensor): Input to the transformer of shape (batch_size, seq_len)  
            trg(LongTensor): Decoding target of the transformer of shape (batch_size, seq_len)  

            Returns:
            enc_attention_mask: padding mask for encoder
            dec_attention_mask: causal mask for decoder
        '''

        return enc_attention_mask, dec_attention_mask

    def init_weights(self):
        for name, param in self.named_parameters():
            if param.requires_grad:
                if 'weight' in name:
                    nn.init.normal_(param.data, mean=0, std=0.01)
                else:
                    nn.init.constant_(param.data, 0)

    def forward(self, src, trg):
        enc_attention_mask, dec_causal_mask = self.generate_mask(src, trg)
        encoder_output, encoder_attention_scores = self.encoder(
            input_ids=src,
            attention_mask=enc_attention_mask
        )

        decoder_output, decoder_attention_scores = self.decoder(
            trg,
            encoder_output,
            encoder_attention_mask=enc_attention_mask,
            decoder_causal_mask=dec_causal_mask,
        )
        decoder_output = self.prediction_head(decoder_output)

        return decoder_output, encoder_attention_scores, decoder_attention_scores



# III. Train & Evaluate
<a id='3'></a>
This section is very similar to week 9, so please refer to it for detailed description.

## 1. Configuration

In [None]:
import easydict
import torch.nn as nn
import torch.optim as optim

config = easydict.EasyDict({
    "emb_dim": 64,
    "ffn_dim": 256,
    "attention_heads": 4,
    "attention_dropout": 0.0,
    "dropout": 0.2,
    "max_position_embeddings": 512,
    "encoder_layers": 3,
    "decoder_layers": 3,
})

N_EPOCHS = 100
learning_rate = 5e-4
CLIP = 1

PAD_IDX = SRC_vocab.stoi['<pad>']

model = Transformer(SRC_vocab, TRG_vocab, config)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

best_valid_loss = float('inf')

## 2. Train & Eval

In [None]:
import math
import time
from tqdm import tqdm

def train(model: nn.Module,
          iterator: DataLoader,
          optimizer: optim.Optimizer,
          criterion: nn.Module,
          clip: float):

    model.train()
    epoch_loss = 0

    for batch in iterator:
        src = batch['SRC'].to(device)
        trg = batch['TRG'].to(device)
        optimizer.zero_grad()

        output, enc_attention_scores, _ = model(src, trg)

        output = output[:,:-1,:].reshape(-1, output.shape[-1])
        trg = trg[:,1:].reshape(-1)

        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(iterator)


def evaluate(model: nn.Module,
             iterator: DataLoader,
             criterion: nn.Module):

    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        for batch in iterator:
            src = batch['SRC'].to(device)
            trg = batch['TRG'].to(device)


            output, attention_score, _ = model(src, trg)

            output = output[:,:-1,:].reshape(-1, output.shape[-1])
            trg = trg[:,1:].reshape(-1)

            loss = criterion(output, trg)
            epoch_loss += loss.item()

    return epoch_loss / len(iterator)

for epoch in tqdm(range(N_EPOCHS), total=N_EPOCHS):
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, valid_iterator, criterion)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
    else: 
        break

    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

test_loss = evaluate(model, test_iterator, criterion)
print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')


# IV. Visualization
<a id='4'></a>

## 1. Positional embedding visualization

In [None]:
import matplotlib.pyplot as plt
# Visualization
fig, ax = plt.subplots(figsize=(15, 9))
cax = ax.matshow(model.encoder.embed_positions.weight.data.cpu().numpy(), aspect='auto',cmap=plt.cm.YlOrRd)
fig.colorbar(cax)
ax.set_title('Positional Embedding Matrix', fontsize=18)
ax.set_xlabel('Embedding Dimension', fontsize=14)
ax.set_ylabel('Sequence Length', fontsize=14)

## 2. Attention visualization

In [None]:
from attentionviz import head_view

BATCH_SIZE = 1

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

model.eval()

In [None]:
import sys
if not 'attentionviz' in sys.path:
  sys.path += ['attentionviz']
!pip install regex

def call_html():
  import IPython
  display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))

In [None]:
SAMPLE_IDX = 131

sample = test_data[SAMPLE_IDX]

src_numericalized = torch.LongTensor([SRC_vocab.numericalize(sample['SRC'])]).to(device)
trg_numericalized = torch.LongTensor([TRG_vocab.numericalize(sample['TRG'])]).to(device)

with torch.no_grad():
    output, enc_attention_score, dec_attention_score = model(src_numericalized, trg_numericalized) # turn off teacher forcing
    attention_score = {'self': enc_attention_score, 'cross': dec_attention_score}

src_tok = [SRC_vocab.itos[x] for x in src_numericalized.squeeze()]
trg_tok = [TRG_vocab.itos[x] for x in trg_numericalized.squeeze()]

call_html()
head_view(attention_score, src_tok, trg_tok)