In [1]:
# ! python -m spacy download en_core_web_sm --quiet
# ! python -m spacy download de_core_news_sm --quiet

In [2]:
import torch
import torch.nn as nn
import torchtext
import spacy
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
import tqdm
import numpy as np
import datasets
import random

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
dataset = datasets.load_dataset("bentrevett/multi30k")
train_data, valid_data, test_data = (
    dataset["train"],
    dataset["validation"],
    dataset["test"],
)

In [None]:
en_nlp = spacy.load("en_core_web_sm")
de_nlp = spacy.load("de_core_news_sm")

In [None]:
def tokenize_example(example, en_nlp, de_nlp, max_length, lower, sos_token, eos_token):
    en_tokens = [token.text for token in en_nlp.tokenizer(example["en"])][:max_length]
    de_tokens = [token.text for token in de_nlp.tokenizer(example["de"])][:max_length]
    en_tokens = [token.lower() for token in en_tokens]
    de_tokens = [token.lower() for token in de_tokens]
    en_tokens = [sos_token] + en_tokens + [eos_token]
    de_tokens = [sos_token] + de_tokens + [eos_token]
    return {"en_tokens": en_tokens, "de_tokens": de_tokens}

In [None]:
max_length = 1_000
lower = True
sos_token = "<sos>"
eos_token = "<eos>"

fn_kwargs = {
    "en_nlp": en_nlp,
    "de_nlp": de_nlp,
    "max_length": max_length,
    "lower": lower,
    "sos_token": sos_token,
    "eos_token": eos_token,
}

train_data = train_data.map(tokenize_example, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(tokenize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(tokenize_example, fn_kwargs=fn_kwargs)

In [None]:
min_freq = 2
unk_token = "<unk>"
pad_token = "<pad>"

special_tokens = [
    unk_token,
    pad_token,
    sos_token,
    eos_token,
]

en_vocab = torchtext.vocab.build_vocab_from_iterator(
    train_data["en_tokens"],
    min_freq=min_freq,
    specials=special_tokens,
    max_tokens=1_000
)

de_vocab = torchtext.vocab.build_vocab_from_iterator(
    train_data["de_tokens"],
    min_freq=min_freq,
    specials=special_tokens,
    max_tokens=1_000
)

In [None]:
assert en_vocab[unk_token] == de_vocab[unk_token]
assert en_vocab[pad_token] == de_vocab[pad_token]

unk_index = en_vocab[unk_token]
pad_index = en_vocab[pad_token]

en_vocab.set_default_index(unk_index)
de_vocab.set_default_index(unk_index)

In [None]:
def numericalize_example(example, en_vocab, de_vocab):
    en_ids = en_vocab.lookup_indices(example["en_tokens"])
    de_ids = de_vocab.lookup_indices(example["de_tokens"])
    return {"en_ids": en_ids, "de_ids": de_ids}

In [None]:
fn_kwargs = {"en_vocab": en_vocab, "de_vocab": de_vocab}

train_data = train_data.map(numericalize_example, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(numericalize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(numericalize_example, fn_kwargs=fn_kwargs)

In [None]:
train_data[0]

In [None]:
data_type = "torch"
format_columns = ["en_ids", "de_ids"]

train_data = train_data.with_format(
    type=data_type, columns=format_columns, output_all_columns=True
)

valid_data = valid_data.with_format(
    type=data_type,
    columns=format_columns,
    output_all_columns=True,
)

test_data = test_data.with_format(
    type=data_type,
    columns=format_columns,
    output_all_columns=True,
)

In [None]:
train_data[0]

In [None]:
def get_collate_fn(pad_index):
    def collate_fn(batch):
        batch_en_ids = [example["en_ids"] for example in batch]
        batch_de_ids = [example["de_ids"] for example in batch]
        batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, padding_value=pad_index, batch_first=True)
        batch_de_ids = nn.utils.rnn.pad_sequence(batch_de_ids, padding_value=pad_index, batch_first=True)
        batch = {
            "en_ids": batch_en_ids,
            "de_ids": batch_de_ids,
        }
        return batch

    return collate_fn

In [None]:
def get_data_loader(dataset, batch_size, pad_index, shuffle=False):
    collate_fn = get_collate_fn(pad_index)
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle,
    )

    return data_loader

In [None]:
batch_size = 128

train_dataloader = get_data_loader(train_data, batch_size, pad_index, shuffle=True)
valid_dataloader = get_data_loader(valid_data, batch_size, pad_index)
test_dataloader = get_data_loader(test_data, batch_size, pad_index)

In [None]:
list(train_dataloader)[7]["en_ids"].shape
# batch is of shape batch_sz, seq_len
# has to be updated because transformers take in a "sentence" at a time as they need to pay attention to all
# the tokens in the sentence at once, so instead of seq_len, batch_sz, the data needs to be batch_sz, seq_len

In [None]:
list(train_dataloader)[7]["en_ids"]

In [None]:
# transformers (attention is all you need) implementation: https://arxiv.org/pdf/1706.03762

### Positional Encoding

In [None]:
import math
class PositionalEncoder(nn.Module):
    def __init__(self, max_seq_len, emb_dim, p_dropout=0.1):
        # max_seq_len is number of tokens in a sentence
        self.max_seq_len = max_seq_len
        self.emb_dim = emb_dim
        # emb_dim is the dims of one token
        super().__init__()
        pe = torch.zeros(max_seq_len, emb_dim)
        pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        # to prevent raising large bases via an exp, use the property a^b = exp(b * ln(a))
        # then you have 2i * ln(10,000) / dmodel and taking negative of the exponent results in the inervse
        # e.g., 5 ^ 2 = 10 and 5 ^ -2 = 1/10
        divisor = torch.exp(torch.arange(0, emb_dim, 2).float() * (-math.log(10_000.0) / emb_dim))  # i.e., exp(2i * -ln(10,000) / dmodel)
        # apply sin only to the even columns
        pe[:, 0::2] = torch.sin(pos * divisor)
        # apply cos only to the odd columns
        pe[:, 1::2] = torch.cos(pos * divisor)
        pe = pe.unsqueeze(0) # pe: [1, max_seq_len, emb_dim]
        self.register_buffer("pe", pe)
        self.dropout = nn.Dropout(p=p_dropout)
        
    
    def forward(self, x):
        seq_len = x.shape[1]
        # pe_batch = self.pe[:, :seq_len].clone().detach()
        # return self.dropout(x + pe_batch) [batch_sz, seq_len, emb_dim]
        return self.dropout(x + self.pe[:, :seq_len, :]) # [batch_sz, seq_len, emb_dim]
        
# wee test
batch_sz = 3
seq_len = 5
emb_dim = 6
pos_enc = PositionalEncoder(seq_len, emb_dim)
x_test = torch.randn(batch_sz, seq_len, emb_dim)
o = pos_enc(x_test)
assert list(o.shape) == [batch_sz, seq_len, emb_dim]

### Scaled Dot-Product Attention

In [None]:
class Attention(nn.Module):
    def __init__(self, dim_head):
        super().__init__()
        self.dim_head = torch.tensor(dim_head)

    def forward(self, q, k, v, mask=None):
        # q, k, v: [batch_sz, n_heads, seq_len, dim_head]
        k = k.transpose(-2, -1) # key: [batch_sz, n_heads, dim_head, seq_len] for batch key.T
        attn = q @ k # attn: [batch_sz, n_heads, seq_len, seq_len]
        # attn is a comparison of every word in the sentence against every other word in the sentence. hence, a square matrix
        attn = attn / torch.sqrt(self.dim_head)
        # masking
        if mask is not None:
            attn = attn.masked_fill(mask==0, float("-inf"))
        attn = torch.softmax(attn, dim=-1) # attn: [batch_sz, n_head, seq_len, seq_len]
        attn = attn @ v # attn: [batch_sz, n_head, seq_len, dim_head]
        return attn

# wee test
dmodel = 6
bsize = 5
seq_len = 4
emb_dim = 3
n_heads = 2
input_test = torch.randn(bsize, n_heads, seq_len, emb_dim)
attn_test = Attention(dmodel)
o = attn_test(input_test, input_test, input_test)
assert list(o.shape) == [5, 2, 4, 3]

### Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_dim, num_heads=4):
        super().__init__()
        self.num_heads= num_heads
        assert emb_dim % num_heads == 0
        self.dim_head = emb_dim // num_heads
        self.attention = Attention(emb_dim)
        self.fc_out = nn.Linear(emb_dim, emb_dim)
    
    def split_input(self, x):
        batch_sz, seq_len, emb_dim = x.shape
        split = x.view(batch_sz, seq_len, self.num_heads, self.dim_head) # split: [batch_sz, seq_len, num_heads, dim_head]
        return split.permute(0, 2, 1, 3) # split: [batch_sz, num_heads, seq_len, dim_head]

    def cat_heads(self, x):
        batch_sz, num_heads, seq_len, head_dim = x.shape
        x = torch.transpose(x, 2, 1)
        x = torch.reshape(x, (batch_sz, seq_len, head_dim * num_heads))
        return x # x: [batch_sz, seq_len, dim_head * num_heads]

    def forward(self, q, k, v, mask=None):
        q = self.split_input(q) # [batch_sz, num_heads, seq_len, dim_head]
        k = self.split_input(k) # [batch_sz, num_heads, seq_len, dim_head]
        v = self.split_input(v) # [batch_sz, num_heads, seq_len, dim_head]
        attn = self.attention(q, k, v, mask) # [batch_sz, num_heads, seq_len, dim_head]
        attn = self.cat_heads(attn) # [batch_sz, seq_len, dim_head * num_heads]
        attn = self.fc_out(attn) # 
        return attn # attn: [batch_sz, seq_len, emb_dim]
        
# wee test
bsize = 5
seq_len = 3
head_dim = 16
# out_emb_dim in this example is 4
input_test = torch.randn(bsize, seq_len, head_dim)
mha = MultiHeadAttention(head_dim)
o = mha(input_test, input_test, input_test)
assert list(o.shape) == [5, 3, 16]

### Encoder-Decoder

#### Encoder

In [None]:
class Encoder1L(nn.Module):
    def __init__(self, emb_dim, n_heads, p_dropout=0.1):
        super().__init__()
        self.Q = nn.Linear(emb_dim, emb_dim) 
        self.K = nn.Linear(emb_dim, emb_dim)
        self.V = nn.Linear(emb_dim, emb_dim)
        self.mhead_attn = MultiHeadAttention(emb_dim, n_heads)
        self.feed_fwd = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, emb_dim))
        self.layer_norm1 = nn.LayerNorm(emb_dim)
        self.layer_norm2 = nn.LayerNorm(emb_dim)
        self.dropout1 = nn.Dropout(p_dropout)
        self.dropout2 = nn.Dropout(p_dropout)
    def forward(self, x, mask):
        # x: [batch_sz, seq_ln, emb_dim]
        q = self.Q(x)
        k = self.K(x)
        v = self.V(x)
        x = self.mhead_attn(q, k, v, mask)
        x = self.dropout1(x) + x
        x = self.layer_norm1(x)
        x = self.feed_fwd(x)
        x = self.dropout2(x) + x
        x = self.layer_norm2(x)
        return x #  x: [batch_sz, seq_ln, emb_dim]

# wee test
bsize = 5
seq_len = 3
emb_dim = 16
n_heads = 4
test_x = torch.randn(bsize, seq_len, emb_dim)
e1l = Encoder1L(emb_dim, n_heads)
o = e1l(test_x, mask=None)
assert list(o.shape) == [5, 3, 16]

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_vocab_sz, max_seq_len, emb_dim, n_layers, n_attn_heads, p_dropout=0.1):
        super().__init__()
        self.n_layers = n_layers
        self.enc = nn.ModuleList([Encoder1L(emb_dim, n_attn_heads) for i in range(n_layers)])
        self.input_embeddings = nn.Embedding(in_vocab_sz, emb_dim)
        self.pos_enc = PositionalEncoder(max_seq_len, emb_dim)
        self.dropout = nn.Dropout(p_dropout)
    def forward(self, x, mask):
        x = self.input_embeddings(x)
        x = self.pos_enc(x)
        x = self.dropout(x)
        for i in range(self.n_layers):
            x = self.enc[i](x, mask)
        return x

# wee test
bsize = 5
max_seq_len = 3
emb_dim = 16
n_layers = 2
n_heads = 4
in_voc = 20
test_x = torch.randint(low=0, high=10, size=(bsize, max_seq_len))
enc = Encoder(in_voc, max_seq_len, emb_dim, n_layers, n_heads)
o = enc(test_x, mask=None)
assert list(o.shape) == [5, 3, 16]

In [None]:
class Decoder1L(nn.Module):
    def __init__(self, emb_dim, n_heads, p_dropout=0.1):
        super().__init__()
        self.Q1 = nn.Linear(emb_dim, emb_dim) 
        self.K1 = nn.Linear(emb_dim, emb_dim)
        self.V1 = nn.Linear(emb_dim, emb_dim)
        self.Q2 = nn.Linear(emb_dim, emb_dim) 
        self.K2 = nn.Linear(emb_dim, emb_dim)
        self.V2 = nn.Linear(emb_dim, emb_dim)
        self.mhead_attn = MultiHeadAttention(emb_dim, n_heads)
        self.masked_mhead_attn = MultiHeadAttention(emb_dim, n_heads)
        self.feed_fwd = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, emb_dim))
        self.layer_norm1 = nn.LayerNorm(emb_dim)
        self.layer_norm2 = nn.LayerNorm(emb_dim)
        self.layer_norm3 = nn.LayerNorm(emb_dim)
        self.dropout1 = nn.Dropout(p_dropout)
        self.dropout2 = nn.Dropout(p_dropout)
        self.dropout3 = nn.Dropout(p_dropout)
    def forward(self, enc_x, x, src_mask, trg_mask):
        q1 = self.Q1(x)
        k1 = self.K1(x)
        v1 = self.V1(x)
        x = self.masked_mhead_attn(q1, k1, v1, mask=trg_mask)
        x = self.dropout1(x) + x
        x = self.layer_norm1(x)
        q2 = self.Q2(x)
        k2 = self.K2(enc_x)
        v2 = self.V2(enc_x)
        o = self.mhead_attn(q2, k2, v2, mask=src_mask)
        o = self.dropout2(o) + x
        o = self.layer_norm2(o)
        o = self.feed_fwd(o)
        o = self.dropout3(o) + o
        o = self.layer_norm3(o)
        return o #  x: [batch_sz, seq_ln, emb_dim]

# wee test
bsize = 5
seq_len = 3
emb_dim = 16
n_heads = 4
test_x = torch.randn(bsize, seq_len, emb_dim)
test_enc_x = torch.randn(bsize, seq_len, emb_dim)
d1l = Decoder1L(emb_dim, n_heads)
o = d1l(test_enc_x, test_x, src_mask=None, trg_mask=None)
assert list(o.shape) == [5, 3, 16]        

In [None]:
class Decoder(nn.Module):
    def __init__(self, out_vocab_sz, max_seq_len, emb_dim, n_layers, n_attn_heads, p_dropout=0.1):
        super().__init__()
        self.n_layers = n_layers
        self.dec = nn.ModuleList([Decoder1L(emb_dim, n_attn_heads) for i in range(n_layers)])
        self.output_embeddings = nn.Embedding(out_vocab_sz, emb_dim)
        self.fc = nn.Linear(emb_dim, emb_dim)
        self.pos_enc = PositionalEncoder(max_seq_len, emb_dim)
        self.dropout = nn.Dropout(p_dropout)
        
    def forward(self, enc_x, x, src_mask, trg_mask):
        x = self.output_embeddings(x)
        x = self.pos_enc(x)
        x = self.dropout(x)
        for i in range(self.n_layers):
            x = self.dec[i](enc_x, x, src_mask, trg_mask)
        x = self.fc(x)
        return x


# wee test
bsize = 5
max_seq_len = 3
emb_dim = 16
n_heads = 4
n_layers = 2
out_voc = 4
test_x = torch.randint(low=0, high=3, size=(bsize, max_seq_len))
test_enc_x = torch.randn(bsize, seq_len, emb_dim)
dec = Decoder(out_voc, max_seq_len, emb_dim, n_layers, n_heads)
o = dec(test_enc_x, test_x, src_mask=None, trg_mask=None)
assert list(o.shape) == [5, 3, 16] 

## Transformer

In [None]:
class Transformer(nn.Module):
    def __init__(self, max_seq_len, in_vocab_sz, out_vocab_sz, emb_dim, n_layers, 
                 n_attn_heads, src_pad_idx, trg_pad_idx, p_dropout=0.1):
        super().__init__()
        self.encoder = Encoder(in_vocab_sz, max_seq_len, emb_dim, n_layers, n_attn_heads)
        self.decoder = Decoder(out_vocab_sz, max_seq_len, emb_dim, n_layers, n_attn_heads)
        self.source_pad_index = src_pad_idx
        self.target_pad_index = trg_pad_idx
        self.n_attn_heads = n_attn_heads

    def get_source_mask(self, source):
        # true where source is NOT padding: 1 for valid tokens and 0 for padding tokens
        src_mask = (source != self.source_pad_index).unsqueeze(1).unsqueeze(2)
        # src_mask: [batch_size, 1, 1, src_len]
        return src_mask

    def get_target_mask(self, target):
        batch_size, target_len = target.size()
    
        # padding_mask: [batch_size, 1, 1, trg_len]
        pad_mask = (target != self.target_pad_index).unsqueeze(1).unsqueeze(2)
        # lookahead_mask: [1, trg_len, trg_len]
        lookahead_mask = torch.tril(torch.ones(target_len, target_len, device=target.device)).bool()
    
        # Ensure both masks align with batch size and target length
        trg_mask = pad_mask & lookahead_mask.unsqueeze(0).unsqueeze(1)
        return trg_mask

    def forward(self, x, y):
        source_mask = self.get_source_mask(x) # [batch_size, 1, 1, source_len]
        target_mask = self.get_target_mask(y) # [batch_size, 1, target_len, target_len]

        x = self.encoder(x, source_mask)            
        y = self.decoder(x, y, source_mask, target_mask)
        return y


## Training

In [None]:
def init_model(en_vocab, de_vocab, pad_index):
    # params
    input_vocab_sz = len(en_vocab)
    output_vocab_sz = len(de_vocab)
    embedding_dim = 1000
    n_layers = 2
    n_attn_heads = 4
    max_seq_len = 100
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # init models
    transformer = Transformer(max_seq_len, input_vocab_sz, output_vocab_sz, embedding_dim, 
                              n_layers, n_attn_heads, de_vocab["<pad>"], en_vocab["<pad>"])

    optimizer = optim.Adam(transformer.parameters(), lr=0.1, momentum=0.9)
    
    criterion = nn.CrossEntropyLoss(ignore_index=pad_index)
    
    return transformer, optimizer, criterion

In [None]:
def train(train_dataloader, val_dataloader, en_vocab, de_vocab, pad_index, n_epochs=20):
    model, optimizer, criterion = init_model(en_vocab, de_vocab, pad_index)
    clip = 1
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
    
    for epoch in tqdm.tqdm(range(n_epochs)):
        training_losses = []
        model.train()
        
        for i, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            
            src = batch["en_ids"]
            trg = batch["de_ids"]

            # src shape: [batch_size, src_seq_len]
            # trg shape: [batch_size, trg_seq_len]

            # shift target: 
            # trg_input: all except the last token (for teacher forcing input)
            # trg_output: all except the first token (for ground truth output)
            trg_input = trg[:, :-1]  # Decoder input (teacher forcing)
            trg_output = trg[:, 1:]  # Expected output

            source_mask = model.get_source_mask(src)
            target_mask = model.get_target_mask(trg_input)

            # forward pass
            enc_out = model.encoder(src, source_mask)
            y_pred = model.decoder(enc_out, trg_input, source_mask, target_mask)

            # adjust shapes for loss calculation
            trg_vocab_sz = y_pred.shape[-1]
            y_pred = y_pred.reshape(-1, trg_vocab_sz)       # [batch_size * (seq_len - 1), trg_vocab_size]
            trg_output = trg_output.reshape(-1)             # [batch_size * (seq_len - 1)]

            loss = criterion(y_pred, trg_output)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()

            training_losses.append(loss.item())
        
        scheduler.step()
        
        print(f"Epoch {epoch} average training loss: {sum(training_losses) / len(training_losses)}")

        # Validation loop
        with torch.no_grad():
            model.eval()
            validation_losses = []

            for i, batch in enumerate(val_dataloader):
                src = batch["en_ids"]
                trg = batch["de_ids"]

                # shift target
                trg_input = trg[:, :-1] # decoder input
                trg_output = trg[:, 1:] # ground truth

                source_mask = model.get_source_mask(src)
                target_mask = model.get_target_mask(trg_input)

                enc_out = model.encoder(src, source_mask)
                y_pred = model.decoder(enc_out, trg_input, source_mask, target_mask)
                trg_vocab_sz = y_pred.shape[-1]
                y_pred = y_pred.reshape(-1, trg_vocab_sz) # [batch_size * (seq_len - 1), vocab_size]
                trg_output = trg_output.reshape(-1) # [batch_size * (seq_len - 1)]

                loss = criterion(y_pred, trg_output)

                validation_losses.append(loss.item())

            print(f"Epoch {epoch} average validation loss: {sum(validation_losses) / len(validation_losses)}")

    return model


In [None]:
# train model
model = train(train_dataloader, valid_dataloader, en_vocab, de_vocab, pad_index)
model

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f"The model has {count_parameters(model):,} trainable parameters")

In [None]:
def translate_sentence(
    sentence,
    model,
    en_nlp,
    de_nlp,
    en_vocab,
    de_vocab,
    sos_token,
    eos_token,
    device,
    pad_token=en_vocab["<pad>"], 
    max_output_length=25,
):
    model.eval()
    with torch.no_grad():
        # Tokenize and preprocess
        tokens = [token.text.lower() for token in en_nlp.tokenizer(sentence)]
        tokens = [sos_token] + tokens + [eos_token]

        ids = en_vocab.lookup_indices(tokens)
        print(f"Input tokens: {ids}")

        src_tensor = torch.LongTensor(ids).unsqueeze(0).to(device)  # shape: [1, src_len]

        # generate source mask
        source_mask = model.get_source_mask(src_tensor)

        # encode input sequence
        enc_out = model.encoder(src_tensor, source_mask)

        next_token = torch.LongTensor([de_vocab[sos_token]]).unsqueeze(0).to(device)  # shape: [1, 1]
        trg_tokens = next_token  # Autoregressive target sequence

        outputs = []

        for _ in range(max_output_length):
            target_mask = model.get_target_mask(trg_tokens)
            output = model.decoder(enc_out, trg_tokens, source_mask, target_mask)
            next_token = output[:, -1, :].argmax(dim=-1).unsqueeze(0)  # Shape: [1, 1]
            outputs.append(next_token.item())
            if next_token.item() == de_vocab[eos_token]:
                break

            trg_tokens = torch.cat((trg_tokens, next_token), dim=1)

        output_tokens = de_vocab.lookup_tokens(outputs)

        print(f"Output tokens: {outputs}")
        return output_tokens


In [None]:
test_sentences = test_data
rando_idx = np.random.randint(low=0, high=len(test_sentences))
sentence = test_sentences[rando_idx]["en"]
expected_translation = test_sentences[rando_idx]["de"]
translation = translate_sentence(
    sentence=sentence,
    model=model,
    en_nlp=en_nlp,
    de_nlp=de_nlp,
    en_vocab=en_vocab,
    de_vocab=de_vocab,
    sos_token="<sos>",
    eos_token="<eos>",
    device="cpu",
)
print(f"\nsentence: {sentence}\n")
print(f"expected_translation: {expected_translation}\n")
print(f"actual translation: {' '.join(i for i in translation[0:])}")