In [1]:
import warnings
warnings.filterwarnings("ignore")
import os, re
import copy
from typing import Iterable, List
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchtext.data.utils import get_tokenizer
from torch.nn.utils.rnn import pad_sequence
from torchtext.vocab import build_vocab_from_iterator

import numpy as np

from sklearn.model_selection import train_test_split

In [2]:
def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything()

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
SRC_LANGUAGE = "ru"
TGT_LANGUAGE = "en"

token_transform = {}
vocab_transform = {}

In [5]:
token_transform[SRC_LANGUAGE] = get_tokenizer("spacy", language="ru_core_news_sm")
token_transform[TGT_LANGUAGE] = get_tokenizer("spacy", language="en_core_web_sm")

In [6]:
data = [line.replace("\t", "\n").split("CC")[0] for line in open("rus.txt", "r",  encoding="utf8")]
data = [item for item in data if len(item.split("\n")[0]) >= 20 and len(item.split("\n")[0]) <= 35]

data = np.array(data)

In [7]:
def preprocessor(text):
    """ 
     simple text preprocessing (do not remove stop words for transformer model) 
     without that step ce-loss was near 0.67, otherwise ~ 0.75
    """
    text = text.lower()              
    cleanr = re.compile('<.*?>')
    text = re.sub(cleanr, ' ', text)
    text = re.sub(r'[?|!|\'|"|#]', r'', text)
    text = re.sub(r'[.|,|)|(|\|/]', r' ', text)
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r"it's", "it is", text)
    text = re.sub(r"i'd", "i had", text)
    text = re.sub(r"i've", "a have", text)
    text = re.sub(r"i'll", "i will", text)
    text = re.sub(r"we'll", "we will", text)
    # and so on...
    return text

In [8]:
class TransDataset(Dataset):
    def __init__(self, data, is_train=True):
        train, test = train_test_split(data, shuffle=True, random_state=42)
        if is_train:
            self.data = train
        else:
            self.data = test

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, ix):
        pairs_str = self.data[ix]
        pos = pairs_str.find('\n')
        # return preprocessor(pairs_str[pos+1:]), preprocessor(pairs_str[:pos+1])
        return pairs_str[pos+1:], pairs_str[:pos+1]

    def collate_fn(self, batch):
        src_batch, tgt_batch = [], []
        for src_sample, tgt_sample in batch:
            src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))  # .rstrip()))
            tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))  # .rstrip()))

        src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
        tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
        return src_batch, tgt_batch

In [9]:
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

In [10]:
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ["<unk>", "<pad>", "<bos>", "<eos>"]

In [11]:
# the best alternative way is to build word-piece vocabulary

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    train_iter = TransDataset(data)
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

In [12]:
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    vocab_transform[ln].set_default_index(UNK_IDX)

**Building transformer model**

In [13]:
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(-torch.arange(0, emb_size, 2) * np.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer("pos_embedding", pos_embedding)

    def forward(self, token_embedding: torch.Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

In [14]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: torch.Tensor):
        return self.embedding(tokens.long()) * np.sqrt(self.emb_size)

In [15]:
def create_trg_mask(size):
    """ mask for target sequence """
    np_mask = np.triu(np.ones((1, size, size)), k=1).astype("uint8")
    np_mask = (torch.from_numpy(np_mask) == 0.0).to(device)
    return np_mask

def create_masks(src, trg):
    
    src_mask = (src != PAD_IDX).unsqueeze(-2)

    if trg is not None:
        trg_mask = (trg != PAD_IDX).unsqueeze(-2)
        size = trg.size(1) 
        np_mask = create_trg_mask(size)
        trg_mask = trg_mask & np_mask
    else:
        trg_mask = None
        
    return src_mask, trg_mask

In [16]:
def attention(query, key, value, d_model, mask=None, dropout=None):
    
    scores = torch.einsum("... i d , ... j d -> ... i j", query, key) / np.sqrt(d_model)
    
    if mask is not None:
        mask = mask.unsqueeze(1)
        scores = scores.masked_fill(mask == 0, -1e9)  # mask padding
    scores = F.softmax(scores, dim=-1)
    
    if dropout is not None:
        scores = dropout(scores)
        
    output = torch.einsum("... i j , ... j d -> ... i d", scores, value)
    
    return output

In [17]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        
        self.d_model = d_model
        self.d_k = d_model // heads
        self.heads = heads
        
        self.q_linear = nn.Linear(d_model, d_model, bias=False)
        self.v_linear = nn.Linear(d_model, d_model, bias=False)
        self.k_linear = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model, bias=False)
    
    def forward(self, q, k, v, mask=None):
        
        batch_size = q.size(0)
                
        k = self.k_linear(k).view(batch_size, -1, self.heads, self.d_k)
        q = self.q_linear(q).view(batch_size, -1, self.heads, self.d_k)
        v = self.v_linear(v).view(batch_size, -1, self.heads, self.d_k)
               
        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)

        scores = attention(q, k, v, self.d_k, mask, self.dropout)
        
        concat = scores.transpose(1,2).contiguous().view(batch_size, -1, self.d_model)
        
        output = self.out(concat)
    
        return output

In [18]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        super(FeedForward, self).__init__() 
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
        
    def forward(self, x):
        x = self.dropout(F.relu(self.linear_1(x)))  # or "gelu": [x * scipy.stats.norm.cdf(x, loc=0, scale=1)]
        x = self.linear_2(x)
        return x

In [19]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super(LayerNorm, self).__init__()
    
        self.size = d_model
        self.alpha = nn.Parameter(torch.ones(self.size))
        self.bias = nn.Parameter(torch.zeros(self.size))
        self.eps = eps
        
    def forward(self, x):
        norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
        return norm

In [20]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, heads, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.norm_1 = LayerNorm(d_model)
        self.norm_2 = LayerNorm(d_model)
        self.mha = MultiHeadAttention(heads, d_model)
        self.ff = FeedForward(d_model)
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        x2 = self.norm_1(x)
        x = x + self.dropout_1(self.mha(x2, x2, x2, mask))
        x2 = self.norm_2(x)
        x = x + self.dropout_2(self.ff(x2))
        return x

In [21]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, heads, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.norm_1 = LayerNorm(d_model)
        self.norm_2 = LayerNorm(d_model)
        self.norm_3 = LayerNorm(d_model)
        
        self.dropout_1 = nn.Dropout(dropout)
        self.dropout_2 = nn.Dropout(dropout)
        self.dropout_3 = nn.Dropout(dropout)
        
        self.mha_1 = MultiHeadAttention(heads, d_model)
        self.mha_2 = MultiHeadAttention(heads, d_model)
        self.ffn = FeedForward(d_model)
        
    def forward(self, x, encoder_out, src_mask, trg_mask):
        x2 = self.norm_1(x)
        x = x + self.dropout_1(self.mha_1(x2, x2, x2, trg_mask))
        x2 = self.norm_2(x)
        x = x + self.dropout_2(self.mha_2(x2, encoder_out, encoder_out, src_mask))
        x2 = self.norm_3(x)
        x = x + self.dropout_3(self.ffn(x2))
        return x

In [22]:
def get_n_layers(module, n_layers):
    return nn.ModuleList([copy.deepcopy(module) for i in range(n_layers)])

In [23]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, heads):
        super(Encoder, self).__init__()
        self.n_layers = n_layers
        self.embedding = TokenEmbedding(vocab_size, d_model)
        self.pe = PositionalEncoding(d_model, dropout=0.1)
        self.layers = get_n_layers(EncoderLayer(d_model, heads), n_layers)
        self.norm = LayerNorm(d_model)
        
    def forward(self, src, mask):
        x = self.embedding(src)
        x = self.pe(x)
        for i in range(self.n_layers):
            x = self.layers[i](x, mask)
        return self.norm(x)
    
class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, heads):
        super(Decoder, self).__init__()
        self.n_layers = n_layers
        self.embedding = TokenEmbedding(vocab_size, d_model)
        self.pe = PositionalEncoding(d_model, dropout=0.1)
        self.layers = get_n_layers(DecoderLayer(d_model, heads), n_layers)
        self.norm = LayerNorm(d_model)
        
    def forward(self, trg, encoder_out, src_mask, trg_mask):
        x = self.embedding(trg)
        x = self.pe(x)
        for i in range(self.n_layers):
            x = self.layers[i](x, encoder_out, src_mask, trg_mask)
        return self.norm(x)

In [24]:
class Transformer(nn.Module):
    def __init__(self, src_vocab, trg_vocab, d_model, n_layers, heads):
        super(Transformer, self).__init__()
        self.encoder = Encoder(src_vocab, d_model, n_layers, heads)
        self.decoder = Decoder(trg_vocab, d_model, n_layers, heads)
        self.out = nn.Linear(d_model, trg_vocab)
        
    def forward(self, src, trg, src_mask, trg_mask):
        encoder_out = self.encoder(src, src_mask)
        decoder_out = self.decoder(trg, encoder_out, src_mask, trg_mask)
        output = self.out(decoder_out)
        return output

In [25]:
class EarlyStopping:
    def __init__(self, patience=15, min_delta=0, path="model.pth"):
        self.path = path
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss, model=None):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            if model is not None:
                checkpoint = {
                    "model": model,
                }
                torch.save(checkpoint, self.path)
                print(f"Model saved to: {self.path}")
            self.best_loss = val_loss
            self.counter = 0
        elif self.best_loss - val_loss < self.min_delta:
            self.counter += 1
            print(f"INFO: Early stopping counter {self.counter} of {self.patience}")
            if self.counter >= self.patience:
                print("INFO: Early stopping")
                self.early_stop = True

In [26]:
src_vocab = len(vocab_transform[SRC_LANGUAGE])
trg_vocab = len(vocab_transform[TGT_LANGUAGE])
d_model = 512  # quite long for given short sentences
heads = 8
batch_size = 128
n_layers = 3  # 6 in paper, but it is not magical number

model = Transformer(src_vocab, trg_vocab, d_model, n_layers, heads).to(device)
early = EarlyStopping(patience=3)

In [27]:
for param in model.parameters():
    if param.dim() > 1:
        nn.init.xavier_uniform_(param)
        
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=.5)

# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=8, T_mult=1, eta_min=0.000001)

In [28]:
# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# src and tgt language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln],  # Tokenization
                                               vocab_transform[ln],  # Numericalization
                                               tensor_transform)  # Add BOS/EOS and create tensor

In [29]:
train_ds = TransDataset(data, is_train=True)
valid_ds = TransDataset(data, is_train=False)

train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=train_ds.collate_fn)
valid_dataloader = DataLoader(valid_ds, batch_size=batch_size, collate_fn=valid_ds.collate_fn)

In [30]:
# TODO: torchmetrics.TranslationEditRate

def train_epoch(model, optimizer):
    model.train()
    losses = 0

    for src, tgt in tqdm(train_dataloader, total=len(train_dataloader)):
        src = src.transpose(0, 1).to(device)
        tgt = tgt.transpose(0, 1).to(device)

        tgt_input = tgt[:, :-1]
        tgt_out = tgt[:, 1:].contiguous().view(-1)

        src_mask, tgt_mask = create_masks(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask)

        optimizer.zero_grad()
        
        loss = loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out)
        loss.backward()

        optimizer.step()
        losses += loss.item()
    scheduler.step()
    return losses / len(train_dataloader)

def evaluate(model):
    model.eval()
    losses = 0    

    for src, tgt in tqdm(valid_dataloader, total=len(valid_dataloader)):
        src = src.transpose(0, 1).to(device)
        tgt = tgt.transpose(0, 1).to(device)

        tgt_input = tgt[:, :-1]
        tgt_out = tgt[:, 1:].contiguous().view(-1)

        src_mask, tgt_mask = create_masks(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask)
        
        loss = loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out)
        losses += loss.item()

    return losses / len(valid_dataloader)

In [31]:
NUM_EPOCHS = 50
        
for epoch in range(1, NUM_EPOCHS+1):
    train_loss = train_epoch(model, optimizer)
    val_loss = evaluate(model)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}"))
    early(val_loss, model)
    if early.early_stop:
        break

100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:21<00:00, 10.74it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:15<00:00, 32.42it/s]


Epoch: 1, Train loss: 3.1537, Val loss: 2.1276


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:21<00:00, 10.68it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:15<00:00, 32.86it/s]


Epoch: 2, Train loss: 1.8344, Val loss: 1.4283
Model saved to: model.pth


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:22<00:00, 10.63it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:15<00:00, 32.82it/s]


Epoch: 3, Train loss: 1.2920, Val loss: 1.0942
Model saved to: model.pth


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:21<00:00, 10.69it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:16<00:00, 31.16it/s]


Epoch: 4, Train loss: 1.0039, Val loss: 0.9315
Model saved to: model.pth


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:25<00:00, 10.45it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:16<00:00, 30.26it/s]


Epoch: 5, Train loss: 0.8344, Val loss: 0.8472
Model saved to: model.pth


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:24<00:00, 10.48it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:15<00:00, 31.68it/s]


Epoch: 6, Train loss: 0.7210, Val loss: 0.7912
Model saved to: model.pth


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:23<00:00, 10.56it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:16<00:00, 30.87it/s]


Epoch: 7, Train loss: 0.6374, Val loss: 0.7571
Model saved to: model.pth


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:24<00:00, 10.49it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:16<00:00, 30.41it/s]


Epoch: 8, Train loss: 0.5728, Val loss: 0.7312
Model saved to: model.pth


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:23<00:00, 10.54it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:15<00:00, 31.92it/s]


Epoch: 9, Train loss: 0.5204, Val loss: 0.7062
Model saved to: model.pth


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:22<00:00, 10.65it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:16<00:00, 31.30it/s]


Epoch: 10, Train loss: 0.4775, Val loss: 0.7005
Model saved to: model.pth


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:24<00:00, 10.50it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:16<00:00, 30.92it/s]


Epoch: 11, Train loss: 0.4118, Val loss: 0.6780
Model saved to: model.pth


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:24<00:00, 10.50it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:15<00:00, 31.65it/s]


Epoch: 12, Train loss: 0.3898, Val loss: 0.6756
Model saved to: model.pth


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:24<00:00, 10.50it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:16<00:00, 30.96it/s]


Epoch: 13, Train loss: 0.3743, Val loss: 0.6719
Model saved to: model.pth


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:22<00:00, 10.61it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:16<00:00, 31.12it/s]


Epoch: 14, Train loss: 0.3615, Val loss: 0.6718
Model saved to: model.pth


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:23<00:00, 10.53it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:16<00:00, 30.69it/s]


Epoch: 15, Train loss: 0.3500, Val loss: 0.6703
Model saved to: model.pth


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:23<00:00, 10.57it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:16<00:00, 31.15it/s]


Epoch: 16, Train loss: 0.3378, Val loss: 0.6760
INFO: Early stopping counter 1 of 3


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:24<00:00, 10.49it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:16<00:00, 31.24it/s]


Epoch: 17, Train loss: 0.3268, Val loss: 0.6738
INFO: Early stopping counter 2 of 3


100%|██████████████████████████████████████████████████████████████████████████████| 1515/1515 [02:24<00:00, 10.46it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 505/505 [00:16<00:00, 31.03it/s]

Epoch: 18, Train loss: 0.3170, Val loss: 0.6712
INFO: Early stopping counter 3 of 3
INFO: Early stopping





In [32]:
model = torch.load("model.pth")["model"].to(device)

In [33]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(device)
    src_mask = src_mask.to(device)

    memory = model.encoder(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
    
    for i in range(max_len-1):
        memory = memory.to(device)
        tgt_mask = create_trg_mask(ys.size(1)).type(torch.bool)
        output = model.decoder(ys, memory, src_mask, tgt_mask)
        prob = model.out(output[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()
        if next_word == EOS_IDX:
            break
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        
    return ys

def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src = src.transpose(0, 1)
    src_mask = (src != PAD_IDX).unsqueeze(-2)
    tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))[1:]).replace("", "").replace("", "")

In [34]:
print(translate(model, "Я люблю людей и спорт"))
print(translate(model, "Это долгая история которую стоит рассказать подробно"))
print(translate(model, "Необходимо сначала спросить разрешение"))

I like sports and sports .
That 's the long story in detail .
It 's necessary to ask first .
