In [None]:
### For colab

# !pip install torchtext
# !pip install torchdatasets
# !pip intstall spacy

# !python -m spacy download en_core_web_sm
# !python -m spacy download de_core_news_sm

In [1]:
import sys
import torch
sys.path.append("..")
%reload_ext autoreload
%autoreload 2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

# Dataset

In [2]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k

from typing import Iterable, List
from data_utils import Dataset

d = Dataset()

B = 2
for src, tgt in d.get_dataloader(B):
    print("src: ", src.shape)
    print("tgt: ", tgt.shape, "\n")
    for j in range(B):
        src_tokens = list(src[j, :].detach().numpy())
        tgt_tokens = list(tgt[j, :].detach().numpy())
        src_sentence = " ".join(d.src_vocab.lookup_tokens(src_tokens))
        tgt_sentence = " ".join(d.tgt_vocab.lookup_tokens(tgt_tokens))
        print(src_sentence, "\n", tgt_sentence, "\n")
        
    break

initializing Multi30k with train split ...
src:  torch.Size([2, 15])
tgt:  torch.Size([2, 14]) 

<bos> Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche . <eos> 
 <bos> Two young , White males are outside near many bushes . <eos> <pad> 

<bos> Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem . <eos> <pad> <pad> <pad> <pad> <pad> 
 <bos> Several men in hard hats are operating a giant pulley system . <eos> 



## Build model

from torch import nn
from transformer import Seq2SeqTransformer

In [3]:
from transformer import Seq2SeqTransformer

torch.manual_seed(0)

B = 128
D, Dff  = 8, 8
n_heads = 2
l_enc, l_dec = 1, 1

transformer = Seq2SeqTransformer(l_enc, l_dec, D, n_heads, d.src_vocab_size, d.tgt_vocab_size, Dff).to(device)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index = d._PAD_IDX)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

## Train

In [4]:
import time
from torch import nn, Tensor
from torch.optim import Optimizer
from utils import create_masks
from train import train_epoch, evaluate
from typing import Callable

# train_epoch(d, transformer, optimizer, loss_fn, B, device=device, log_interval = 5, early_stop = 20)    
# evaluate(d, transformer, loss_fn, B, device)

def train( d: Dataset, model: nn.Module, optimizer: Optimizer, loss_fn: Callable[[Tensor, Tensor], Tensor], B: int, epochs: int, device: torch.device):
    for e in range(1, epochs+1):
        start_time = time.time()
        train_loss = train_epoch(d, transformer, optimizer, loss_fn, B, device=device, log_interval = 0, early_stop = 100)    
        val_loss = evaluate(d, transformer, loss_fn, B, device)
        
        msg = f"Epoch: {e}, Train loss: {train_loss:<.3f}, Val loss: {val_loss:<.3f}, " 
        msg += f"Epoch time: {time.time() - start_time:.3f}"
        print(msg)
        
epochs = 3
train(d, transformer, optimizer, loss_fn, B, epochs, device=device)

initializing Multi30k with valid split ...
Epoch: 1, Train loss: 9.314, Val loss: 9.251, Epoch time: 58.467
Epoch: 2, Train loss: 9.198, Val loss: 9.114, Epoch time: 56.275
Epoch: 3, Train loss: 9.060, Val loss: 8.970, Epoch time: 57.277


## Translate

In [5]:
from utils import generate_square_subsequent_mask

def greedy_decode(model: nn.Module, src: Tensor, src_mask: Tensor, max_len: int, start_symbol: int, device: torch.device):
    src, src_mask = src.to(device), src_mask.to(device)
    memory = model.encode(src, src_mask)
    # memory = memory.to(device)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
    
    for i in range(max_len-1):
        tgt_mask = generate_square_subsequent_mask(ys.size(1)).type(torch.bool).to(device)
        
        out = model.decode(ys, memory, tgt_mask)
        prob = model.generator(out[:, -1])
        
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()
        
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        
        if next_word == d._EOS_IDX:
            break
        
    return ys
        
def translate(model: nn.Module, d: Dataset, src_sentence: str, device: torch.device) -> str:    
    src = d.src_transform(src_sentence).view(1, -1)
    ns = src.shape[1]
    src_mask = torch.zeros(ns, ns).type(torch.bool)

    max_len = ns + 5
    start_symbol = d._BOS_IDX

    ys = greedy_decode(model, src, src_mask, max_len, start_symbol, device)    
    ys = list(ys.squeeze().cpu().numpy())
    return " ".join(d.src_vocab.lookup_tokens(ys))

src_sentence = "Eine Gruppe von Menschen steht vor einem Iglu"
tgt_sentence = translate(transformer, d, src_sentence, device)
print(f"sentence:\n{src_sentence}\n\ntranslation:\n{tgt_sentence}")

sentence:
Eine Gruppe von Menschen steht vor einem Iglu

translation:
<bos> einer einer einer einer einer einer einer einer einer einer einer einer einer einer
