# Bytenet
Implementation of the paper Neural Machine Translation in Linear Time (https://arxiv.org/pdf/1610.10099). The Bynet is a CNN based Encoder/Decoder Model used here for Sequence to Sequence Translation. The model is trained on a part of the WMT2014 english to german dataset.

### Data Preprocessing
Data is loaded and tokenized in this step.

In [1]:
# Define Imports
import json

import requests
import torch
import pickle
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from datasets import load_dataset
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from transformers import AutoTokenizer


In [23]:
import requests


class WMT19JSONLoader:
    def __init__(self, file_path, source_lang='de', target_lang='en', max_length=128):
        self.source_lang = source_lang
        self.target_lang = target_lang
        self.max_length = max_length
        # self.tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
        self.file_path = file_path
        self.tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")

    def load_json_data(self, file_path):
        """
        Function that loads the downloaded JSON file

        :param file_path:
        :return:
        """
        loaded_data = []
        with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
            for line in f:
                try:
                    loaded_data.append(json.loads(line.strip()))
                except json.JSONDecodeError as e:
                    print(f"Error when line is decoded: {e}")
        return loaded_data

    def convert_to_tensor(self, src, trg):
        """
        Checks if source and target are tensor
        If both are not tensor, they are converted to tensors

        :param src:
        :param trg:
        :return:
        """
        if not torch.is_tensor(src):
            src = torch.Tensor(src)
        if not torch.is_tensor(trg):
            trg = torch.tensor(trg, dtype=torch.int32)
        return src, trg

    def extract_source_target(self, load_data):
        """
        Function that extracts out of the downloaded JSON the
        german rows as source and the english rows as targets

        :param load_data:
        :param source_lang:
        :param target_lang:
        :return:
        """
        source_texts = []
        target_texts = []
        for item in load_data:
            if ('row' in item and 'translation' in item['row'] and
                    self.source_lang in item['row']['translation'] and
                    self.target_lang in item['row']['translation']):
                source_texts.append(item['row']['translation'][self.source_lang])
                target_texts.append(item['row']['translation'][self.target_lang])
        return source_texts, target_texts

    def tokenize_texts(self, texts):
        """
        Function for tokenizing the text data
        Uses BERT-Tokenizer as tokenizer model

        :param texts:
        :return:
        """
        tokenized_texts = []
        for text in texts:
            tokens = self.tokenizer(text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
            
            tokenized_texts.append(tokens['input_ids'].squeeze())
        return tokenized_texts

    def load_and_tokenize(self, json_file_path):
        """
        Function that does the load json data
        and the tokenizing process

        :param json_file_path:
        """
        loaded_data = self.load_json_data(json_file_path)

        source_texts, target_texts = self.extract_source_target(loaded_data)

        # The tokenized source and targets
        # self.tokenizer is a object of type transformers from the Bert model
        # padding="max_length": is used to fill sequence to maximal length
        # truncation = True: Means that the sequence is cutted, if longer than max_length
        # return_tensors="pt": Means that a pytorch tensor is returned
        # the source text is tokenized into smaller elements
        tokenized_source_texts = self.tokenize_texts(source_texts)

        # the target text is tokenized into smaller elements
        tokenized_target_texts = self.tokenize_texts(target_texts)

        #TODO: evetually squeeze as in WMTLoader

        return tokenized_source_texts, tokenized_target_texts


def download_data(offset, length):
    """
    Method for downloading the dataset as JSON
    F.e. if the first 10 rows have to be downloaded, offset has to
    be 0 and length has to be 10

    :param offset: The offset used in the url
    :param length: The length of the selected number of rows in the dataset
    :return:
    """
    url = f"https://datasets-server.huggingface.co/rows?dataset=wmt%2Fwmt19&config=de-en&split=train&offset={offset}&length={length}"
    query_parameters = {"downloadformat": "json"}
    response = requests.get(url, params=query_parameters)
    if response.status_code == 200:
        loaded_data = response.json()
        print(f"Downloading dataset-offset: {offset}")
        return loaded_data['rows']
    else:
        print(f"Error while downloading data: {response.status_code}")
        return []


def save_data_to_json(load_data, file_path):
    """
    Writes data into the JSON object

    :param load_data: The data that has to be writen into file
    :param file_path: The file path where the file has to be saved
    """
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'a', encoding='utf-8') as f:
        for item in load_data:
            json.dump(item, f, ensure_ascii=False)
            f.write('\n')


def download_batch_and_save(offset, length, output_file):
    """
    Downloads and saves the batch

    :param offset: The offset which is currently used to download
    :param length: The length is defined with 100
    :param output_file: The name of the file to be saved
    """
    loaded_data = download_data(offset, length)
    save_data_to_json(loaded_data, output_file)


def download_entire_de_en_dataset(batch_size, output_dir, num_workers):
    """
    Downloads the entire WMT19 dataset. Uses a ThreadPoolExecutor for
    faster download of the dataset.

    :param batch_size:
    :param output_dir:
    :param num_workers:
    """
    offset = 0
    output_file = os.path.join(output_dir, 'wmt_19_de_en.json')
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = []
        while True:
            futures.append(executor.submit(download_batch_and_save, offset, batch_size, output_file))
            offset += batch_size
            # if offset >= 34800000:
            # This controls how much of the dataset is actually downloaded
            if offset >= 34800:
                break

        for future in as_completed(futures):
            future.result()


### ByteNet Model
The following Cells implement the necessary parts of the ByteNet model. The model is made up of a number of sets of, each of which contains a number of residual blocks that apply LayerNorm and 1D Convolutions, which is further masked for the Decoder part of the Network.

In [3]:
# Imports for ByteNet
import torch
from torch.utils.data import Dataset
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from data.data_loader import WMTLoader, WMT19JSONLoader, download_entire_de_en_dataset


In [120]:
class ResidualBlockReLu(nn.Module):
    """
    Implementation of residual Layer for Bytenet machine translation task.

    :param d: The number of input features.
    :param dilation: The initial dilation rate for the convolution layers.
    """

    def __init__(self, d, dilation, k, decoder=False):
        super(ResidualBlockReLu, self).__init__()
        self.decoder = decoder
        self.layer_norm1 = nn.LayerNorm(128)
        self.reLu1 = nn.ReLU()
        # 2*d -> d
        self.conv1 = nn.Conv1d(d * 2, d, 1)
        self.layer_norm2 = nn.LayerNorm(128)
        self.reLu2 = nn.ReLU()
        # Masked kernel size is k
        # Dilation only used for masked convolution
        # d -> d
        if decoder:
            # Masked convolution basically means all padding on left side
            self.receptive_field = (k - 1) * dilation
            self.conv2 = nn.Conv1d(d, d, k, dilation=dilation)
        else:
            # Padding still needed to keep the size of the input and output the same
            padding = (k - 1) * dilation // 2
            self.conv2 = nn.Conv1d(d, d, k, dilation=dilation, padding=padding)
        self.layer_norm3 = nn.LayerNorm(128)
        self.reLu3 = nn.ReLU()
        # d -> 2*d
        self.conv3 = nn.Conv1d(d, d * 2, 1)

    def forward(self, x):
        residual = x
        x = self.layer_norm1(x)
        x = self.reLu1(x)
        x = self.conv1(x)
        x = self.layer_norm2(x)
        x = self.reLu2(x)
        # When Decoder is used, the convolution is causal
        if self.decoder:
            x = F.pad(x, (self.receptive_field, 0))
        x = self.conv2(x)
        x = self.layer_norm3(x)
        x = self.reLu3(x)
        x = self.conv3(x)
        # Add back the residual
        x += residual
        return x


In [121]:
class BytenetEncoder(nn.Module):
    """
    Implementation of the ByteNet Encoder. Default Parameters are set to the ones used in the paper.
    
    :param kernel_size: The kernel size for the unmasked (padded) convolution in the residual block.
    :param max_dilation_rate: The maximum dilation rate for the convolution layers.
    :param masked_kernel_size: The kernel size for the masked convolution in the residual block (only interesting for decoder).
    :param num_sets: The number of sets of residual blocks.
    :param set_size: The number of residual blocks in each set.
    :param hidden_channels: The number of hidden channels in the model.
    """
    def __init__(self, kernel_size=3, max_dilation_rate=16, masked_kernel_size=3, num_sets=6, set_size=5,
                 hidden_channels=800, emb_size = 1600):
        super(BytenetEncoder, self).__init__()
        self.num_channels = hidden_channels
        self.kernel_size = kernel_size
        self.layers = nn.Sequential()
        # 128 is size of tokenizer
        # input of shape [batch_size, 128, 128] as [batch_size, tokens, embedding_size]
        self.layers.append(nn.Conv1d(in_channels=emb_size, out_channels=hidden_channels * 2, kernel_size=1))
        # From the Paper:
        # Model has a series of residual blocks of increased dilation rate
        # With unmasked convolutions for the encoder
        for _ in range(num_sets):
            dilation_rate = 1
            for _ in range(set_size):
                # Dilation Rate doubles each layer (starting out at 1)
                dilation_rate = dilation_rate * 2
                # Dilation rate does not exceed a given maximum
                # Example from the paper: 16
                self.layers.append(ResidualBlockReLu(hidden_channels,
                                                     dilation_rate if dilation_rate <= max_dilation_rate else max_dilation_rate,
                                                     masked_kernel_size))
            # "the network applies one more convolution"
        # Note: The output of the residual layers is 2*input_features, however the output of the final convolutions is not specified in the paper
        # Experimentation needed if it should be 2*input_features or input_features
        self.encoder_out_conv = nn.Conv1d(in_channels=hidden_channels * 2, out_channels=2 * hidden_channels, kernel_size=1)
        # "and ReLU"
        # Not sure if these last 2 layers should be in encoder or just decoder
        # self.layers.append(nn.ReLU())
        # "followed by a convolution"
        # self.layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size))
        # "and a final softmax layer" (probably not for encoder, however paper does not specify)
        # self.layers.append(nn.Softmax(dim=1))

    def forward(self, x):
        # Temporary
        x = x.float()
        for layer in self.layers:
            x = layer(x)
        x = self.encoder_out_conv(x)
        return x


In [122]:
class BytenetDecoder(nn.Module):
    """
    Implementation of the ByteNet Decoder. Default Parameters are set to the ones used in the paper.
    
    :param kernel_size: The kernel size for the unmasked (padded) convolution in the residual block (not important for decoder).
    :param max_dilation_rate: The maximum dilation rate for the convolution layers.
    :param masked_kernel_size: The kernel size for the masked convolution in the residual block.
    :param num_sets: The number of sets of residual blocks.
    :param set_size: The number of residual blocks in each set.
    :param hidden_channels: The number of hidden channels in the model.
    """
    def __init__(self, kernel_size=3, max_dilation_rate=16, masked_kernel_size=3, num_sets=6, set_size=5,
                 hidden_channels=800, output_channels=384):
        super(BytenetDecoder, self).__init__()
        self.num_channels = hidden_channels
        self.kernel_size = kernel_size
        self.layers = nn.Sequential()
        # From the Paper:
        # Model has a series of residual blocks of increased dilation rate
        # With masekd convolution for decoder
        for _ in range(num_sets):
            dilation_rate = 1
            for _ in range(set_size):
                # Dilation Rate doubles each layer (starting out at 1)
                dilation_rate = dilation_rate * 2
                # Dilation rate does not exceed a given maximum
                # Example from the paper: 16
                self.layers.append(ResidualBlockReLu(hidden_channels,
                                                     dilation_rate if dilation_rate <= max_dilation_rate else max_dilation_rate,
                                                     masked_kernel_size, decoder=True))
        # "the network applies one more convolution"
        # Note: The output of the residual layers is 2*input_features, however the output of the final convolutions is not specified in the paper
        # Experimentation needed if it should be 2*input_features or input_features
        self.layers.append(nn.Conv1d(hidden_channels * 2, hidden_channels, 1))
        # "and ReLU"
        self.layers.append(nn.ReLU())
        # "followed by a convolution"
        self.layers.append(nn.Conv1d(hidden_channels, output_channels, 1))
        # "and a final softmax layer"
        # self.layers.append(nn.LogSoftmax(dim=-1))

        # self.layers.append(nn.Softmax(dim=1))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


In [123]:
class EncoderDecoderStacking(nn.Module):
    """
    Stacks the encoder and decoder for the ByteNet model.
    This means passing the output of the encoder as input to the decoder.
    
    :param kernel_size: The kernel size for the unmasked (padded) convolution in the residual block (for Encoder).
    :param max_dilation_rate: The maximum dilation rate for the convolution layers.
    :param masked_kernel_size: The kernel size for the masked convolution in the residual block (for Decoder).
    :param num_sets: The number of sets of residual blocks.
    :param set_size: The number of residual blocks in each set.
    :param hidden_channels: The number of hidden channels in the model.
    :param output_channels: The number of output channels in the model (vocab size).

    :return x: The output of the decoder.
    """

    def __init__(self, kernel_size=3, max_dilation_rate=16, masked_kernel_size=3, n_sets=6, blocks_per_set=5,
                 hidden_channels=800, output_channels = 384, emb_size= 1600):
        super(EncoderDecoderStacking, self).__init__()
        self.embed = nn.Embedding(vocab_size, emb_size)
        self.encoder = BytenetEncoder(kernel_size=kernel_size, max_dilation_rate=max_dilation_rate,
                                      masked_kernel_size=masked_kernel_size, num_sets=n_sets, set_size=blocks_per_set,
                                      hidden_channels=hidden_channels, emb_size=emb_size)
        self.decoder = BytenetDecoder(kernel_size=kernel_size, max_dilation_rate=max_dilation_rate,
                                      masked_kernel_size=masked_kernel_size, num_sets=n_sets, set_size=blocks_per_set,
                                      hidden_channels=hidden_channels, output_channels=output_channels)

    def forward(self, x):
        # This permutation is needed for embeddings in pytorch with 1d convolutions
        embed_x = self.embed(x).permute(0, 2, 1)
        x = self.encoder(embed_x)
        x = self.decoder(x)
        return x


In [91]:
class InputEmbeddingTensor:
    """
    Class which enables the embedding of tokens.

    :param vocab_size: The size of the vocabulary as int.
    :param embed_size: The size of the embedding units as int.
    """

    def __init__(self, vocab_size, embed_size):
        super(InputEmbeddingTensor, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        # This is the actual lookup table.
        # A lookup table is an array of data that maps input values to output values
        self.lookup_table_non_zero = nn.Embedding(vocab_size - 1, embed_size)
        init.xavier_uniform_(self.lookup_table_non_zero.weight)

    def embed(self, in_values):
        """
        In this method the first n tokens are embedded via look-up table.
        The n tokens serve as targets for the predictions.

        :param in_values: The train input values from batch, more exact: the tokens
        :return: A embedded tensor of size n × 2d where d is the number of inner
                channels in the network
        """
        lookup_table_zero = torch.zeros(1, self.embed_size).to(in_values.device)
        # Here the both look up tables are combined. The rows with the zeros and the rows
        # with values from the actual lookup table are combined therefore
        lookup_table = torch.cat((lookup_table_zero, self.lookup_table_non_zero.weight.to(device)),
                                 0)  # Move to the same device as inputs
        # Next the input ids are embedded into the lookup table, which means that each id has it own
        # embedding-vector, f.e:
        # id: 5 => [1,5,4]; id:7 => [3,2,9]
        # The input ids are the tokens
        # If a token sequence of 5;7 is used, the resulting matrix is:
        # [1,5,4],[3,2,9]
        return F.embedding(in_values, lookup_table).to(in_values.device)


### Training
The following cells implement the training of the ByteNet model. The model is trained on a part of the WMT2014 english to german dataset.

In [24]:
# Load the data
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cache_dir = 'F:/wmt19_cache'
#    wmt_loader = WMTLoader(split="train", cache_dir=cache_dir)
# Number of workers provides parallel loading
num_workers = 4
#    data_load = DataLoader(wmt_loader, batch_size=32, collate_fn=wmt_loader.collate_fn, num_workers=num_workers)
#    temp = data_load
#
# for batch in wmt_loader:
#     src_batch, tgt_batch = batch
#     break

batch_size = 100
# change as needed
output_dir = 'F:\\wmt19_json'
download_entire_de_en_dataset(batch_size, output_dir, 4)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
wmt_json_loader = WMT19JSONLoader(output_dir)


Downloading dataset-offset: 300
Downloading dataset-offset: 0
Downloading dataset-offset: 200
Downloading dataset-offset: 100
Downloading dataset-offset: 400
Downloading dataset-offset: 800
Downloading dataset-offset: 700
Downloading dataset-offset: 500
Downloading dataset-offset: 900
Downloading dataset-offset: 1000
Downloading dataset-offset: 1100
Downloading dataset-offset: 1300
Downloading dataset-offset: 600
Downloading dataset-offset: 1200
Downloading dataset-offset: 1400
Downloading dataset-offset: 1500
Downloading dataset-offset: 1600
Downloading dataset-offset: 1800
Downloading dataset-offset: 1700
Downloading dataset-offset: 2100
Downloading dataset-offset: 2200
Downloading dataset-offset: 1900
Downloading dataset-offset: 2000
Downloading dataset-offset: 2300
Downloading dataset-offset: 2400
Downloading dataset-offset: 2700
Downloading dataset-offset: 2800
Downloading dataset-offset: 2500
Downloading dataset-offset: 2600
Downloading dataset-offset: 2900
Downloading dataset-of

In [25]:
print(wmt_json_loader.tokenizer.get_vocab())

{'<pad>': 0, '</s>': 1, '<unk>': 2, '\x00': 3, '\x01': 4, '\x02': 5, '\x03': 6, '\x04': 7, '\x05': 8, '\x06': 9, '\x07': 10, '\x08': 11, '\t': 12, '\n': 13, '\x0b': 14, '\x0c': 15, '\r': 16, '\x0e': 17, '\x0f': 18, '\x10': 19, '\x11': 20, '\x12': 21, '\x13': 22, '\x14': 23, '\x15': 24, '\x16': 25, '\x17': 26, '\x18': 27, '\x19': 28, '\x1a': 29, '\x1b': 30, '\x1c': 31, '\x1d': 32, '\x1e': 33, '\x1f': 34, ' ': 35, '!': 36, '"': 37, '#': 38, '$': 39, '%': 40, '&': 41, "'": 42, '(': 43, ')': 44, '*': 45, '+': 46, ',': 47, '-': 48, '.': 49, '/': 50, '0': 51, '1': 52, '2': 53, '3': 54, '4': 55, '5': 56, '6': 57, '7': 58, '8': 59, '9': 60, ':': 61, ';': 62, '<': 63, '=': 64, '>': 65, '?': 66, '@': 67, 'A': 68, 'B': 69, 'C': 70, 'D': 71, 'E': 72, 'F': 73, 'G': 74, 'H': 75, 'I': 76, 'J': 77, 'K': 78, 'L': 79, 'M': 80, 'N': 81, 'O': 82, 'P': 83, 'Q': 84, 'R': 85, 'S': 86, 'T': 87, 'U': 88, 'V': 89, 'W': 90, 'X': 91, 'Y': 92, 'Z': 93, '[': 94, '\\': 95, ']': 96, '^': 97, '_': 98, '`': 99, 'a': 10

In [102]:
# HYPERPARAMETERS
num_sets = 3
set_size = 5
embed_size = 256 # Paper
batch_size = 64

In [116]:
class TranslationDataset(Dataset):
    def __init__(self, source_texts, target_texts):
        self.source_texts = source_texts
        self.target_texts = target_texts

    def __len__(self):
        return len(self.source_texts)

    def __getitem__(self, idx):
        return self.source_texts[idx], self.target_texts[idx]


In [45]:
cache_dir = 'F:/wmt19_cache'
# wmt_loader = WMTLoader(split="train", cache_dir=cache_dir)
# index = 0
# source, target = wmt_loader[index]
# print("Source:", source)
# print("Target:", target)

# use drive in which to save dataset in cache
tokenized_source_texts, tokenized_target_texts = wmt_json_loader.load_and_tokenize(
    'F:\\wmt19_json\\wmt_19_de_en.json')
src = tokenized_source_texts
trgt = tokenized_target_texts
vocab_size = len(wmt_json_loader.tokenizer.get_vocab())
print(f"Vocabulary size: {vocab_size}")

Error when line is decoded: Expecting ',' delimiter: line 1 column 1109 (char 1108)
Error when line is decoded: Expecting ':' delimiter: line 1 column 401 (char 400)
Error when line is decoded: Expecting ',' delimiter: line 1 column 614 (char 613)
Error when line is decoded: Extra data: line 1 column 489 (char 488)
Error when line is decoded: Expecting ',' delimiter: line 1 column 462 (char 461)
Error when line is decoded: Expecting ':' delimiter: line 1 column 784 (char 783)
Error when line is decoded: Expecting ',' delimiter: line 1 column 186 (char 185)
Error when line is decoded: Expecting value: line 1 column 1 (char 0)
Error when line is decoded: Expecting ',' delimiter: line 1 column 390 (char 389)
Error when line is decoded: Expecting ':' delimiter: line 1 column 274 (char 273)
Error when line is decoded: Expecting ',' delimiter: line 1 column 227 (char 226)
Error when line is decoded: Extra data: line 1 column 314 (char 313)
Error when line is decoded: Expecting value: line 1 

In [46]:
print(src[:1])

[tensor([ 90, 108, 104, 103, 104, 117, 100, 120, 105, 113, 100, 107, 112, 104,
         35, 103, 104, 117,  35,  86, 108, 119, 125, 120, 113, 106, 118, 115,
        104, 117, 108, 114, 103, 104,   1,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0])]


In [30]:
print(trgt[:500])

[tensor([ 85, 104, 118, 120, 112, 115, 119, 108, 114, 113,  35, 114, 105,  35,
        119, 107, 104,  35, 118, 104, 118, 118, 108, 114, 113,   1,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0]), tensor([ 76,  35, 103, 104, 102, 111, 100, 117, 104,  35, 117, 104, 118, 120,
        112, 104, 103,  35, 119, 107, 104,  35, 118, 104, 118, 118, 108, 114,
        113,  35, 114, 105,  35, 119, 107, 104,  35,  72, 120, 117, 114, 115,
        104, 100, 113,  35,  83, 100, 117, 

In [117]:
translation_dataset = TranslationDataset(tokenized_source_texts, tokenized_target_texts)
dataset_size = len(translation_dataset)
train_size = int(0.8 * dataset_size)
test_size = dataset_size - train_size
train_dataset, test_dataset = torch.utils.data.random_split(translation_dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [124]:
criterion = torch.nn.CrossEntropyLoss()
inputEmbedding = InputEmbeddingTensor(vocab_size, embed_size)
# size and all params according to the paper, reduce for performance
encoder_decoder = EncoderDecoderStacking(n_sets=num_sets, blocks_per_set=set_size, output_channels=vocab_size,emb_size=embed_size).to(
    device)

# Define a loss function and an optimizer
# When changing Loss function, make sure to check if the decoder should have the softmax layer, and adjust that
optimizer = torch.optim.Adam(encoder_decoder.parameters(), lr=0.0003)  # replace with your actual optimizer
# Number of epochs
num_epochs = 1


In [125]:
 # Train the model loop
 for epoch in range(1):
    for i, (inputs, targets) in tqdm(enumerate(train_loader), total=len(train_loader)):
        # Move data to the appropriate device
        # inputs = inputEmbedding.embed(inputs.to(device))  # Add batch dimension
        inputs = inputs.to(device)  # Add batch
        targets = targets.to(device)  # Add batch

        outputs = encoder_decoder(inputs.to(device))
        # Compute loss
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print loss every 100 steps
        if i % 25 == 0:
            tqdm.write(
                f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item()}')

  0%|          | 1/18391 [00:18<92:45:42, 18.16s/it]

Epoch [1/1], Step [1/18391], Loss: 6.1425862312316895


  0%|          | 1/18391 [00:52<270:09:16, 52.89s/it]


KeyboardInterrupt: 