<a href="https://colab.research.google.com/github/satyajitghana/TSAI-DeepNLP-END2.0/blob/main/12_AttentionIsAllYouNeed/Attention_Is_All_You_Need.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention Is All You Need !

In [1]:
! nvidia-smi

Sat Jul 31 08:48:16 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   44C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
# ! pip3 install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

In [3]:
! pip3 install git+https://github.com/extensive-nlp/ttc_nlp --quiet

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 76 kB 3.3 MB/s 
[K     |████████████████████████████████| 913 kB 16.0 MB/s 
[K     |████████████████████████████████| 22.3 MB 77.9 MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 234 kB 41.0 MB/s 
[K     |████████████████████████████████| 74 kB 3.4 MB/s 
[K     |████████████████████████████████| 6.4 MB 29.0 MB/s 
[K     |████████████████████████████████| 636 kB 63.3 MB/s 
[K     |████████████████████████████████| 112 kB 64.6 MB/s 
[K     |████████████████████████████████| 829 kB 43.0 MB/s 
[K     |████████████████████████████████| 118 kB 69.3 MB/s 
[K     |████████████████████████████████| 10.6 MB 47.5 MB/s 


In [4]:
! python -m spacy download en_core_web_sm --quiet
! python -m spacy download de_core_news_sm --quiet

2021-07-31 08:49:12.594375: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
[K     |████████████████████████████████| 13.6 MB 66 kB/s 
[?25h[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
2021-07-31 08:49:20.912233: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
[K     |████████████████████████████████| 18.8 MB 1.1 MB/s 
[?25h[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('de_core_news_sm')


## Dataset

In [28]:
import torch
import pytorch_lightning as pl
from torchtext.datasets import Multi30k, IWSLT2016

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

from torch.utils.data import DataLoader

from torchtext.datasets import Multi30k, IWSLT2016

from typing import *

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

class TTCTranslation(pl.LightningDataModule):

    DATASET_OPTIONS = ['multi30k', 'iwslt2016']
    LANGUAGE_OPTIONS = ['en', 'fr', 'de', 'cs', 'ar']

    # Define special symbols and indices
    UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
    # Make sure the tokens are in order of their indices to properly insert them in vocab
    SPECIAL_SYMBOLS = ['<unk>', '<pad>', '<bos>', '<eos>']

    def __init__(self, language_pair=('en', 'de'), spacy_language_pair=('en_core_web_sm', 'de_core_news_sm'), dataset='multi30k', batch_size=64, batch_first=True):
        super().__init__()

        assert len(language_pair) == 2 and len(spacy_language_pair), f'tf are you doing? give me a language \"pair\"'
        assert dataset in self.DATASET_OPTIONS, f'{self.DATASET_OPTIONS} are only supported'
        assert all(x in self.LANGUAGE_OPTIONS for x in language_pair), f'{self.LANGUAGE_OPTIONS} are only supported'

        self.batch_size = batch_size
        self.batch_first = batch_first

        self.language_pair = language_pair
        self.src_lang, self.tgt_lang = language_pair
        self.spacy_src_lang, self.spacy_tgt_lang = spacy_language_pair

        if dataset == 'multi30k':
            self.train_dataset = Multi30k(split='train', language_pair=self.language_pair)
            self.val_dataset = Multi30k(split='valid', language_pair=self.language_pair)
            self.test_dataset = Multi30k(split='test', language_pair=self.language_pair)
        elif dataset == 'iwslt2016':
            self.train_dataset = IWSLT2016(split='train', language_pair=self.language_pair)
            self.val_dataset = IWSLT2016(split='valid', language_pair=self.language_pair)
            self.test_dataset = IWSLT2016(split='test', language_pair=self.language_pair)
        
        self.train_dataset, self.val_dataset, self.test_dataset = list(self.train_dataset), list(self.val_dataset), list(self.test_dataset)

        # --- token transform ---

        self.token_transform = {}
        self.token_transform[self.src_lang] = get_tokenizer('spacy', language=self.spacy_src_lang)
        self.token_transform[self.tgt_lang] = get_tokenizer('spacy', language=self.spacy_tgt_lang)

        # --- vocab transform ---
        # helper function to yield list of tokens
        def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
            language_index = {self.src_lang: 0, self.tgt_lang: 1}

            for data_sample in data_iter:
                yield self.token_transform[language](data_sample[language_index[language]])
        
        self.vocab_transform = {}
        for ln in self.language_pair:
            # Create torchtext's Vocab object 
            self.vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(self.train_dataset, ln),
                                                            min_freq=1,
                                                            specials=self.SPECIAL_SYMBOLS,
                                                            special_first=True)

        # Set UNK_IDX as the default index. This index is returned when the token is not found. 
        # If not set, it throws RuntimeError when the queried token is not found in the Vocabulary. 
        for ln in self.language_pair:
            self.vocab_transform[ln].set_default_index(self.UNK_IDX)

        # --- text/tensor transform ---

        # function to add BOS/EOS and create tensor for input sequence indices
        def tensor_transform(token_ids: List[int]):
            return torch.cat((torch.tensor([self.BOS_IDX]), 
                            torch.tensor(token_ids), 
                            torch.tensor([self.EOS_IDX])))

        # src and tgt language text transforms to convert raw strings into tensors indices
        self.text_transform = {}
        for ln in self.language_pair:
            self.text_transform[ln] = sequential_transforms(self.token_transform[ln], #Tokenization
                                                    self.vocab_transform[ln], #Numericalization
                                                    tensor_transform) # Add BOS/EOS and create tensor

    def prepare_data(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
        pass

    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        pass

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            collate_fn=self.collator_fn 
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            collate_fn=self.collator_fn
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            collate_fn=self.collator_fn
        )

    @property
    def collator_fn(self):
        def wrapper(batch):
            src_batch, tgt_batch = [], []
            for src_sample, tgt_sample in batch:
                src_batch.append(self.text_transform[self.src_lang](src_sample.rstrip("\n")))
                tgt_batch.append(self.text_transform[self.tgt_lang](tgt_sample.rstrip("\n")))

            src_batch = torch.nn.utils.rnn.pad_sequence(src_batch, padding_value=self.PAD_IDX, batch_first=self.batch_first)
            tgt_batch = torch.nn.utils.rnn.pad_sequence(tgt_batch, padding_value=self.PAD_IDX, batch_first=self.batch_first)

            return src_batch, tgt_batch

        return wrapper

    def teardown(self, stage):
        # clean up after fit or test
        # called on every process in DDP
        pass

In [6]:
ttc_translation = TTCTranslation(
    language_pair=('en', 'de'),
    spacy_language_pair=('en_core_web_sm', 'de_core_news_sm'),
    dataset='multi30k',
    batch_size=4
)

training.tar.gz: 100%|██████████| 1.21M/1.21M [00:01<00:00, 948kB/s]
validation.tar.gz: 100%|██████████| 46.3k/46.3k [00:00<00:00, 164kB/s]
mmt16_task1_test.tar.gz: 100%|██████████| 43.9k/43.9k [00:00<00:00, 154kB/s]


In [7]:
train_loader = ttc_translation.train_dataloader()

In [8]:
src, tgt = next(iter(train_loader))

In [9]:
src.shape, tgt.shape

(torch.Size([4, 17]), torch.Size([4, 17]))

In [10]:
src_itos = ttc_translation.vocab_transform[ttc_translation.src_lang].get_itos()
tgt_itos = ttc_translation.vocab_transform[ttc_translation.tgt_lang].get_itos()

In [11]:
src[0]

tensor([   2,   20,   26,   16, 1170,  809,   18,   58,   85,  337, 1340,    6,
           3,    1,    1,    1,    1])

In [12]:
tgt[0]

tensor([   2,   22,   86,  258,   32,   88,   23,   95,    8,   17,  113, 7911,
        3210,    5,    3,    1,    1])

In [13]:
' '.join(src_itos[x] for x in src[0])

'<bos> Two young , White males are outside near many bushes . <eos> <pad> <pad> <pad> <pad>'

In [14]:
' '.join(tgt_itos[x] for x in tgt[0])

'<bos> Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche . <eos> <pad> <pad>'

In [15]:
tgt

tensor([[   2,   22,   86,  258,   32,   88,   23,   95,    8,   17,  113, 7911,
         3210,    5,    3,    1,    1],
        [   2,   85,   32,   11,  848, 2209,   16, 8269,    5,    3,    1,    1,
            1,    1,    1,    1,    1],
        [   2,    6,   70,   28,  220,    8,   16, 6770,   56,  509,    5,    3,
            1,    1,    1,    1,    1],
        [   2,    6,   13,    8,    7,   48,   42,   31,   12,   14,  544,   10,
          699,   16,  249,    5,    3]])

In [16]:
tgt[:, 1:]

tensor([[  22,   86,  258,   32,   88,   23,   95,    8,   17,  113, 7911, 3210,
            5,    3,    1,    1],
        [  85,   32,   11,  848, 2209,   16, 8269,    5,    3,    1,    1,    1,
            1,    1,    1,    1],
        [   6,   70,   28,  220,    8,   16, 6770,   56,  509,    5,    3,    1,
            1,    1,    1,    1],
        [   6,   13,    8,    7,   48,   42,   31,   12,   14,  544,   10,  699,
           16,  249,    5,    3]])

## Model

In [29]:
import torch
import torch.nn as nn


class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        # Get number of training examples
        N = query.shape[0]

        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        query = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)  # (N, value_len, heads, head_dim)
        keys = self.keys(keys)  # (N, key_len, heads, head_dim)
        queries = self.queries(query)  # (N, query_len, heads, heads_dim)

        # Einsum does matrix mult. for query*keys for each training example
        # with every other training example, don't be confused by einsum
        # it's just how I like doing matrix multiplication & bmm

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim),
        # keys shape: (N, key_len, heads, heads_dim)
        # energy: (N, heads, query_len, key_len)

        # Mask padded indices so their weights become 0
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # Normalize energy values similarly to seq2seq + attention
        # so that they sum to 1. Also divide by scaling factor for
        # better stability
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        # attention shape: (N, heads, query_len, key_len)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, heads_dim)
        # out after matrix multiply: (N, query_len, heads, head_dim), then
        # we reshape and flatten the last two dimensions.

        out = self.fc_out(out)
        # Linear layer doesn't modify the shape, final shape will be
        # (N, query_len, embed_size)

        return out


class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        # Add skip connection, run through normalization and finally dropout
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out


class Encoder(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        embed_size,
        num_layers,
        heads,
        forward_expansion,
        dropout,
        max_length,
    ):

        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask, device):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(device)
        out = self.dropout(
            (self.word_embedding(x) + self.position_embedding(positions))
        )

        # In the Encoder the query, key, value are all the same, it's in the
        # decoder this will change. This might look a bit odd in this case.
        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out


class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout):
        super(DecoderBlock, self).__init__()
        self.norm = nn.LayerNorm(embed_size)
        self.attention = SelfAttention(embed_size, heads=heads)
        self.transformer_block = TransformerBlock(
            embed_size, heads, dropout, forward_expansion
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out


class Decoder(nn.Module):
    def __init__(
        self,
        trg_vocab_size,
        embed_size,
        num_layers,
        heads,
        forward_expansion,
        dropout,
        max_length,
    ):
        super(Decoder, self).__init__()
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout)
                for _ in range(num_layers)
            ]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask, device):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)

        return out


class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        trg_pad_idx,
        embed_size=512,
        num_layers=6,
        forward_expansion=4,
        heads=8,
        dropout=0,
        max_length=100,
    ):

        super(Transformer, self).__init__()

        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
        )

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx

    def make_src_mask(self, src, device):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # (N, 1, 1, src_len)
        return src_mask.to(device)

    def make_trg_mask(self, trg, device):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )

        return trg_mask.to(device)

    def forward(self, src, trg, device):
        src_mask = self.make_src_mask(src, device)
        trg_mask = self.make_trg_mask(trg, device)
        enc_src = self.encoder(src, src_mask, device)
        out = self.decoder(trg, enc_src, src_mask, trg_mask, device)
        return out

In [30]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerTranslateModel(pl.LightningModule):

    def __init__(self, hparams, *args, **kwargs):
        super().__init__()

        self.save_hyperparameters(hparams)

        print(f'Hparams =>\n\n{hparams}')

        self.model = Transformer(
            self.hparams.src_vocab_size,
            self.hparams.tgt_vocab_size,
            self.hparams.src_pad_idx,
            self.hparams.tgt_pad_idx,
            embed_size=self.hparams.embed_size,
            num_layers=self.hparams.num_layers,
            forward_expansion=self.hparams.forward_expansion,
            heads=self.hparams.heads,
            dropout=self.hparams.dropout,
            max_length=self.hparams.max_length,
        )

        self.loss = nn.CrossEntropyLoss(ignore_index=self.hparams.tgt_pad_idx)

    def forward(self, src_text, tgt_text):

        output = self.model(src_text, tgt_text, device=self.device)

        return output

    def shared_step(self, batch, batch_idx):
        src_text, tgt_text = batch

        # print(src_text.shape, tgt_text.shape)

        # remove <eos> from target and send
        logits = self(src_text, tgt_text[:, :-1])

        # reshape it into words * output_dim
        logits = logits.reshape(-1, logits.shape[2])
        # print(logits.shape)
        # we remove the <bos> from target
        target = tgt_text[:, 1:].reshape(-1)

        # print(target.shape)
        
        loss = self.loss(logits, target)

        metric = {'loss': loss}

        return metric


    def training_step(self, batch, batch_idx):
        metrics = self.shared_step(batch, batch_idx)

        log_metrics = {'train_loss': metrics['loss']}

        self.log_dict(log_metrics, prog_bar=True)

        return metrics


    def validation_step(self, batch, batch_idx):
        metrics = self.shared_step(batch, batch_idx)

        return metrics
    

    def validation_epoch_end(self, outputs):
        loss = torch.stack([x['loss'] for x in outputs]).mean()

        log_metrics = {'val_loss': loss}

        self.log_dict(log_metrics, prog_bar=True)

        return log_metrics


    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams.lr)
        return optimizer

In [42]:
ttc_translation = TTCTranslation(
    language_pair=('en', 'de'),
    spacy_language_pair=('en_core_web_sm', 'de_core_news_sm'),
    dataset='multi30k',
    batch_size=512
)

In [43]:
from omegaconf import OmegaConf

In [44]:
hparams = OmegaConf.create({
    'src_vocab_size': len(ttc_translation.vocab_transform[ttc_translation.src_lang]),
    'tgt_vocab_size': len(ttc_translation.vocab_transform[ttc_translation.tgt_lang]),
    'src_pad_idx': ttc_translation.PAD_IDX,
    'tgt_pad_idx': ttc_translation.PAD_IDX,
    'embed_size': 256,
    'num_layers': 4,
    'forward_expansion': 4,
    'heads': 4,
    'dropout': 0,
    'max_length': 50, # should be >=50 for multi30k
    'lr': 5e-4,
    'epochs': 30,
    'use_lr_finder': False
})

In [45]:
transformer = TransformerTranslateModel(hparams)

Hparams =>

{'src_vocab_size': 10838, 'tgt_vocab_size': 19215, 'src_pad_idx': 1, 'tgt_pad_idx': 1, 'embed_size': 256, 'num_layers': 4, 'forward_expansion': 4, 'heads': 4, 'dropout': 0, 'max_length': 50, 'lr': 0.0005, 'epochs': 30, 'use_lr_finder': False}


In [46]:
# trainer = pl.Trainer(gpus=0, max_epochs=hparams.epochs)

In [47]:
trainer = pl.Trainer(gpus=1, max_epochs=hparams.epochs)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [48]:
# trainer = pl.Trainer(tpu_cores=8, precision=16, plugins="tpu_spawn_debug", max_epochs=hparams.epochs, checkpoint_callback=False)

In [49]:
trainer.fit(transformer, ttc_translation)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | model | Transformer      | 17.8 M
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
17.8 M    Trainable params
0         Non-trainable params
17.8 M    Total params
71.237    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [66]:
def translate_sentence(model, datamodule, sentence, max_length=50):
    model.eval()

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    if type(sentence) == str:
        tokens = [token.lower() for token in datamodule.token_transform[datamodule.src_lang](sentence)]
    else:
        tokens = [token.lower() for token in sentence]

    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, '<bos>')
    tokens.append('<eos>')

    text_to_indices = [datamodule.vocab_transform[datamodule.src_lang].get_stoi()[token] for token in tokens]

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(0)

    outputs = [datamodule.BOS_IDX]
    for i in range(max_length):
        trg_tensor = torch.LongTensor(outputs).unsqueeze(0)

        with torch.no_grad():
            output = model(sentence_tensor, trg_tensor)

        best_guess = output.argmax(2)[:, -1].item()
        outputs.append(best_guess)

        if best_guess == datamodule.EOS_IDX:
            break

    outputs = outputs[1:-1]
    translated_sentence = [datamodule.vocab_transform[datamodule.tgt_lang].get_itos()[idx] for idx in outputs]

    return ' '.join(translated_sentence)

In [67]:
sentence = "Two young, White males are outside near many bushes."

In [68]:
translate_sentence(transformer, ttc_translation, sentence)

'Zwei junge Männer sind in weißen Nähe von Büschen .'

## Tensorboard Logs: https://tensorboard.dev/experiment/pqsWXkyvS2umPiN8rr1snw/

In [None]:
! tensorboard dev upload --logdir lightning_logs \
    --name "END2 12 Attention Is All You Need - Satyajit" \
    --description "Experiments on AIAYN model on Multi30K Dataset"


In [None]:
from torchtext.data.metrics import bleu_score
# TODO: make this work
def bleu(data, model, german, english, device):
    targets = []
    outputs = []

    for example in data:
        src = vars(example)["src"]
        trg = vars(example)["trg"]

        prediction = translate_sentence(model, src, german, english, device)
        prediction = prediction[:-1]  # remove <eos> token

        targets.append([trg])
        outputs.append(prediction)

    return bleu_score(outputs, targets)