In [1]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(
        self, 
        embedding_size, 
        nhead, 
    ):
        """
            :param embedding_size: dimension of word embedding (d_model in paper)
            :param nhead: number of heads in multi-head attentions (in papers, nhead=8)
        """
        super(SelfAttention, self).__init__()
        
        self.embedding_size = embedding_size
        self.nhead = nhead
        self.d_model = embedding_size // nhead

        assert (self.d_model * self.nhead == self.embedding_size), "Embedding size must be div by n_heads"
        
        self.values = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_model,
            bias=False
        )
        self.keys = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_model,
            bias=False
        )
        self.queries = nn.Linear(
            in_features=self.d_model,
            out_features=self.d_model,
            bias=False
        )
        
        self.fc_out = nn.Linear(
            in_features=self.nhead * self.d_model,
            out_features=self.embedding_size
        )
        
    def forward(self, values, keys, queries, mask):
        N = queries.shape[0] # batch_size
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.nhead, self.d_model)
        keys = keys.reshape(N, key_len, self.nhead, self.d_model)
        queries = queries.reshape(N, query_len, self.nhead, self.d_model)
        
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)
        
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, n_heads, d_model)
        # keys shape: (N, key_len, n_heads, d_model)
        # energy shape: (N, n_heads, query_len, key_len)
        
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
            # when the enery is approxminus inf, 
            # the value of energy go through softmax layer will be approx 0    
        
        attention = torch.softmax(energy / (self.embedding_size ** 0.5), dim=3)
        # attention shape: (N, nhead, query_len, key_len)
        
        # attention shape: (N, nhead, query_len, key_len)
        # values shape: (N, value_len, nhead, d_model)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values])
        # output shape: (N, query_len, nhead, d_model)
        
        # concat all attention head
        out = out.reshape(N, query_len, self.nhead * self.d_model)
        # out shape: (N, query_len, embedding_size)
        
        out = self.fc_out(out)
        return out
        
class TransformerBlock(nn.Module):
    def __init__(
        self, 
        embedding_size, 
        nhead, 
        dropout, 
        forward_expansion
    ):
        super(TransformerBlock, self).__init__()
        
        self.attention = SelfAttention(
            embedding_size=embedding_size,
            nhead=nhead
        )
        self.norm1 = nn.LayerNorm(
            normalized_shape=embedding_size
        )
        self.norm2 = nn.LayerNorm(
            normalized_shape=embedding_size
        )
        
        self.feed_forward = nn.Sequential(
            nn.Linear(
                in_features=embedding_size,
                out_features=forward_expansion*embedding_size,
            ),
            nn.ReLU(),
            nn.Linear(
                in_features=forward_expansion*embedding_size,
                out_features=embedding_size
            )
        )
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        
        x = self.norm1(attention + query)
        x = self.dropout(x)
        
        forward = self.feed_forward(x)
        
        out = self.norm2(forward + x)
        out = self.dropout(out)
        
        return out
    
class Encoder(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        embedding_size,
        num_layers,
        nhead,
        device,
        forward_expansion,
        dropout,
        max_length
    ):
        super(Encoder, self).__init__()
        self.embedding_size = embedding_size
        self.device = device
        self.word_embedding = nn.Embedding(
            num_embeddings=src_vocab_size, 
            embedding_dim=self.embedding_size
        )
        
        self.position_embedding = nn.Embedding(
            num_embeddings=max_length, 
            embedding_dim=self.embedding_size
        )
        
        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    self.embedding_size,
                    nhead=nhead,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                )
                for _ in range(num_layers)
            ]
        )
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        
        x = self.word_embedding(x) + self.position_embedding(positions)
        out = self.dropout(x)
        
        for layer in self.layers:
            out = layer(out, out, out, mask)
            
        return out
    
class DecoderBlock(nn.Module):
    def __init__(
        self, 
        embedding_size, 
        nhead, 
        forward_expansion, 
        dropout, 
        device
    ):
        """
        attn -> layer_norm -> transformer block
        """
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(
            embedding_size=embedding_size,
            nhead=nhead,
        )
        self.norm = nn.LayerNorm(embedding_size)
        self.transformer_block = TransformerBlock(
            embedding_size=embedding_size,
            nhead=nhead,
            dropout=dropout,
            forward_expansion=forward_expansion
        )
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_value, enc_key, src_mask, tgt_mask):
        """
        src_mask and tgt_mask are to hide padding values
        value and key come from the encoder output
        """
        # extract query from the input of decoder
        attention = self.attention(
            values=x,
            keys=x,
            queries=x,
            mask=tgt_mask
        )
        dec_query = self.norm(attention + x)
        dec_query = self.dropout(dec_query)
        
        # pass extracted query of decoder input, value and key from encoder
        # to the transformer block
        out = self.transformer_block(
            value=enc_value, 
            key=enc_key, 
            query=dec_query, 
            mask=src_mask
        )
        return out
    
class Decoder(nn.Module):
    def __init__(
        self,
        tgt_vocab_size,
        embedding_size,
        num_layers,
        nhead,
        forward_expansion,
        dropout,
        device,
        max_length,
    ):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(
            num_embeddings=tgt_vocab_size,
            embedding_dim=embedding_size
        )
        self.position_embedding = nn.Embedding(
            num_embeddings=max_length,
            embedding_dim=embedding_size
        )
        self.layers = nn.ModuleList(
            [
                DecoderBlock(
                    embedding_size=embedding_size,
                    nhead=nhead,
                    forward_expansion=forward_expansion,
                    dropout=dropout,
                    device=device
                )
                for _ in range (num_layers)
            ]
        )
        self.fc_out = nn.Linear(
            in_features=embedding_size,
            out_features=tgt_vocab_size
        )
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, x, enc_out, src_mask, tgt_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.word_embedding(x) + self.position_embedding(positions)
        x = self.dropout(x)
        
        for layer in self.layers:
            x = layer(
                x=x,
                enc_value=enc_out,
                enc_key=enc_out,
                src_mask=src_mask,
                tgt_mask=tgt_mask
            )
            
        out = self.fc_out(x)
        return out
        
class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        tgt_vocab_size,
        src_pad_idx,
        tgt_pad_idx,
        embedding_size=512,
        num_layers=6,
        forward_expansion=4,
        nhead=8,
        dropout=0,
        device="cuda",
        max_length=128
    ):
        super(Transformer, self).__init__()
        
        self.encoder = Encoder(
            src_vocab_size=src_vocab_size,
            embedding_size=embedding_size,
            num_layers=num_layers,
            nhead=nhead,
            device=device,
            forward_expansion=forward_expansion,
            dropout=dropout,
            max_length=max_length
        )

        self.decoder = Decoder(
            tgt_vocab_size=tgt_vocab_size,
            embedding_size=embedding_size,
            num_layers=num_layers,
            nhead=nhead,
            forward_expansion=forward_expansion,
            dropout=dropout,
            device=device,
            max_length=max_length
        )

        self.src_pad_idx = src_pad_idx
        self.tgt_pad_idx = tgt_pad_idx
        self.device = device
        
    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # src_mask.shape = (N, 1, 1, src_len)
        return src_mask.to(self.device)
    
    def make_tgt_mask(self, tgt):
        N, tgt_len = tgt.shape
        tgt_mask = torch.tril(torch.ones(tgt_len, tgt_len)).expand(
            N, 1, tgt_len, tgt_len
        )
        return tgt_mask.to(self.device)
    
    def forward(self, src, tgt):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_tgt_mask(tgt)
        
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(tgt, enc_src, src_mask, tgt_mask)
        
        return out


#### Dataset and vocab

In [2]:
import os
import gdown

if os.path.exists('/content/'):
    # os.system('!gdown 1ty8k-omlU3zvSUemx2gvBaEQWAWZAQ1C')
    # os.system("!gdown ")
    gdown.download("https://drive.google.com/file/d/1ty8k-omlU3zvSUemx2gvBaEQWAWZAQ1C", output='./train.en')
    gdown.download("https://drive.google.com/file/d/1mzDv83hvTlsLNSg7XNIIFLub36YVOf6u", output='./train.vi')

In [3]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
SOS_token = 0
EOS_token = 1
PAD_token = 2


class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS", 2: "PAD"}
        self.n_words = 3  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1
            
# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    # return ''.join(
    #     c for c in unicodedata.normalize('NFD', s)
    #     if unicodedata.category(c) != 'Mn'
    # )
    return s

# Lowercase, trim, and remove non-letter characters


def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    # s = re.sub(r"([.!?])", r" \1", s)
    # s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [5]:
def readLangs(filename, lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines
    from_lang = None
    with open(f'{filename}.{lang1}', 'r+', encoding='utf8') as f:
        from_lang = f.read().strip().split('\n')
        
    to_lang = None
    with open(f'{filename}.{lang2}', 'r+', encoding='utf8') as f:
        to_lang = f.read().strip().split('\n')
        
    # Split every line into pairs and normalize
    pairs = list(zip(from_lang, to_lang))
    pairs = [[normalizeString(s) for s in l] for l in pairs]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs

In [6]:
MAX_LENGTH = 256

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s ",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)

def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

def prepareData(filename, lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(filename, lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs

input_lang, output_lang, pairs = prepareData('train', 'en', 'vi')

Reading lines...
Read 133317 sentence pairs
Trimmed to 133292 sentence pairs
Counting words...
Counted words:
en 41168
vi 18703


In [7]:
X = [pair[0] for pair in pairs]
Y = [pair[1] for pair in pairs]

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
X_train[0], Y_train[0]

('it would not look or feel like anything that we see when we look at a flower so if you look at this flower here and you are a little bug if you are on that surface of that flower that is what the terrain would look like',
 'sẽ không thể nhìn hay cảm nhận bất cứ thứ gì giống như chúng ta thấy khi chúng ta nhìn một bông hoa vậy nếu bạn nhìn bông hoa ở đây và bạn là một con côn trùng bé xíu bạn ở trên bề mặt của bông hoa đó địa hình cũng giống như vậy')

In [8]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, X, Y, input_lang, output_lang, max_length=MAX_LENGTH):
        self.X = X
        self.Y = Y
        self.input_lang = input_lang
        self.output_lang = output_lang
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        src = [SOS_token] + [self.input_lang.word2index[word] for word in self.X[idx].split(' ')]
        src.append(EOS_token)
        src = torch.tensor(src, dtype=torch.long, device=device)
        if len(src) < MAX_LENGTH:
            src = torch.cat((src, torch.tensor([PAD_token] * (MAX_LENGTH - len(src)), dtype=torch.long, device=device)))
        else:
            src = src[:MAX_LENGTH]

        tgt = [SOS_token] + [self.output_lang.word2index[word] for word in self.Y[idx].split(' ')]
        # tgt.append(EOS_token)
        tgt = torch.tensor([self.output_lang.word2index[word] for word in self.Y[idx].split(' ')], dtype=torch.long, device=device)
        if len(tgt) < MAX_LENGTH:
            tgt = torch.cat((tgt, torch.tensor([PAD_token] * (MAX_LENGTH - len(tgt)), dtype=torch.long, device=device)))
        else:
            tgt = tgt[:MAX_LENGTH]

        return src, tgt

In [9]:
train_dataset = Dataset(X_train, Y_train, input_lang, output_lang)
test_dataset = Dataset(X_test, Y_test, input_lang, output_lang)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=True)

x, y = train_dataset[0]
print(x, y)

tensor([    0,   103,    42,   270,    56,    65,   361,    43,   414,    50,
          156,    54,    58,   156,    56,   128,     8,  9077,   192,   452,
           46,    56,   128,    57,  9077,   165,    62,    46,    67,     8,
          651,  4697,   452,    46,    67,    23,    50,  1686,    17,    50,
         9077,    50,   104,   144,     5, 11851,    42,    56,    43,     1,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
            2,     2,     2,     2,     2,     2,     2,     2, 

In [10]:
print([input_lang.index2word[i.item()] for i in x])

['SOS', 'it', 'would', 'not', 'look', 'or', 'feel', 'like', 'anything', 'that', 'we', 'see', 'when', 'we', 'look', 'at', 'a', 'flower', 'so', 'if', 'you', 'look', 'at', 'this', 'flower', 'here', 'and', 'you', 'are', 'a', 'little', 'bug', 'if', 'you', 'are', 'on', 'that', 'surface', 'of', 'that', 'flower', 'that', 'is', 'what', 'the', 'terrain', 'would', 'look', 'like', 'EOS', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD'

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 1

learning_rate = 3e-4
batch_size = 32

In [13]:
embedding_size = 512
src_vocab_size = input_lang.n_words
tgt_vocab_size = output_lang.n_words
src_pad_idx = PAD_token
tgt_pad_idx = PAD_token
num_heads = 4
num_layers = 2
forward_expansion = 4
dropout = 0.1
max_len = MAX_LENGTH
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)


model = Transformer(
    src_vocab_size,
    tgt_vocab_size,
    src_pad_idx,
    tgt_pad_idx,
    embedding_size,
    num_layers,
    forward_expansion,
    num_heads,
    dropout,
    device,
    max_len,
).to(device)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=0.1, patience=10, verbose=True
)

criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx)

sentence = "ein pferd geht unter einer brücke neben einem boot."

for epoch in range(num_epochs):
    print(f"[Epoch {epoch} / {num_epochs}]")

    # model.eval()
    # translated_sentence = translate_sentence(
    #     model, sentence, german, english, device, max_length=50
    # )
    # print(f"Translated example sentence: \n {translated_sentence}")

    model.train()
    losses = []

    for batch_idx, batch in enumerate(train_loader):
        # Get input and targets and get to cuda
        batch = tuple(t.to(device) for t in batch)
        src, tgt = batch
        print(f'src: {src.shape}, tgt: {tgt.shape}')
        # Forward prop
        input_tgt = tgt[:, :-1]
        output = model(src, tgt[:, :-1])
        print(f'output: {output.shape}')

        # Output is of shape (trg_len, batch_size, output_dim) but Cross Entropy Loss
        # doesn't take input in that form. For example if we have MNIST we want to have
        # output to be: (N, 10) and targets just (N). Here we can view it in a similar
        # way that we have output_words * batch_size that we want to send in into
        # our cost function, so we need to do some reshapin.
        # Let's also remove the start token while we're at it
        output = output.reshape(-1, output.shape[-1])
        target = tgt[:, 1:].reshape(-1)

        optimizer.zero_grad()

        loss = criterion(output, target)
        losses.append(loss.detach().cpu().item())

        # Back prop
        loss.backward()
        # Clip to avoid exploding gradient issues, makes sure grads are
        # within a healthy range
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

        # Gradient descent step
        optimizer.step()

        # plot to tensorboard
        # writer.add_scalar("Training loss", loss, global_step=step)
        # step += 1
        print(f'Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(train_loader)} Loss: {loss.item():.4f}')

    mean_loss = sum(losses) / len(losses)
    scheduler.step(mean_loss)
    break

# running on entire test data takes a while
# score = bleu(test_data[1:100], model, german, english, device)
# print(f"Bleu score {score * 100:.2f}")

cuda
[Epoch 0 / 1]
src: torch.Size([4, 256]), tgt: torch.Size([4, 256])
output: torch.Size([4, 255, 18703])
Epoch [0/1] Batch 0/26659 Loss: 10.0714
src: torch.Size([4, 256]), tgt: torch.Size([4, 256])
output: torch.Size([4, 255, 18703])
Epoch [0/1] Batch 1/26659 Loss: 10.0248
src: torch.Size([4, 256]), tgt: torch.Size([4, 256])
output: torch.Size([4, 255, 18703])
Epoch [0/1] Batch 2/26659 Loss: 9.7974
src: torch.Size([4, 256]), tgt: torch.Size([4, 256])
output: torch.Size([4, 255, 18703])
Epoch [0/1] Batch 3/26659 Loss: 9.7332
src: torch.Size([4, 256]), tgt: torch.Size([4, 256])
output: torch.Size([4, 255, 18703])
Epoch [0/1] Batch 4/26659 Loss: 9.4779
src: torch.Size([4, 256]), tgt: torch.Size([4, 256])
output: torch.Size([4, 255, 18703])
Epoch [0/1] Batch 5/26659 Loss: 9.6524
src: torch.Size([4, 256]), tgt: torch.Size([4, 256])
output: torch.Size([4, 255, 18703])
Epoch [0/1] Batch 6/26659 Loss: 9.5352
src: torch.Size([4, 256]), tgt: torch.Size([4, 256])
output: torch.Size([4, 255, 18

KeyboardInterrupt: 

In [None]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(device)
    src_mask = src_mask.to(device)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
    for i in range(max_len-1):
        memory = memory.to(device)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(device)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_token:
            break
    return ys


# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src = [SOS_token] + [input_lang.word2index[word] for word in src_sentence.split(' ')] + [EOS_token]
    src = torch.tensor(src).unsqueeze(1).to(device)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=SOS_token).flatten()
    return " ".join(output_lang.index2word[tok.item()] for tok in tgt_tokens)