In [1]:
from google.colab import drive
drive.mount('/content/mydrive')

Mounted at /content/mydrive


In [2]:
# install the dependencies

!pip install -U torchdata
!pip install -U spacy
!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm
!pip install 'portalocker==2.8.2'

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=2->torchdata)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=2->torchdata)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=2->torchdata)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=2->torchdata)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=2->torchdata)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch>=2->torchdata)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch>=2->torchdat

At this step, you will need to restart session to avoid future dependency errors.

In [33]:
%matplotlib inline

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List

from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math

import torch.nn.functional as F
# import copy as c
import copy
c=copy.deepcopy

In [34]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

cuda


In [35]:
# download the data
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

# Place-holders
token_transform = {}
vocab_transform = {}

In [36]:
## THIS STEP IS TOKENIZATION

token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')


# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Training data Iterator
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

# Set ``UNK_IDX`` as the default index. This index is returned when the token is not found.
# If not set, it throws ``RuntimeError`` when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
  vocab_transform[ln].set_default_index(UNK_IDX)



In [None]:
## FROM HERE, IT IS MY IMPLEMENTATION OF MY ENCODER AND DECODER, 
# refer from "https://jalammar.github.io/illustrated-transformer/"
# I integrate reference code with my translation task

## clone function, which clones the model layers
def clones(module, N):
    return nn.ModuleList([ copy.deepcopy(module) for _ in range(N) ])

In [38]:
## Encoder class()

class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x,src_key_padding_mask, is_causal=None, mask=None):
        x= x.transpose(0, 1) # torch.permute(x, (1,0,2)) ## permute the dimension

        print('x:', x.shape)

        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [39]:
class LayerNorm(nn.Module):
    def __init__(self, features_size, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.scale = nn.Parameter(torch.ones(features_size))
        self.shift = nn.Parameter(torch.zeros(features_size))
        self.eps = eps

    def forward(self, x):
        #x=torch.permute(x, (1,0,2))
        mean = x.mean(-1, keepdims=True)
        std = x.std(-1, keepdims=True)
        return (x - mean) * self.scale / (std + self.eps) + self.shift

In [8]:
# class SubLayerConnection(nn.Module):
#     def __init__(self, features_size, dropout):
#         super(SubLayerConnection, self).__init__()
#         self.norm = LayerNorm(features_size)
#         self.dropout = nn.Dropout(dropout)

#     def forward(self, x, sublayer):
#         #x=torch.permute(x, (1,0,2))
#         return x + self.dropout(sublayer(self.norm(x)))

In [40]:
# pre normalization don't run now
class SubLayerConnection(nn.Module):
    def __init__(self, features_size, dropout):
        super(SubLayerConnection, self).__init__()
        self.norm = LayerNorm(features_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        x=self.norm(x)
        return x + self.dropout(sublayer(x))

In [41]:
class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayers = clones(SubLayerConnection(size, dropout), 2)

    def forward(self, x, mask):
        #x=torch.permute(x, (1,0,2))
        attn_function = lambda x: self.self_attn(x, x, x, mask)
        x = self.sublayers[0](x, attn_function)
        return self.sublayers[1](x, self.feed_forward)

In [42]:
## attention and Multihead Attention
def attention(query, key, value, mask=None, dropout=None):
    '''
    query, key and value contain vectors corresponding to each word in the input
    '''
    sqrt_d_k = math.sqrt(query.size(-1))
    scores = torch.matmul(query, key.transpose(-2,-1)) / sqrt_d_k

    if mask is not None:
        print('SCORES:', scores.shape, 'MASK:', mask.shape)
        scores += mask # scores.masked_fill(mask == 0, -1e9)
        #print ("scores",scores)

    prob_scores = F.softmax(scores, dim=-1)
    #print("prob_scores: ",prob_scores )

    if dropout is not None:
        p_attn = dropout(prob_scores)

    scaled_value = torch.matmul(prob_scores, value)
    #print("scaled_values: ", scaled_value)
    return scaled_value, prob_scores


class MultiHeadedAttention(nn.Module):
    def __init__(self, num_heads, dim_input=512, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert dim_input % num_heads == 0
        self.num_heads = num_heads
        self.dropout = nn.Dropout(p=dropout)
        self.d_k = dim_input // num_heads

        # L1, L2, L3 and W0: four linear layers in all
        self.linears = clones(nn.Linear(dim_input, dim_input), 4)

        # this is used to store the prob_scores, just for visualization
        self.attn = None

        # helper function to resize the tensor as described above
        #self.resize_tensor = lambda tensor: tensor.view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)

    def forward(self, query, key, value, mask=None):
        #x=torch.permute(x, (1,0,2))
        if mask is not None:
            ## ?????
            mask = mask.unsqueeze(1) # same mask is applied to all heads ??
        batch_size = query.size(0) ## local

        # use the first three linear layers to transform query, key and value
        zipped = zip(self.linears, (query, key, value))
        ### ????
        query, key, value = [linear(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2) for (linear, x) in zipped]

        # apply self attention
        scaled_value, self.attn = attention(query, key, value, mask, self.dropout)
        ### ????
        scaled_value = scaled_value.transpose(1,2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        return self.linears[-1](scaled_value)

In [43]:
class Decoder(nn.Module): ### here is my decoder definiton  here is I overwrite the original Decoder forward function
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
    ### here is I overwrite the original Decoder forward function, if I overwrite the
    ### the forward() funtion in original Decoder source code, I need class name
    ### the same as the original Decoder source code, so it is Decoder not override a function.
    ### override a function, inheriting from parent and that has already implemented and create a new definition for that funciton.
    ### this the same name with the existing function, or it will think I am creating a new function.
    ### inheeriting from Module, forward() like because how the pytorch transformer calls the Decoder()
    ### because pytorch transformer forward() call Decoder like this
    ### self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
                              # tgt_key_padding_mask=tgt_key_padding_mask,
                              # memory_key_padding_mask=memory_key_padding_mask,
                              # tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal)
    def forward(self, x, memory, memory_mask, tgt_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None,
                tgt_is_causal=None,memory_is_causal=None):
        x=torch.permute(x, (1,0,2))
        for layer in self.layers:
            x = layer(x, memory, memory_mask, tgt_mask)
        return torch.permute(self.norm(x), (1,0,2))



        #  output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
        #                       tgt_key_padding_mask=tgt_key_padding_mask,
        #                       memory_key_padding_mask=memory_key_padding_mask,
        #                       tgt_is_causal=tgt_is_causal, memory_is_causal=memory_is_causal)

In [44]:
class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, enc_dec_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        # enc_dec_attn is also called src_attn in the harvardnlp implementation
        self.self_attn = self_attn
        self.enc_dec_attn = enc_dec_attn
        self.feed_forward = feed_forward
        self.sublayers = clones(SubLayerConnection(size, dropout), 3)
        # we need to store size because it is used by LayerNorm in Decoder
        self.size = size

    def forward(self, x, encoder_outputs, src_mask, tgt_mask):
        # encoder_outputs are also called `memory` in the paper
        #x=torch.permute(x, (1,0,2))
        x = self.sublayers[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayers[1](x, lambda x: self.enc_dec_attn(x, encoder_outputs, encoder_outputs, src_mask))
        return self.sublayers[2](x, self.feed_forward)

In [45]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, input_size=512, output_size=2048, dropout=0.1):
        super(PositionWiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(input_size, output_size)
        self.linear2 = nn.Linear(output_size, input_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        #x=torch.permute(x, (1,0,2))
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

In [46]:
# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,#
                 num_decoder_layers: int,#
                 emb_size: int,#
                 nhead: int, #
                 #dim_model,
                 attn_layer,
                 enc_dec_attn,
                 feed_fwd_layer,
                 src_vocab_size: int,#
                 tgt_vocab_size: int,#
                 dim_feedforward: int = 512,#
                 dropout: float = 0.1):#
        super(Seq2SeqTransformer, self).__init__()

        my_encoder_layer = EncoderLayer(size=emb_size, self_attn=c(attn_layer), feed_forward=c(feed_fwd_layer), dropout=dropout)
        myencoder = Encoder(my_encoder_layer, N=num_encoder_layers)


        my_decoder_layer = DecoderLayer(size=emb_size, self_attn=c(attn_layer), enc_dec_attn=c(attn_layer),
                                        feed_forward=c(feed_fwd_layer),
                                        dropout=dropout) ####6 mydecoder is here
        mydecoder = Decoder(my_decoder_layer, N=num_decoder_layers)####6 mydecoder is here

        self.transformer = Transformer(d_model=emb_size, ###4 this is corresponding to transformer source code
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout,
                                       custom_encoder=myencoder, ###5 my decoder,
                                       custom_decoder=mydecoder)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)


        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)


    def forward(self, ####2
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, ####3 transformer source code
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)

        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding( #?? which encoder?
                            self.src_tok_emb(src)), src_mask) ## I think the error comes from here

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask=tgt_mask, memory_mask=None)

In [47]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    ### ???
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask.unsqueeze(0) ##


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((1, src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [48]:
## initialize the Seq2SeqTransformer class

torch.manual_seed(0)

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

# class Seq2SeqTransformer(nn.Module):
#     def __init__(self,
#                  num_encoder_layers: int,#
#                  num_decoder_layers: int,#
#                  emb_size: int,#
#                  nhead: int, #
#                  dim_model,
#                  attn_layer,
#                  enc_dec_attn,
#                  feed_fwd_layer,
#                  src_vocab_size: int,#
#                  tgt_vocab_size: int,#
#                  dim_feedforward: int = 512,#
#                  dropout: float = 0.1):#
#         super(Seq2SeqTransformer, self).__init__()




ATTN_LAYER = MultiHeadedAttention(num_heads=8, dim_input=512)
FEED_FWD_LAYER = PositionWiseFeedForward(input_size=512, output_size=2048, dropout=0.1)


transformer = Seq2SeqTransformer(num_encoder_layers=NUM_ENCODER_LAYERS, ####1
                                 num_decoder_layers=NUM_DECODER_LAYERS,
                                 emb_size=EMB_SIZE,
                                 nhead=NHEAD,
                                 attn_layer=ATTN_LAYER,
                                 enc_dec_attn=ATTN_LAYER,
                                 feed_fwd_layer= FEED_FWD_LAYER,
                                src_vocab_size=SRC_VOCAB_SIZE,
                                tgt_vocab_size=TGT_VOCAB_SIZE,
                                 dim_feedforward=FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [49]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms): ## ?? * is arbitray number of parameters pass any number of transform s(transform operation preprocessing operation ) I like
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func # return a function that perform the transform. I can call it like a function

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]): ## ??
    return torch.cat((torch.tensor([BOS_IDX]), # concatenation
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization ## ??
                                               tensor_transform) # Add BOS/EOS and create tensor
                                                # here is a series of text preprocessing


# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

In [50]:
from torch.utils.data import DataLoader

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in train_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
         ## is model here refers to transformer? the input is different than transformer how we know the input
         ## this is transformer forward method, model.forwar()?
        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(list(train_dataloader))


def evaluate(model):
    model.eval()
    losses = 0

    val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in val_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(list(val_dataloader))

In [51]:
from timeit import default_timer as timer
NUM_EPOCHS = 3

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer) # transforer here has different inputs
    end_time = timer()
    val_loss = evaluate(transformer)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))




x: torch.Size([128, 27, 512])
SCORES: torch.Size([128, 8, 27, 27]) MASK: torch.Size([1, 1, 27, 27])
SCORES: torch.Size([128, 8, 27, 27]) MASK: torch.Size([1, 1, 27, 27])
SCORES: torch.Size([128, 8, 27, 27]) MASK: torch.Size([1, 1, 27, 27])
SCORES: torch.Size([128, 8, 23, 23]) MASK: torch.Size([1, 1, 23, 23])
SCORES: torch.Size([128, 8, 23, 23]) MASK: torch.Size([1, 1, 23, 23])
SCORES: torch.Size([128, 8, 23, 23]) MASK: torch.Size([1, 1, 23, 23])
x: torch.Size([128, 46, 512])
SCORES: torch.Size([128, 8, 46, 46]) MASK: torch.Size([1, 1, 46, 46])
SCORES: torch.Size([128, 8, 46, 46]) MASK: torch.Size([1, 1, 46, 46])
SCORES: torch.Size([128, 8, 46, 46]) MASK: torch.Size([1, 1, 46, 46])
SCORES: torch.Size([128, 8, 36, 36]) MASK: torch.Size([1, 1, 36, 36])
SCORES: torch.Size([128, 8, 36, 36]) MASK: torch.Size([1, 1, 36, 36])
SCORES: torch.Size([128, 8, 36, 36]) MASK: torch.Size([1, 1, 36, 36])
x: torch.Size([128, 33, 512])
SCORES: torch.Size([128, 8, 33, 33]) MASK: torch.Size([1, 1, 33, 33])


In [52]:
# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = None # src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys


# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")

In [53]:
print(translate(transformer, "Eine Gruppe von Menschen steht vor einem Iglu ."))

x: torch.Size([1, 11, 512])
SCORES: torch.Size([1, 8, 1, 1]) MASK: torch.Size([1, 1, 1, 1])
SCORES: torch.Size([1, 8, 1, 1]) MASK: torch.Size([1, 1, 1, 1])
SCORES: torch.Size([1, 8, 1, 1]) MASK: torch.Size([1, 1, 1, 1])
SCORES: torch.Size([1, 8, 2, 2]) MASK: torch.Size([1, 1, 2, 2])
SCORES: torch.Size([1, 8, 2, 2]) MASK: torch.Size([1, 1, 2, 2])
SCORES: torch.Size([1, 8, 2, 2]) MASK: torch.Size([1, 1, 2, 2])
SCORES: torch.Size([1, 8, 3, 3]) MASK: torch.Size([1, 1, 3, 3])
SCORES: torch.Size([1, 8, 3, 3]) MASK: torch.Size([1, 1, 3, 3])
SCORES: torch.Size([1, 8, 3, 3]) MASK: torch.Size([1, 1, 3, 3])
SCORES: torch.Size([1, 8, 4, 4]) MASK: torch.Size([1, 1, 4, 4])
SCORES: torch.Size([1, 8, 4, 4]) MASK: torch.Size([1, 1, 4, 4])
SCORES: torch.Size([1, 8, 4, 4]) MASK: torch.Size([1, 1, 4, 4])
SCORES: torch.Size([1, 8, 5, 5]) MASK: torch.Size([1, 1, 5, 5])
SCORES: torch.Size([1, 8, 5, 5]) MASK: torch.Size([1, 1, 5, 5])
SCORES: torch.Size([1, 8, 5, 5]) MASK: torch.Size([1, 1, 5, 5])
SCORES: torc