# Bytenet
Implementation of the paper Neural Machine Translation in Linear Time (https://arxiv.org/pdf/1610.10099). The Bynet is a CNN based Encoder/Decoder Model used here for Sequence to Sequence Translation. The model is trained on a part of the WMT2014 english to german dataset.

### Data Preprocessing
Data is loaded and tokenized in this step.

In [1]:
# Define Imports
import json

import requests
import torch
import pickle
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from datasets import load_dataset
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from transformers import AutoTokenizer


In [2]:
import requests


class WMT19JSONLoader:
    def __init__(self, file_path, source_lang='de', target_lang='en', max_length=128):
        self.source_lang = source_lang
        self.target_lang = target_lang
        self.max_length = max_length
        # self.tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
        self.file_path = file_path
        self.tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")

    def load_json_data(self, file_path):
        """
        Function that loads the downloaded JSON file

        :param file_path:
        :return:
        """
        loaded_data = []
        with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
            for line in f:
                try:
                    loaded_data.append(json.loads(line.strip()))
                except json.JSONDecodeError as e:
                    print(f"Error when line is decoded: {e}")
        return loaded_data

    def convert_to_tensor(self, src, trg):
        """
        Checks if source and target are tensor
        If both are not tensor, they are converted to tensors

        :param src:
        :param trg:
        :return:
        """
        if not torch.is_tensor(src):
            src = torch.Tensor(src)
        if not torch.is_tensor(trg):
            trg = torch.tensor(trg, dtype=torch.int32)
        return src, trg

    def extract_source_target(self, load_data):
        """
        Function that extracts out of the downloaded JSON the
        german rows as source and the english rows as targets

        :param load_data:
        :param source_lang:
        :param target_lang:
        :return:
        """
        source_texts = []
        target_texts = []
        for item in load_data:
            if ('row' in item and 'translation' in item['row'] and
                    self.source_lang in item['row']['translation'] and
                    self.target_lang in item['row']['translation']):
                source_texts.append(item['row']['translation'][self.source_lang])
                target_texts.append(item['row']['translation'][self.target_lang])
        return source_texts, target_texts

    def tokenize_texts(self, texts):
        """
        Function for tokenizing the text data
        Uses BERT-Tokenizer as tokenizer model

        :param texts:
        :return:
        """
        tokenized_texts = []
        for text in texts:
            tokens = self.tokenizer(text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
            
            tokenized_texts.append(tokens['input_ids'].squeeze())
        return tokenized_texts

    def load_and_tokenize(self, json_file_path):
        """
        Function that does the load json data
        and the tokenizing process

        :param json_file_path:
        """
        loaded_data = self.load_json_data(json_file_path)

        source_texts, target_texts = self.extract_source_target(loaded_data)

        # The tokenized source and targets
        # self.tokenizer is a object of type transformers from the Bert model
        # padding="max_length": is used to fill sequence to maximal length
        # truncation = True: Means that the sequence is cutted, if longer than max_length
        # return_tensors="pt": Means that a pytorch tensor is returned
        # the source text is tokenized into smaller elements
        tokenized_source_texts = self.tokenize_texts(source_texts)

        # the target text is tokenized into smaller elements
        tokenized_target_texts = self.tokenize_texts(target_texts)

        #TODO: evetually squeeze as in WMTLoader

        return tokenized_source_texts, tokenized_target_texts


def download_data(offset, length):
    """
    Method for downloading the dataset as JSON
    F.e. if the first 10 rows have to be downloaded, offset has to
    be 0 and length has to be 10

    :param offset: The offset used in the url
    :param length: The length of the selected number of rows in the dataset
    :return:
    """
    url = f"https://datasets-server.huggingface.co/rows?dataset=wmt%2Fwmt19&config=de-en&split=train&offset={offset}&length={length}"
    query_parameters = {"downloadformat": "json"}
    response = requests.get(url, params=query_parameters)
    if response.status_code == 200:
        loaded_data = response.json()
        print(f"Downloading dataset-offset: {offset}")
        return loaded_data['rows']
    else:
        print(f"Error while downloading data: {response.status_code}")
        return []


def save_data_to_json(load_data, file_path):
    """
    Writes data into the JSON object

    :param load_data: The data that has to be writen into file
    :param file_path: The file path where the file has to be saved
    """
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'a', encoding='utf-8') as f:
        for item in load_data:
            json.dump(item, f, ensure_ascii=False)
            f.write('\n')


def download_batch_and_save(offset, length, output_file):
    """
    Downloads and saves the batch

    :param offset: The offset which is currently used to download
    :param length: The length is defined with 100
    :param output_file: The name of the file to be saved
    """
    loaded_data = download_data(offset, length)
    save_data_to_json(loaded_data, output_file)


def download_entire_de_en_dataset(batch_size, output_dir, num_workers):
    """
    Downloads the entire WMT19 dataset. Uses a ThreadPoolExecutor for
    faster download of the dataset.

    :param batch_size:
    :param output_dir:
    :param num_workers:
    """
    offset = 0
    output_file = os.path.join(output_dir, 'wmt_19_de_en.json')
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        offset = 0
        futures = []
        while True:
            futures.append(executor.submit(download_batch_and_save, offset, batch_size, output_file))
            offset += batch_size
            # if offset >= 34800000:
            # This controls how much of the dataset is actually downloaded
            print(offset)
            if offset >= 50000:
                break

        for future in as_completed(futures):
            future.result()


### ByteNet Model
The following Cells implement the necessary parts of the ByteNet model. The model is made up of a number of sets of, each of which contains a number of residual blocks that apply LayerNorm and 1D Convolutions, which is further masked for the Decoder part of the Network.

In [3]:
# Imports for ByteNet
import torch
from torch.utils.data import Dataset
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from data.data_loader import WMTLoader, WMT19JSONLoader, download_entire_de_en_dataset


In [4]:
class ResidualBlockReLu(nn.Module):
    """
    Implementation of residual Layer for Bytenet machine translation task.

    :param d: The number of input features.
    :param dilation: The initial dilation rate for the convolution layers.
    """

    def __init__(self, d, dilation, k, decoder=False):
        super(ResidualBlockReLu, self).__init__()
        self.seq_len = 256 if decoder else 128
        self.decoder = decoder
        self.layer_norm1 = DynamicLayerNorm()
        self.reLu1 = nn.ReLU()
        # 2*d -> d
        self.conv1 = nn.Conv1d(d * 2, d, 1)
        self.layer_norm2 = DynamicLayerNorm()
        self.reLu2 = nn.ReLU()
        # Masked kernel size is k
        # Dilation only used for masked convolution
        # d -> d
        if decoder:
            # Masked convolution basically means all padding on left side
            self.receptive_field = (k - 1) * dilation
            self.conv2 = nn.Conv1d(d, d, k, dilation=dilation)
        else:
            # Padding still needed to keep the size of the input and output the same
            padding = (k - 1) * dilation // 2
            if padding > 0:
                self.conv2 = nn.Conv1d(d, d, k, dilation=dilation, padding=padding)
            else:
                self.conv2 = nn.Conv1d(d, d, k, dilation=dilation)
        self.layer_norm3 = DynamicLayerNorm()
        self.reLu3 = nn.ReLU()
        # d -> 2*d
        self.conv3 = nn.Conv1d(d, d * 2, 1)

    def forward(self, x):
        residual = x
        x = self.layer_norm1(x)
        x = self.reLu1(x)
        x = self.conv1(x)
        x = self.layer_norm2(x)
        x = self.reLu2(x)
        # When Decoder is used, the convolution is causal
        if self.decoder and self.receptive_field > 0:
            x = F.pad(x, (self.receptive_field, 0))
        x = self.conv2(x)
        x = self.layer_norm3(x)
        x = self.reLu3(x)
        x = self.conv3(x)
        # Add back the residual
        x += residual
        return x


In [5]:
class BytenetEncoder(nn.Module):
    """
    Implementation of the ByteNet Encoder. Default Parameters are set to the ones used in the paper.
    
    :param kernel_size: The kernel size for the unmasked (padded) convolution in the residual block.
    :param max_dilation_rate: The maximum dilation rate for the convolution layers.
    :param masked_kernel_size: The kernel size for the masked convolution in the residual block (only interesting for decoder).
    :param num_sets: The number of sets of residual blocks.
    :param set_size: The number of residual blocks in each set.
    :param hidden_channels: The number of hidden channels in the model.
    """
    def __init__(self, kernel_size=3, max_dilation_rate=16, masked_kernel_size=3, num_sets=6, set_size=5,
                 hidden_channels=800, emb_size = 1600):
        super(BytenetEncoder, self).__init__()
        self.num_channels = hidden_channels
        self.kernel_size = kernel_size
        self.layers = nn.Sequential()
        # 128 is size of tokenizer
        # input of shape [batch_size, 128, 128] as [batch_size, tokens, embedding_size]
        # self.layers.append(nn.Conv1d(in_channels=emb_size, out_channels=hidden_channels * 2, kernel_size=1))
        # From the Paper:
        # Model has a series of residual blocks of increased dilation rate
        # With unmasked convolutions for the encoder
        for _ in range(num_sets):
            dilation_rate = 1
            for _ in range(set_size):
                # Dilation rate does not exceed a given maximum
                # Example from the paper: 16
                self.layers.append(ResidualBlockReLu(hidden_channels,
                                                     dilation_rate if dilation_rate <= max_dilation_rate else max_dilation_rate,
                                                     masked_kernel_size))
                                # Dilation Rate doubles each layer (starting out at 1)
                dilation_rate = dilation_rate * 2

            # "the network applies one more convolution"
        # Note: The output of the residual layers is 2*input_features, however the output of the final convolutions is not specified in the paper
        # Experimentation needed if it should be 2*input_features or input_features
        # self.encoder_out_conv = nn.Conv1d(in_channels=hidden_channels * 2, out_channels=2 * hidden_channels, kernel_size=1)
        # "and ReLU"
        # Not sure if these last 2 layers should be in encoder or just decoder
        # self.layers.append(nn.ReLU())
        # "followed by a convolution"
        # self.layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size))
        # "and a final softmax layer" (probably not for encoder, however paper does not specify)
        # self.layers.append(nn.Softmax(dim=1))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


In [6]:
class BytenetDecoder(nn.Module):
    """
    Implementation of the ByteNet Decoder. Default Parameters are set to the ones used in the paper.
    
    :param kernel_size: The kernel size for the unmasked (padded) convolution in the residual block (not important for decoder).
    :param max_dilation_rate: The maximum dilation rate for the convolution layers.
    :param masked_kernel_size: The kernel size for the masked convolution in the residual block.
    :param num_sets: The number of sets of residual blocks.
    :param set_size: The number of residual blocks in each set.
    :param hidden_channels: The number of hidden channels in the model.
    """
    def __init__(self, kernel_size=3, max_dilation_rate=16, masked_kernel_size=3, num_sets=6, set_size=5,
                 hidden_channels=800, output_channels=384):
        super(BytenetDecoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, hidden_channels*2)
        self.num_channels = hidden_channels
        self.kernel_size = kernel_size
        self.layers = nn.Sequential()
        # From the Paper:
        # Model has a series of residual blocks of increased dilation rate
        # With masekd convolution for decoder
        for _ in range(num_sets):
            dilation_rate = 1
            for _ in range(set_size):
                # Dilation Rate doubles each layer (starting out at 1)
                # 1, 2, 4, 8, 16
                # Dilation rate does not exceed a given maximum
                # Example from the paper: 16
                self.layers.append(ResidualBlockReLu(hidden_channels,
                                                     dilation_rate if dilation_rate <= max_dilation_rate else max_dilation_rate,
                                                     masked_kernel_size, decoder=True))
                dilation_rate = dilation_rate * 2

        # "the network applies one more convolution"
        # Note: The output of the residual layers is 2*input_features, however the output of the final convolutions is not specified in the paper
        # Experimentation needed if it should be 2*input_features or input_features
        self.layers.append(nn.Conv1d(hidden_channels * 2, hidden_channels, 1))
        # "and ReLU"
        self.layers.append(nn.ReLU())
        # "followed by a convolution"
        self.layers.append(nn.Conv1d(hidden_channels, output_channels, 1))
        # "and a final softmax layer"
        # self.layers.append(nn.LogSoftmax(dim=-1))

        # self.layers.append(nn.Softmax(dim=1))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


In [7]:
class EncoderDecoderStacking(nn.Module):
    """
    Stacks the encoder and decoder for the ByteNet model.
    This means passing the output of the encoder as input to the decoder.
    
    :param kernel_size: The kernel size for the unmasked (padded) convolution in the residual block (for Encoder).
    :param max_dilation_rate: The maximum dilation rate for the convolution layers.
    :param masked_kernel_size: The kernel size for the masked convolution in the residual block (for Decoder).
    :param num_sets: The number of sets of residual blocks.
    :param set_size: The number of residual blocks in each set.
    :param hidden_channels: The number of hidden channels in the model.
    :param output_channels: The number of output channels in the model (vocab size).

    :return x: The output of the decoder.
    """

    def __init__(self, kernel_size=3, max_dilation_rate=16, masked_kernel_size=3, n_sets=6, blocks_per_set=5,
                 hidden_channels=800, output_channels = 384, emb_size= 1600):
        super(EncoderDecoderStacking, self).__init__()
        self.embed = nn.Embedding(vocab_size, emb_size)
        self.encoder = BytenetEncoder(kernel_size=kernel_size, max_dilation_rate=max_dilation_rate,
                                      masked_kernel_size=masked_kernel_size, num_sets=n_sets, set_size=blocks_per_set,
                                      hidden_channels=hidden_channels, emb_size=emb_size)
        self.decoder = BytenetDecoder(kernel_size=kernel_size, max_dilation_rate=max_dilation_rate,
                                      masked_kernel_size=masked_kernel_size, num_sets=n_sets, set_size=blocks_per_set,
                                      hidden_channels=hidden_channels, output_channels=output_channels)

    def forward(self, x):
        # This permutation is needed for embeddings in pytorch with 1d convolutions
        embed_x = self.embed(x).permute(0, 2, 1)
        x = self.encoder(embed_x)
        x = self.decoder(x)
        return x


In [8]:
class InputEmbeddingTensor:
    """
    Class which enables the embedding of tokens.

    :param vocab_size: The size of the vocabulary as int.
    :param embed_size: The size of the embedding units as int.
    """

    def __init__(self, vocab_size, embed_size):
        super(InputEmbeddingTensor, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        # This is the actual lookup table.
        # A lookup table is an array of data that maps input values to output values
        self.lookup_table_non_zero = nn.Embedding(vocab_size - 1, embed_size)
        init.xavier_uniform_(self.lookup_table_non_zero.weight)

    def embed(self, in_values):
        """
        In this method the first n tokens are embedded via look-up table.
        The n tokens serve as targets for the predictions.

        :param in_values: The train input values from batch, more exact: the tokens
        :return: A embedded tensor of size n × 2d where d is the number of inner
                channels in the network
        """
        lookup_table_zero = torch.zeros(1, self.embed_size).to(in_values.device)
        # Here the both look up tables are combined. The rows with the zeros and the rows
        # with values from the actual lookup table are combined therefore
        lookup_table = torch.cat((lookup_table_zero, self.lookup_table_non_zero.weight.to(device)),
                                 0)  # Move to the same device as inputs
        # Next the input ids are embedded into the lookup table, which means that each id has it own
        # embedding-vector, f.e:
        # id: 5 => [1,5,4]; id:7 => [3,2,9]
        # The input ids are the tokens
        # If a token sequence of 5;7 is used, the resulting matrix is:
        # [1,5,4],[3,2,9]
        return F.embedding(in_values, lookup_table).to(in_values.device)


### Training
The following cells implement the training of the ByteNet model. The model is trained on a part of the WMT2014 english to german dataset.

In [26]:
# Load the data
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cache_dir = 'F:/wmt19_cache'
#    wmt_loader = WMTLoader(split="train", cache_dir=cache_dir)
# Number of workers provides parallel loading
num_workers = 8
#    data_load = DataLoader(wmt_loader, batch_size=32, collate_fn=wmt_loader.collate_fn, num_workers=num_workers)
#    temp = data_load
#
# for batch in wmt_loader:
#     src_batch, tgt_batch = batch
#     break
batch_size = 64
# change as needed
output_dir = 'F:\\wmt19_json'
download_entire_de_en_dataset(batch_size, output_dir, 4)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
wmt_json_loader = WMT19JSONLoader(output_dir)


64
128
192
256
320
384
448
512
576
640
704
768
832
896
960
1024
1088
1152
1216
1280
1344
1408
1472
1536
1600
1664
1728
1792
1856
1920
1984
2048
2112
2176
2240
2304
2368
2432
2496
2560
2624
2688
2752
2816
2880
2944
3008
3072
3136
3200
3264
3328
3392
3456
3520
3584
3648
3712
3776
3840
3904
3968
4032
4096
4160
4224
4288
4352
4416
4480
4544
4608
4672
4736
4800
4864
4928
4992
5056
5120
5184
5248
5312
5376
5440
5504
5568
5632
5696
5760
5824
5888
5952
6016
6080
6144
6208
6272
6336
6400
6464
6528
6592
6656
6720
6784
6848
6912
6976
7040
7104
7168
7232
7296
7360
7424
7488
7552
7616
7680
7744
7808
7872
7936
8000
8064
8128
8192
8256
8320
8384
8448
8512
8576
8640
8704
8768
8832
8896
8960
9024
9088
9152
9216
9280
9344
9408
9472
9536
9600
9664
9728
9792
9856
9920
9984
10048
10112
10176
10240
10304
10368
10432
10496
10560
10624
10688
10752
10816
10880
10944
11008
11072
11136
11200
11264
11328
11392
11456
11520
11584
11648
11712
11776
11840
11904
11968
12032
12096
12160
12224
12288
12352
12416
12480
12

In [10]:
print(wmt_json_loader.tokenizer.get_vocab())

{'<pad>': 0, '</s>': 1, '<unk>': 2, '\x00': 3, '\x01': 4, '\x02': 5, '\x03': 6, '\x04': 7, '\x05': 8, '\x06': 9, '\x07': 10, '\x08': 11, '\t': 12, '\n': 13, '\x0b': 14, '\x0c': 15, '\r': 16, '\x0e': 17, '\x0f': 18, '\x10': 19, '\x11': 20, '\x12': 21, '\x13': 22, '\x14': 23, '\x15': 24, '\x16': 25, '\x17': 26, '\x18': 27, '\x19': 28, '\x1a': 29, '\x1b': 30, '\x1c': 31, '\x1d': 32, '\x1e': 33, '\x1f': 34, ' ': 35, '!': 36, '"': 37, '#': 38, '$': 39, '%': 40, '&': 41, "'": 42, '(': 43, ')': 44, '*': 45, '+': 46, ',': 47, '-': 48, '.': 49, '/': 50, '0': 51, '1': 52, '2': 53, '3': 54, '4': 55, '5': 56, '6': 57, '7': 58, '8': 59, '9': 60, ':': 61, ';': 62, '<': 63, '=': 64, '>': 65, '?': 66, '@': 67, 'A': 68, 'B': 69, 'C': 70, 'D': 71, 'E': 72, 'F': 73, 'G': 74, 'H': 75, 'I': 76, 'J': 77, 'K': 78, 'L': 79, 'M': 80, 'N': 81, 'O': 82, 'P': 83, 'Q': 84, 'R': 85, 'S': 86, 'T': 87, 'U': 88, 'V': 89, 'W': 90, 'X': 91, 'Y': 92, 'Z': 93, '[': 94, '\\': 95, ']': 96, '^': 97, '_': 98, '`': 99, 'a': 10

In [9]:
# HYPERPARAMETERS
num_sets = 3
set_size = 5
embed_size = 1600 # Paper
batch_size = 64

In [10]:
class TranslationDataset(Dataset):
    def __init__(self, source_texts, target_texts):
        self.source_texts = source_texts
        self.target_texts = target_texts

    def __len__(self):
        return len(self.source_texts)

    def __getitem__(self, idx):
        return self.source_texts[idx], self.target_texts[idx]


In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
output_dir = 'F:\\wmt19_json'

print(f'Using device: {device}')
wmt_json_loader = WMT19JSONLoader(output_dir)

cache_dir = 'F:/wmt19_cache'
# wmt_loader = WMTLoader(split="train", cache_dir=cache_dir)
# index = 0
# source, target = wmt_loader[index]
# print("Source:", source)# print("Target:", target)

# use drive in which to save dataset in cache
tokenized_source_texts, tokenized_target_texts = wmt_json_loader.load_and_tokenize(
    'F:\\wmt19_json\\wmt_19_de_en.json')
src = tokenized_source_texts
trgt = tokenized_target_texts
vocab_size = len(wmt_json_loader.tokenizer.get_vocab())
print(f"Vocabulary size: {vocab_size}")

Using device: cuda
Error when line is decoded: Expecting value: line 1 column 1 (char 0)
Error when line is decoded: Expecting ',' delimiter: line 1 column 197 (char 196)
Error when line is decoded: Expecting value: line 1 column 1 (char 0)
Error when line is decoded: Expecting ',' delimiter: line 1 column 490 (char 489)
Error when line is decoded: Expecting value: line 1 column 1 (char 0)
Error when line is decoded: Expecting ',' delimiter: line 1 column 412 (char 411)
Error when line is decoded: Expecting value: line 1 column 1 (char 0)
Error when line is decoded: Expecting value: line 1 column 1 (char 0)
Error when line is decoded: Extra data: line 1 column 6 (char 5)
Error when line is decoded: Expecting ',' delimiter: line 1 column 709 (char 708)
Error when line is decoded: Expecting ':' delimiter: line 1 column 326 (char 325)
Error when line is decoded: Expecting ',' delimiter: line 1 column 296 (char 295)
Error when line is decoded: Expecting ':' delimiter: line 1 column 311 (ch

In [14]:
print(src[:1])

[tensor([ 75, 108, 104, 117,  35, 107, 100, 119,  35, 113, 108, 102, 107, 119,
         35, 104, 119, 122, 100,  35, 103, 100, 118,  35,  83, 100, 117, 111,
        100, 112, 104, 113, 119,  35, 103, 108, 104,  35,  83, 117, 108, 113,
        125, 108, 115, 108, 104, 113,  35, 103, 104, 117,  35,  87, 117, 100,
        113, 118, 115, 100, 117, 104, 113, 125,  35, 120, 113, 103,  35,  82,
        105, 105, 104, 113, 107, 104, 108, 119,  35, 108, 113,  35, 103, 108,
        104,  35,  87, 100, 119,  35, 120, 112, 106, 104, 118, 104, 119, 125,
        119,  47,  35, 118, 114, 113, 103, 104, 117, 113,  35, 104, 118,  35,
        122, 120, 117, 103, 104, 113,  35, 105, 198, 191, 117,  35, 104, 108,
        113,   1])]


In [15]:
print(trgt[:500])

[tensor([ 87, 107, 108, 118,  35, 122, 100, 118,  35, 113, 114, 119,  35, 119,
        107, 104,  35, 117, 104, 118, 120, 111, 119,  35, 114, 105,  35,  83,
        100, 117, 111, 108, 100, 112, 104, 113, 119,  35, 115, 120, 119, 119,
        108, 113, 106,  35, 108, 113, 119, 114,  35, 115, 117, 100, 102, 119,
        108, 102, 104,  35, 119, 107, 104,  35, 115, 117, 108, 113, 102, 108,
        115, 111, 104, 118,  35, 114, 105,  35, 119, 117, 100, 113, 118, 115,
        100, 117, 104, 113, 102, 124,  35, 100, 113, 103,  35, 114, 115, 104,
        113, 113, 104, 118, 118,  47,  35, 101, 120, 119,  35, 105, 114, 111,
        111, 114, 122, 118,  35, 119, 107, 104,  35, 102, 114, 115, 124, 108,
        113,   1]), tensor([ 81, 114, 122,  35, 119, 107, 100, 119,  35, 119, 107, 104,  35, 117,
        104, 102, 114, 117, 103, 118,  35, 107, 100, 121, 104,  35, 101, 104,
        104, 113,  35, 115, 120, 119,  35, 108, 113, 119, 114,  35, 119, 107,
        104,  35, 115, 120, 101, 111, 108, 

In [16]:
# Idea for what could be used for unfolding, however even with batch size 1, it does not work as I run ot of memory so I could not test it
def decode(decoder, context_vector):
    batch_size = context_vector.shape[0]

    output_sequence = []
    end_of_sequence_mask = torch.zeros((batch_size, 384), dtype=torch.bool, device=context_vector.device)

    for _ in range(128):
        if len(output_sequence) == 0:  # for the first step, use context vector
            decoder_input = context_vector
        else:  # for subsequent steps, use the last output token
            input_token = output_sequence[-1]
            input_token_embedded = encoder_decoder.embed(input_token).permute(0, 2, 1)
            decoder_input = torch.cat([context_vector, input_token_embedded], dim=-1)

        output_token = decoder(decoder_input)
        predicted_token = torch.argmax(output_token, dim=-1)
        end_of_sequence_mask |= (predicted_token.squeeze() == 1)
        predicted_token[end_of_sequence_mask] = 0

        output_sequence.append(predicted_token.squeeze())

    return torch.stack(output_sequence, dim=1)[:, 1:]

In [12]:
# LayerNorm that can take changing input sequence lengths
class DynamicLayerNorm(nn.Module):
    def __init__(self):
        super(DynamicLayerNorm, self).__init__()
        self.norm = None

    def forward(self, x):
        if self.norm is None or self.norm.normalized_shape[0] != x.shape[2]:
            self.norm = nn.LayerNorm(x.shape[2]).to(x.device)
        return self.norm(x)

In [13]:
translation_dataset = TranslationDataset(tokenized_source_texts, tokenized_target_texts)
dataset_size = len(translation_dataset)
train_size = int(0.8 * dataset_size)
test_size = dataset_size - train_size
train_dataset, test_dataset = torch.utils.data.random_split(translation_dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [14]:
criterion = torch.nn.CrossEntropyLoss()
# size and all params according to the paper, reduce for performance
encoder_decoder = EncoderDecoderStacking(n_sets=3, blocks_per_set=5, output_channels=vocab_size,emb_size=embed_size).to(
    device)

# Define a loss function and an optimizer
# When changing Loss function, make sur to check if the decoder should have the softmax layer, and adjust that
optimizer = torch.optim.Adam(encoder_decoder.parameters(), lr=0.0003)  #  Paper: 0.0003
# Number of epochs
num_epochs = 2


In [20]:
!pip install tensorboard




[notice] A new release of pip is available: 23.2.1 -> 24.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [16]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [17]:
 # Train the model loop
 for epoch in range(num_epochs):
    for i, (inputs, targets) in tqdm(enumerate(train_loader), total=len(train_loader)):
        # Move data to the appropriate device
        # inputs = inputEmbedding.embed(inputs.to(device))  # Ad batch dimension
        inputs = inputs.to(device)  # Add batch
        targets = targets.to(device)  # Add batch
        inputs = encoder_decoder.embed(inputs).permute(0, 2, 1)
        enc = encoder_decoder.encoder(inputs)
        enc = torch.cat((enc, encoder_decoder.embed(targets).permute(0, 2, 1)), dim=2)
        
        outputs = encoder_decoder.decoder(enc)
        predicted_targets = outputs[:, :, outputs.shape[2]//2:]

        # Compute loss
        loss = criterion(predicted_targets, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print loss every 100 steps
        if i % 25 == 0:
            tqdm.write(
            f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item()}')
            writer.add_scalar('Loss/train', loss.item(), (i+1)/len(train_loader))
        if i % 200 == 0:
            torch.save(encoder_decoder, 'model_new.pth')
            encoder_decoder.eval()
            total_val_loss = 0
            with torch.inference_mode():
                for i, (inputs, targets) in enumerate(test_loader):
                    inputs = inputs.to(device)
                    targets = targets.to(device)
                    inputs = encoder_decoder.embed(inputs).permute(0, 2, 1)
                    enc = encoder_decoder.encoder(inputs)
                    enc = torch.cat((enc, encoder_decoder.embed(targets).permute(0, 2, 1)), dim=2)
    
                    outputs = encoder_decoder.decoder(enc)
                    predicted_targets = outputs[:, :, outputs.shape[2]//2:]
                    loss = criterion(predicted_targets, targets)
                    total_val_loss += loss.item()
                avg_val_loss = total_val_loss / len(test_loader)
                tqdm.write(f'Validation Loss: {avg_val_loss}')
                writer.add_scalar('Loss/val', avg_val_loss, (i+1)/len(train_loader))
            encoder_decoder.train()

  0%|          | 0/635 [00:08<?, ?it/s]

Epoch [1/2], Step [1/635], Loss: 5.970308303833008


  0%|          | 1/635 [00:52<9:17:21, 52.75s/it]

Validation Loss: 2.1264941977255


  4%|▍         | 26/635 [07:16<2:38:50, 15.65s/it]

Epoch [1/2], Step [26/635], Loss: 0.009035281836986542


  8%|▊         | 51/635 [13:20<2:21:11, 14.51s/it]

Epoch [1/2], Step [51/635], Loss: 0.006053619086742401


 12%|█▏        | 76/635 [19:25<2:17:44, 14.78s/it]

Epoch [1/2], Step [76/635], Loss: 0.0005188676295801997


 16%|█▌        | 101/635 [25:30<2:10:39, 14.68s/it]

Epoch [1/2], Step [101/635], Loss: 0.0006064572371542454


 20%|█▉        | 126/635 [31:34<2:03:19, 14.54s/it]

Epoch [1/2], Step [126/635], Loss: 0.0001635016524232924


 24%|██▍       | 151/635 [37:38<1:57:00, 14.50s/it]

Epoch [1/2], Step [151/635], Loss: 0.0012078034924343228


 28%|██▊       | 176/635 [43:43<1:52:18, 14.68s/it]

Epoch [1/2], Step [176/635], Loss: 3.1776278774486855e-05


 31%|███▏      | 200/635 [49:47<1:45:09, 14.50s/it]

Epoch [1/2], Step [201/635], Loss: 0.0007605839055031538


 32%|███▏      | 201/635 [50:40<3:38:39, 30.23s/it]

Validation Loss: 0.00026583782636859594


 36%|███▌      | 226/635 [56:58<1:47:40, 15.80s/it]

Epoch [1/2], Step [226/635], Loss: 6.42079976387322e-05


 40%|███▉      | 251/635 [1:03:33<1:41:13, 15.82s/it]

Epoch [1/2], Step [251/635], Loss: 1.6992411474348046e-05


 43%|████▎     | 276/635 [1:10:08<1:34:22, 15.77s/it]

Epoch [1/2], Step [276/635], Loss: 0.00030888113542459905


 47%|████▋     | 301/635 [1:16:43<1:28:11, 15.84s/it]

Epoch [1/2], Step [301/635], Loss: 1.5953764886944555e-05


 51%|█████▏    | 326/635 [1:23:19<1:21:28, 15.82s/it]

Epoch [1/2], Step [326/635], Loss: 1.0803877557918895e-05


 55%|█████▌    | 351/635 [1:29:52<1:14:51, 15.81s/it]

Epoch [1/2], Step [351/635], Loss: 3.72878675989341e-05


 59%|█████▉    | 376/635 [1:36:27<1:08:16, 15.82s/it]

Epoch [1/2], Step [376/635], Loss: 2.0411953300936148e-05


 63%|██████▎   | 400/635 [1:43:02<1:01:52, 15.80s/it]

Epoch [1/2], Step [401/635], Loss: 0.0003844169550575316


 63%|██████▎   | 401/635 [1:43:46<1:53:04, 29.00s/it]

Validation Loss: 8.561797716561302e-05


 67%|██████▋   | 426/635 [1:49:59<50:53, 14.61s/it]  

Epoch [1/2], Step [426/635], Loss: 0.0006308655138127506


 71%|███████   | 451/635 [1:56:03<44:51, 14.63s/it]

Epoch [1/2], Step [451/635], Loss: 2.821164162014611e-05


 75%|███████▍  | 476/635 [2:02:08<38:43, 14.61s/it]

Epoch [1/2], Step [476/635], Loss: 3.2159456168301404e-05


 79%|███████▉  | 501/635 [2:08:13<32:44, 14.66s/it]

Epoch [1/2], Step [501/635], Loss: 7.5863395068154205e-06


 83%|████████▎ | 526/635 [2:14:18<26:26, 14.55s/it]

Epoch [1/2], Step [526/635], Loss: 1.815696305129677e-05


 87%|████████▋ | 551/635 [2:20:22<20:31, 14.67s/it]

Epoch [1/2], Step [551/635], Loss: 1.9822799004032277e-05


 91%|█████████ | 576/635 [2:26:27<14:23, 14.63s/it]

Epoch [1/2], Step [576/635], Loss: 3.7699098811572185e-06


 94%|█████████▍| 600/635 [2:32:33<08:31, 14.63s/it]

Epoch [1/2], Step [601/635], Loss: 4.59348302683793e-06


 95%|█████████▍| 601/635 [2:33:17<15:45, 27.82s/it]

Validation Loss: 5.278500140772717e-05


 99%|█████████▊| 626/635 [2:39:30<02:17, 15.30s/it]

Epoch [1/2], Step [626/635], Loss: 5.214518751017749e-06


100%|██████████| 635/635 [2:41:37<00:00, 15.27s/it]
  0%|          | 0/635 [00:14<?, ?it/s]

Epoch [2/2], Step [1/635], Loss: 4.878184881818015e-06


  0%|          | 1/635 [01:25<15:03:32, 85.51s/it]

Validation Loss: 5.898518549725369e-05


  4%|▍         | 26/635 [06:48<2:11:45, 12.98s/it]

Epoch [2/2], Step [26/635], Loss: 4.270076260581845e-06


  8%|▊         | 51/635 [12:04<2:02:14, 12.56s/it]

Epoch [2/2], Step [51/635], Loss: 2.089683221129235e-06


 12%|█▏        | 76/635 [17:28<2:03:53, 13.30s/it]

Epoch [2/2], Step [76/635], Loss: 2.051925093837781e-06


 16%|█▌        | 101/635 [22:49<1:57:33, 13.21s/it]

Epoch [2/2], Step [101/635], Loss: 1.5733054397060187e-06


 20%|█▉        | 126/635 [28:19<1:51:28, 13.14s/it]

Epoch [2/2], Step [126/635], Loss: 1.039941798808286e-05


 24%|██▍       | 151/635 [33:51<1:46:33, 13.21s/it]

Epoch [2/2], Step [151/635], Loss: 4.1290927583759185e-06


 28%|██▊       | 176/635 [39:15<1:40:23, 13.12s/it]

Epoch [2/2], Step [176/635], Loss: 3.4547742870927323e-06


 31%|███▏      | 200/635 [44:43<1:36:50, 13.36s/it]

Epoch [2/2], Step [201/635], Loss: 2.676817530300468e-06


 32%|███▏      | 201/635 [45:28<3:13:44, 26.78s/it]

Validation Loss: 3.776716647543534e-05


 36%|███▌      | 226/635 [50:52<1:27:41, 12.86s/it]

Epoch [2/2], Step [226/635], Loss: 1.9059816622757353e-06


 40%|███▉      | 251/635 [56:15<1:23:00, 12.97s/it]

Epoch [2/2], Step [251/635], Loss: 7.2257848842127714e-06


 43%|████▎     | 276/635 [1:01:38<1:19:00, 13.20s/it]

Epoch [2/2], Step [276/635], Loss: 2.1225057480478426e-06


 47%|████▋     | 301/635 [1:07:14<1:24:09, 15.12s/it]

Epoch [2/2], Step [301/635], Loss: 2.943995923487819e-06


 51%|█████▏    | 326/635 [1:12:41<1:06:56, 13.00s/it]

Epoch [2/2], Step [326/635], Loss: 2.3795732886355836e-06


 55%|█████▌    | 351/635 [1:19:00<1:23:02, 17.54s/it]

Epoch [2/2], Step [351/635], Loss: 1.7238079408343765e-06


 59%|█████▉    | 376/635 [1:25:49<1:11:00, 16.45s/it]

Epoch [2/2], Step [376/635], Loss: 2.252018475701334e-06


 63%|██████▎   | 400/635 [1:32:49<1:04:01, 16.35s/it]

Epoch [2/2], Step [401/635], Loss: 5.807225534226745e-06


 63%|██████▎   | 401/635 [1:33:44<2:08:46, 33.02s/it]

Validation Loss: 4.0299756606918036e-05


 67%|██████▋   | 426/635 [1:40:34<57:05, 16.39s/it]  

Epoch [2/2], Step [426/635], Loss: 3.311465843580663e-05


 71%|███████   | 451/635 [1:47:40<50:30, 16.47s/it]  

Epoch [2/2], Step [451/635], Loss: 1.2720528275167453e-06


 75%|███████▍  | 476/635 [1:54:48<45:46, 17.27s/it]

Epoch [2/2], Step [476/635], Loss: 8.267182352028613e-07


 79%|███████▉  | 501/635 [2:01:48<38:02, 17.04s/it]

Epoch [2/2], Step [501/635], Loss: 1.0231758551526582e-06


 83%|████████▎ | 526/635 [2:08:45<30:08, 16.59s/it]

Epoch [2/2], Step [526/635], Loss: 8.331690537488612e-07


 87%|████████▋ | 551/635 [2:15:46<22:09, 15.83s/it]

Epoch [2/2], Step [551/635], Loss: 1.8997469624082441e-06


 91%|█████████ | 576/635 [2:27:48<41:00, 41.71s/it]

Epoch [2/2], Step [576/635], Loss: 6.91514685513539e-07


 94%|█████████▍| 600/635 [2:46:08<24:33, 42.09s/it]

Epoch [2/2], Step [601/635], Loss: 7.58590840632678e-06


 95%|█████████▍| 601/635 [2:47:40<40:32, 71.55s/it]

Validation Loss: 3.078527900269444e-05


 99%|█████████▊| 626/635 [2:52:36<01:44, 11.58s/it]

Epoch [2/2], Step [626/635], Loss: 8.595839631198032e-07


100%|██████████| 635/635 [2:54:27<00:00, 16.48s/it]


In [69]:
def translate(to_translate, model, loader):
    model.eval()
    with torch.no_grad():
        for i in range(1):
            # initialize character sequence
            pred_prev = torch.zeros((1, 128)).long().to(device)
            pred = torch.zeros((1, 128)).long().to(device)
    
            # generate output sequence
            for i in range(128):
                # predict character
                to_translate = to_translate[0].unsqueeze(0).to(device)
                out = model.encoder(model.embed(to_translate).permute(0, 2, 1))
                out = torch.cat((out, model.embed(pred_prev).permute(0, 2, 1)), dim=2)
                out = model.decoder(out)
                out = out[:, :, out.shape[2]//2:]
                out = torch.argmax(out, dim=2)
                # update character sequence
                if i < 127:
                    pred_prev[:, i + 1] = out[:, i]
                pred[:, i] = out[:, i]
    print(pred.shape)
    print(f"Translation: {loader.tokenizer.decode(pred.squeeze(0).tolist())}")

In [19]:
encoder_decoder = torch.load("model_new.pth")

In [72]:
encoder_decoder.eval()
text = ["Hallo wer bist du, wie ist denn dein name"]
print(f"Translating: {text}")
translate(wmt_json_loader.tokenize_texts([text][0]), encoder_decoder, wmt_json_loader)


Translating: ['Hallo wer bist du, wie ist denn dein name']
torch.Size([1, 128])
Translation: </s>O2f</s>,-6;:U4DMD</s>N<<	A</s>MGSW+EW[K</s>N&,%&</s>&$
7</s>532<pad>@&&$ <unk>4<pad>	<pad><unk>9^</s>VN\Wy x	Z&&]khz<d
