# Bytenet
Implementation of the paper Neural Machine Translation in Linear Time (https://arxiv.org/pdf/1610.10099). The Bynet is a CNN based Encoder/Decoder Model used here for Sequence to Sequence Translation. The model is trained on a part of the WMT2014 english to german dataset.

### Data Preprocessing
Data is loaded and tokenized in this step.

In [3]:
# Define Imports
import json

import requests
import torch
import pickle
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from datasets import load_dataset
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from transformers import AutoTokenizer


In [6]:
import requests


class WMT19JSONLoader:
    def __init__(self, file_path, source_lang='de', target_lang='en', max_length=128):
        self.source_lang = source_lang
        self.target_lang = target_lang
        self.max_length = max_length
        # self.tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
        self.file_path = file_path
        self.tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")

    def load_json_data(self, file_path):
        """
        Function that loads the downloaded JSON file

        :param file_path:
        :return:
        """
        loaded_data = []
        with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
            for line in f:
                try:
                    loaded_data.append(json.loads(line.strip()))
                except json.JSONDecodeError as e:
                    print(f"Error when line is decoded: {e}")
        return loaded_data

    def convert_to_tensor(self, src, trg):
        """
        Checks if source and target are tensor
        If both are not tensor, they are converted to tensors

        :param src:
        :param trg:
        :return:
        """
        if not torch.is_tensor(src):
            src = torch.Tensor(src)
        if not torch.is_tensor(trg):
            trg = torch.tensor(trg, dtype=torch.int32)
        return src, trg

    def extract_source_target(self, load_data):
        """
        Function that extracts out of the downloaded JSON the
        german rows as source and the english rows as targets

        :param load_data:
        :param source_lang:
        :param target_lang:
        :return:
        """
        source_texts = []
        target_texts = []
        for item in load_data:
            if ('row' in item and 'translation' in item['row'] and
                    self.source_lang in item['row']['translation'] and
                    self.target_lang in item['row']['translation']):
                source_texts.append(item['row']['translation'][self.source_lang])
                target_texts.append(item['row']['translation'][self.target_lang])
        return source_texts, target_texts

    def tokenize_texts(self, texts):
        """
        Function for tokenizing the text data
        Uses BERT-Tokenizer as tokenizer model

        :param texts:
        :return:
        """
        tokenized_texts = []
        for text in texts:
            tokens = self.tokenizer(text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
            
            tokenized_texts.append(tokens['input_ids'].squeeze())
        return tokenized_texts

    def load_and_tokenize(self, json_file_path):
        """
        Function that does the load json data
        and the tokenizing process

        :param json_file_path:
        """
        loaded_data = self.load_json_data(json_file_path)

        source_texts, target_texts = self.extract_source_target(loaded_data)

        # The tokenized source and targets
        # self.tokenizer is a object of type transformers from the Bert model
        # padding="max_length": is used to fill sequence to maximal length
        # truncation = True: Means that the sequence is cutted, if longer than max_length
        # return_tensors="pt": Means that a pytorch tensor is returned
        # the source text is tokenized into smaller elements
        tokenized_source_texts = self.tokenize_texts(source_texts)

        # the target text is tokenized into smaller elements
        tokenized_target_texts = self.tokenize_texts(target_texts)

        #TODO: evetually squeeze as in WMTLoader

        return tokenized_source_texts, tokenized_target_texts


def download_data(offset, length):
    """
    Method for downloading the dataset as JSON
    F.e. if the first 10 rows have to be downloaded, offset has to
    be 0 and length has to be 10

    :param offset: The offset used in the url
    :param length: The length of the selected number of rows in the dataset
    :return:
    """
    url = f"https://datasets-server.huggingface.co/rows?dataset=wmt%2Fwmt19&config=de-en&split=train&offset={offset}&length={length}"
    query_parameters = {"downloadformat": "json"}
    response = requests.get(url, params=query_parameters)
    if response.status_code == 200:
        loaded_data = response.json()
        print(f"Downloading dataset-offset: {offset}")
        return loaded_data['rows']
    else:
        print(f"Error while downloading data: {response.status_code}")
        return []


def save_data_to_json(load_data, file_path):
    """
    Writes data into the JSON object

    :param load_data: The data that has to be writen into file
    :param file_path: The file path where the file has to be saved
    """
    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'a', encoding='utf-8') as f:
        for item in load_data:
            json.dump(item, f, ensure_ascii=False)
            f.write('\n')


def download_batch_and_save(offset, length, output_file):
    """
    Downloads and saves the batch

    :param offset: The offset which is currently used to download
    :param length: The length is defined with 100
    :param output_file: The name of the file to be saved
    """
    loaded_data = download_data(offset, length)
    save_data_to_json(loaded_data, output_file)


def download_entire_de_en_dataset(batch_size, output_dir, num_workers):
    """
    Downloads the entire WMT19 dataset. Uses a ThreadPoolExecutor for
    faster download of the dataset.

    :param batch_size:
    :param output_dir:
    :param num_workers:
    """
    offset = 0
    output_file = os.path.join(output_dir, 'wmt_19_de_en.json')
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        offset = 0
        futures = []
        while True:
            futures.append(executor.submit(download_batch_and_save, offset, batch_size, output_file))
            offset += batch_size
            # if offset >= 34800000:
            # This controls how much of the dataset is actually downloaded
            if offset >= 50000:
                break

        for future in as_completed(futures):
            future.result()


### ByteNet Model
The following Cells implement the necessary parts of the ByteNet model. The model is made up of a number of sets of, each of which contains a number of residual blocks that apply LayerNorm and 1D Convolutions, which is further masked for the Decoder part of the Network.

In [7]:
# Imports for ByteNet
import torch
from torch.utils.data import Dataset
from torch import nn
from torch.nn import init
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from data.data_loader import WMTLoader, WMT19JSONLoader, download_entire_de_en_dataset


In [8]:
class ResidualBlockReLu(nn.Module):
    """
    Implementation of residual Layer for Bytenet machine translation task.

    :param d: The number of input features.
    :param dilation: The initial dilation rate for the convolution layers.
    """

    def __init__(self, d, dilation, k, decoder=False):
        super(ResidualBlockReLu, self).__init__()
        self.decoder = decoder
        self.layer_norm1 = nn.LayerNorm(128)
        self.reLu1 = nn.ReLU()
        # 2*d -> d
        self.conv1 = nn.Conv1d(d * 2, d, 1)
        self.layer_norm2 = nn.LayerNorm(128)
        self.reLu2 = nn.ReLU()
        # Masked kernel size is k
        # Dilation only used for masked convolution
        # d -> d
        if decoder:
            # Masked convolution basically means all padding on left side
            self.receptive_field = (k - 1) * dilation
            self.conv2 = nn.Conv1d(d, d, k, dilation=dilation)
        else:
            # Padding still needed to keep the size of the input and output the same
            padding = (k - 1) * dilation // 2
            if padding > 0:
                self.conv2 = nn.Conv1d(d, d, k, dilation=dilation, padding=padding)
            else:
                self.conv2 = nn.Conv1d(d, d, k, dilation=dilation)
        self.layer_norm3 = nn.LayerNorm(128)
        self.reLu3 = nn.ReLU()
        # d -> 2*d
        self.conv3 = nn.Conv1d(d, d * 2, 1)

    def forward(self, x):
        residual = x
        x = self.layer_norm1(x)
        x = self.reLu1(x)
        x = self.conv1(x)
        x = self.layer_norm2(x)
        x = self.reLu2(x)
        # When Decoder is used, the convolution is causal
        if self.decoder and self.receptive_field > 0:
            x = F.pad(x, (self.receptive_field, 0))
        x = self.conv2(x)
        x = self.layer_norm3(x)
        x = self.reLu3(x)
        x = self.conv3(x)
        # Add back the residual
        x += residual
        return x


In [9]:
class BytenetEncoder(nn.Module):
    """
    Implementation of the ByteNet Encoder. Default Parameters are set to the ones used in the paper.
    
    :param kernel_size: The kernel size for the unmasked (padded) convolution in the residual block.
    :param max_dilation_rate: The maximum dilation rate for the convolution layers.
    :param masked_kernel_size: The kernel size for the masked convolution in the residual block (only interesting for decoder).
    :param num_sets: The number of sets of residual blocks.
    :param set_size: The number of residual blocks in each set.
    :param hidden_channels: The number of hidden channels in the model.
    """
    def __init__(self, kernel_size=3, max_dilation_rate=16, masked_kernel_size=3, num_sets=6, set_size=5,
                 hidden_channels=800, emb_size = 1600):
        super(BytenetEncoder, self).__init__()
        self.num_channels = hidden_channels
        self.kernel_size = kernel_size
        self.layers = nn.Sequential()
        # 128 is size of tokenizer
        # input of shape [batch_size, 128, 128] as [batch_size, tokens, embedding_size]
        self.layers.append(nn.Conv1d(in_channels=emb_size, out_channels=hidden_channels * 2, kernel_size=1))
        # From the Paper:
        # Model has a series of residual blocks of increased dilation rate
        # With unmasked convolutions for the encoder
        for _ in range(num_sets):
            dilation_rate = 1
            for _ in range(set_size):
                # Dilation rate does not exceed a given maximum
                # Example from the paper: 16
                self.layers.append(ResidualBlockReLu(hidden_channels,
                                                     dilation_rate if dilation_rate <= max_dilation_rate else max_dilation_rate,
                                                     masked_kernel_size))
                                # Dilation Rate doubles each layer (starting out at 1)
                dilation_rate = dilation_rate * 2

            # "the network applies one more convolution"
        # Note: The output of the residual layers is 2*input_features, however the output of the final convolutions is not specified in the paper
        # Experimentation needed if it should be 2*input_features or input_features
        self.encoder_out_conv = nn.Conv1d(in_channels=hidden_channels * 2, out_channels=2 * hidden_channels, kernel_size=1)
        # "and ReLU"
        # Not sure if these last 2 layers should be in encoder or just decoder
        # self.layers.append(nn.ReLU())
        # "followed by a convolution"
        # self.layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size))
        # "and a final softmax layer" (probably not for encoder, however paper does not specify)
        # self.layers.append(nn.Softmax(dim=1))

    def forward(self, x):
        # Temporary
        x = x.float()
        for layer in self.layers:
            x = layer(x)
        x = self.encoder_out_conv(x)
        return x


In [10]:
class BytenetDecoder(nn.Module):
    """
    Implementation of the ByteNet Decoder. Default Parameters are set to the ones used in the paper.
    
    :param kernel_size: The kernel size for the unmasked (padded) convolution in the residual block (not important for decoder).
    :param max_dilation_rate: The maximum dilation rate for the convolution layers.
    :param masked_kernel_size: The kernel size for the masked convolution in the residual block.
    :param num_sets: The number of sets of residual blocks.
    :param set_size: The number of residual blocks in each set.
    :param hidden_channels: The number of hidden channels in the model.
    """
    def __init__(self, kernel_size=3, max_dilation_rate=16, masked_kernel_size=3, num_sets=6, set_size=5,
                 hidden_channels=800, output_channels=384):
        super(BytenetDecoder, self).__init__()
        self.num_channels = hidden_channels
        self.kernel_size = kernel_size
        self.layers = nn.Sequential()
        # From the Paper:
        # Model has a series of residual blocks of increased dilation rate
        # With masekd convolution for decoder
        for _ in range(num_sets):
            dilation_rate = 1
            for _ in range(set_size):
                # Dilation Rate doubles each layer (starting out at 1)
                # 1, 2, 4, 8, 16
                # Dilation rate does not exceed a given maximum
                # Example from the paper: 16
                self.layers.append(ResidualBlockReLu(hidden_channels,
                                                     dilation_rate if dilation_rate <= max_dilation_rate else max_dilation_rate,
                                                     masked_kernel_size, decoder=True))
                dilation_rate = dilation_rate * 2

        # "the network applies one more convolution"
        # Note: The output of the residual layers is 2*input_features, however the output of the final convolutions is not specified in the paper
        # Experimentation needed if it should be 2*input_features or input_features
        self.layers.append(nn.Conv1d(hidden_channels * 2, hidden_channels, 1))
        # "and ReLU"
        self.layers.append(nn.ReLU())
        # "followed by a convolution"
        self.layers.append(nn.Conv1d(hidden_channels, output_channels, 1))
        # "and a final softmax layer"
        # self.layers.append(nn.LogSoftmax(dim=-1))

        # self.layers.append(nn.Softmax(dim=1))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


In [11]:
class EncoderDecoderStacking(nn.Module):
    """
    Stacks the encoder and decoder for the ByteNet model.
    This means passing the output of the encoder as input to the decoder.
    
    :param kernel_size: The kernel size for the unmasked (padded) convolution in the residual block (for Encoder).
    :param max_dilation_rate: The maximum dilation rate for the convolution layers.
    :param masked_kernel_size: The kernel size for the masked convolution in the residual block (for Decoder).
    :param num_sets: The number of sets of residual blocks.
    :param set_size: The number of residual blocks in each set.
    :param hidden_channels: The number of hidden channels in the model.
    :param output_channels: The number of output channels in the model (vocab size).

    :return x: The output of the decoder.
    """

    def __init__(self, kernel_size=3, max_dilation_rate=16, masked_kernel_size=3, n_sets=6, blocks_per_set=5,
                 hidden_channels=800, output_channels = 384, emb_size= 1600):
        super(EncoderDecoderStacking, self).__init__()
        self.embed = nn.Embedding(vocab_size, emb_size)
        self.encoder = BytenetEncoder(kernel_size=kernel_size, max_dilation_rate=max_dilation_rate,
                                      masked_kernel_size=masked_kernel_size, num_sets=n_sets, set_size=blocks_per_set,
                                      hidden_channels=hidden_channels, emb_size=emb_size)
        self.decoder = BytenetDecoder(kernel_size=kernel_size, max_dilation_rate=max_dilation_rate,
                                      masked_kernel_size=masked_kernel_size, num_sets=n_sets, set_size=blocks_per_set,
                                      hidden_channels=hidden_channels, output_channels=output_channels)

    def forward(self, x):
        # This permutation is needed for embeddings in pytorch with 1d convolutions
        embed_x = self.embed(x).permute(0, 2, 1)
        x = self.encoder(embed_x)
        x = self.decoder(x)
        return x


In [12]:
class InputEmbeddingTensor:
    """
    Class which enables the embedding of tokens.

    :param vocab_size: The size of the vocabulary as int.
    :param embed_size: The size of the embedding units as int.
    """

    def __init__(self, vocab_size, embed_size):
        super(InputEmbeddingTensor, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        # This is the actual lookup table.
        # A lookup table is an array of data that maps input values to output values
        self.lookup_table_non_zero = nn.Embedding(vocab_size - 1, embed_size)
        init.xavier_uniform_(self.lookup_table_non_zero.weight)

    def embed(self, in_values):
        """
        In this method the first n tokens are embedded via look-up table.
        The n tokens serve as targets for the predictions.

        :param in_values: The train input values from batch, more exact: the tokens
        :return: A embedded tensor of size n × 2d where d is the number of inner
                channels in the network
        """
        lookup_table_zero = torch.zeros(1, self.embed_size).to(in_values.device)
        # Here the both look up tables are combined. The rows with the zeros and the rows
        # with values from the actual lookup table are combined therefore
        lookup_table = torch.cat((lookup_table_zero, self.lookup_table_non_zero.weight.to(device)),
                                 0)  # Move to the same device as inputs
        # Next the input ids are embedded into the lookup table, which means that each id has it own
        # embedding-vector, f.e:
        # id: 5 => [1,5,4]; id:7 => [3,2,9]
        # The input ids are the tokens
        # If a token sequence of 5;7 is used, the resulting matrix is:
        # [1,5,4],[3,2,9]
        return F.embedding(in_values, lookup_table).to(in_values.device)


### Training
The following cells implement the training of the ByteNet model. The model is trained on a part of the WMT2014 english to german dataset.

In [10]:
# Load the data
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cache_dir = 'F:/wmt19_cache'
#    wmt_loader = WMTLoader(split="train", cache_dir=cache_dir)
# Number of workers provides parallel loading
num_workers = 8
#    data_load = DataLoader(wmt_loader, batch_size=32, collate_fn=wmt_loader.collate_fn, num_workers=num_workers)
#    temp = data_load
#
# for batch in wmt_loader:
#     src_batch, tgt_batch = batch
#     break
batch_size = 64
# change as needed
output_dir = 'F:\\wmt19_json'
download_entire_de_en_dataset(batch_size, output_dir, 4)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
wmt_json_loader = WMT19JSONLoader(output_dir)


Downloading dataset-offset: 0
Downloading dataset-offset: 128
Downloading dataset-offset: 192
Downloading dataset-offset: 64
Downloading dataset-offset: 256
Downloading dataset-offset: 320
Using device: cuda


In [11]:
print(wmt_json_loader.tokenizer.get_vocab())

{'<pad>': 0, '</s>': 1, '<unk>': 2, '\x00': 3, '\x01': 4, '\x02': 5, '\x03': 6, '\x04': 7, '\x05': 8, '\x06': 9, '\x07': 10, '\x08': 11, '\t': 12, '\n': 13, '\x0b': 14, '\x0c': 15, '\r': 16, '\x0e': 17, '\x0f': 18, '\x10': 19, '\x11': 20, '\x12': 21, '\x13': 22, '\x14': 23, '\x15': 24, '\x16': 25, '\x17': 26, '\x18': 27, '\x19': 28, '\x1a': 29, '\x1b': 30, '\x1c': 31, '\x1d': 32, '\x1e': 33, '\x1f': 34, ' ': 35, '!': 36, '"': 37, '#': 38, '$': 39, '%': 40, '&': 41, "'": 42, '(': 43, ')': 44, '*': 45, '+': 46, ',': 47, '-': 48, '.': 49, '/': 50, '0': 51, '1': 52, '2': 53, '3': 54, '4': 55, '5': 56, '6': 57, '7': 58, '8': 59, '9': 60, ':': 61, ';': 62, '<': 63, '=': 64, '>': 65, '?': 66, '@': 67, 'A': 68, 'B': 69, 'C': 70, 'D': 71, 'E': 72, 'F': 73, 'G': 74, 'H': 75, 'I': 76, 'J': 77, 'K': 78, 'L': 79, 'M': 80, 'N': 81, 'O': 82, 'P': 83, 'Q': 84, 'R': 85, 'S': 86, 'T': 87, 'U': 88, 'V': 89, 'W': 90, 'X': 91, 'Y': 92, 'Z': 93, '[': 94, '\\': 95, ']': 96, '^': 97, '_': 98, '`': 99, 'a': 10

In [12]:
# HYPERPARAMETERS
num_sets = 3
set_size = 5
embed_size = 1600 # Paper
batch_size = 64

In [13]:
class TranslationDataset(Dataset):
    def __init__(self, source_texts, target_texts):
        self.source_texts = source_texts
        self.target_texts = target_texts

    def __len__(self):
        return len(self.source_texts)

    def __getitem__(self, idx):
        return self.source_texts[idx], self.target_texts[idx]


In [17]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
output_dir = 'F:\\wmt19_json'

print(f'Using device: {device}')
wmt_json_loader = WMT19JSONLoader(output_dir)

cache_dir = 'F:/wmt19_cache'
# wmt_loader = WMTLoader(split="train", cache_dir=cache_dir)
# index = 0
# source, target = wmt_loader[index]
# print("Source:", source)
# print("Target:", target)

# use drive in which to save dataset in cache
tokenized_source_texts, tokenized_target_texts = wmt_json_loader.load_and_tokenize(
    'F:\\wmt19_json\\wmt_19_de_en.json')
src = tokenized_source_texts
trgt = tokenized_target_texts
vocab_size = len(wmt_json_loader.tokenizer.get_vocab())
print(f"Vocabulary size: {vocab_size}")

Using device: cuda
Error when line is decoded: Expecting ',' delimiter: line 1 column 228 (char 227)
Error when line is decoded: Expecting value: line 1 column 1 (char 0)
Error when line is decoded: Expecting ',' delimiter: line 1 column 619 (char 618)
Error when line is decoded: Extra data: line 1 column 2 (char 1)
Error when line is decoded: Expecting value: line 1 column 1 (char 0)
Error when line is decoded: Expecting value: line 1 column 1 (char 0)
Error when line is decoded: Expecting ',' delimiter: line 1 column 745 (char 744)
Error when line is decoded: Expecting value: line 1 column 1 (char 0)
Error when line is decoded: Expecting ',' delimiter: line 1 column 257 (char 256)
Error when line is decoded: Expecting ':' delimiter: line 1 column 138 (char 137)
Error when line is decoded: Expecting value: line 1 column 1 (char 0)
Error when line is decoded: Expecting ',' delimiter: line 1 column 550 (char 549)
Error when line is decoded: Expecting value: line 1 column 1 (char 0)
Erro

In [15]:
print(src[:1])

[tensor([ 75, 108, 104, 117,  35, 107, 100, 119,  35, 113, 108, 102, 107, 119,
         35, 104, 119, 122, 100,  35, 103, 100, 118,  35,  83, 100, 117, 111,
        100, 112, 104, 113, 119,  35, 103, 108, 104,  35,  83, 117, 108, 113,
        125, 108, 115, 108, 104, 113,  35, 103, 104, 117,  35,  87, 117, 100,
        113, 118, 115, 100, 117, 104, 113, 125,  35, 120, 113, 103,  35,  82,
        105, 105, 104, 113, 107, 104, 108, 119,  35, 108, 113,  35, 103, 108,
        104,  35,  87, 100, 119,  35, 120, 112, 106, 104, 118, 104, 119, 125,
        119,  47,  35, 118, 114, 113, 103, 104, 117, 113,  35, 104, 118,  35,
        122, 120, 117, 103, 104, 113,  35, 105, 198, 191, 117,  35, 104, 108,
        113,   1])]


In [16]:
print(trgt[:500])

[tensor([ 87, 107, 108, 118,  35, 122, 100, 118,  35, 113, 114, 119,  35, 119,
        107, 104,  35, 117, 104, 118, 120, 111, 119,  35, 114, 105,  35,  83,
        100, 117, 111, 108, 100, 112, 104, 113, 119,  35, 115, 120, 119, 119,
        108, 113, 106,  35, 108, 113, 119, 114,  35, 115, 117, 100, 102, 119,
        108, 102, 104,  35, 119, 107, 104,  35, 115, 117, 108, 113, 102, 108,
        115, 111, 104, 118,  35, 114, 105,  35, 119, 117, 100, 113, 118, 115,
        100, 117, 104, 113, 102, 124,  35, 100, 113, 103,  35, 114, 115, 104,
        113, 113, 104, 118, 118,  47,  35, 101, 120, 119,  35, 105, 114, 111,
        111, 114, 122, 118,  35, 119, 107, 104,  35, 102, 114, 115, 124, 108,
        113,   1]), tensor([ 81, 114, 122,  35, 119, 107, 100, 119,  35, 119, 107, 104,  35, 117,
        104, 102, 114, 117, 103, 118,  35, 107, 100, 121, 104,  35, 101, 104,
        104, 113,  35, 115, 120, 119,  35, 108, 113, 119, 114,  35, 119, 107,
        104,  35, 115, 120, 101, 111, 108, 

In [1]:
# Idea for what could be used for unfolding, however even with batch size 1, it does not work as I run ot of memory so I could not test it
def decode(decoder, context_vector):
    batch_size = context_vector.shape[0]

    output_sequence = []
    end_of_sequence_mask = torch.zeros((batch_size, 384), dtype=torch.bool, device=context_vector.device)

    for _ in range(128):
        if len(output_sequence) == 0:  # for the first step, use context vector
            decoder_input = context_vector
        else:  # for subsequent steps, use the last output token
            input_token = output_sequence[-1]
            input_token_embedded = encoder_decoder.embed(input_token).permute(0, 2, 1)
            decoder_input = torch.cat([context_vector, input_token_embedded], dim=-1)

        output_token = decoder(decoder_input)
        predicted_token = torch.argmax(output_token, dim=-1)
        end_of_sequence_mask |= (predicted_token.squeeze() == 1)
        predicted_token[end_of_sequence_mask] = 0

        output_sequence.append(predicted_token.squeeze())

    return torch.stack(output_sequence, dim=1)[:, 1:]

In [2]:
# LayerNorm that can take changing input sequence lengths
class DynamicLayerNorm(nn.Module):
    def __init__(self):
        super(DynamicLayerNorm, self).__init__()
        self.norm = None

    def forward(self, x):
        if self.norm is None or self.norm.normalized_shape[0] != x.shape[2]:
            self.norm = nn.LayerNorm(x.shape[2]).to(x.device)
        return self.norm(x)

NameError: name 'nn' is not defined

In [17]:
translation_dataset = TranslationDataset(tokenized_source_texts, tokenized_target_texts)
dataset_size = len(translation_dataset)
train_size = int(0.9 * dataset_size)
test_size = dataset_size - train_size
train_dataset, test_dataset = torch.utils.data.random_split(translation_dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [18]:
criterion = torch.nn.CrossEntropyLoss()
inputEmbedding = InputEmbeddingTensor(vocab_size, embed_size)
# size and all params according to the paper, reduce for performance
encoder_decoder = EncoderDecoderStacking(n_sets=3, blocks_per_set=5, output_channels=vocab_size,emb_size=embed_size).to(
    device)

# Define a loss function and an optimizer
# When changing Loss function, make sure to check if the decoder should have the softmax layer, and adjust that
optimizer = torch.optim.Adam(encoder_decoder.parameters(), lr=0.001)  #  Paper: 0.0003
# Number of epochs
num_epochs = 5


In [19]:
!pip install tensorboard




[notice] A new release of pip is available: 23.2.1 -> 24.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [20]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [21]:
 # Train the model loop
 for epoch in range(num_epochs):
    encoder_decoder.train()
    for i, (inputs, targets) in tqdm(enumerate(train_loader), total=len(train_loader)):
        # Move data to the appropriate device
        # inputs = inputEmbedding.embed(inputs.to(device))  # Ad batch dimension
        inputs = inputs.to(device)  # Add batch
        targets = targets.to(device)  # Add batch

        outputs = encoder_decoder(inputs.to(device))
        # Compute loss
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print loss every 100 steps
        if i % 50 == 0:
            tqdm.write(
            f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item()}')
            writer.add_scalar('Loss/train', loss.item(), (i+1)/len(train_loader))
    encoder_decoder.eval()
    total_val_loss = 0
    with torch.inference_mode():
        for i, (inputs, targets) in enumerate(test_loader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = encoder_decoder(inputs)
            loss = criterion(outputs, targets)
            total_val_loss += loss.item()
        avg_val_loss = total_val_loss / len(test_loader)
        tqdm.write(f'Validation Loss: {avg_val_loss}')
        writer.add_scalar('Loss/val', avg_val_loss, epoch)
        

  0%|          | 1/3551 [00:00<50:37,  1.17it/s]

Epoch [1/5], Step [1/3551], Loss: 5.874279499053955


  1%|▏         | 51/3551 [00:27<31:52,  1.83it/s]

Epoch [1/5], Step [51/3551], Loss: 2.6860663890838623


  3%|▎         | 101/3551 [00:54<32:02,  1.79it/s]

Epoch [1/5], Step [101/3551], Loss: 2.603867530822754


  4%|▍         | 151/3551 [01:21<30:57,  1.83it/s]

Epoch [1/5], Step [151/3551], Loss: 2.613537311553955


  6%|▌         | 201/3551 [01:48<30:48,  1.81it/s]

Epoch [1/5], Step [201/3551], Loss: 2.632150888442993


  7%|▋         | 251/3551 [02:16<30:43,  1.79it/s]

Epoch [1/5], Step [251/3551], Loss: 2.5785038471221924


  8%|▊         | 301/3551 [02:43<29:21,  1.85it/s]

Epoch [1/5], Step [301/3551], Loss: 2.506746768951416


 10%|▉         | 351/3551 [03:10<29:10,  1.83it/s]

Epoch [1/5], Step [351/3551], Loss: 2.5559065341949463


 11%|█▏        | 401/3551 [03:38<30:04,  1.75it/s]

Epoch [1/5], Step [401/3551], Loss: 2.5895378589630127


 13%|█▎        | 451/3551 [04:05<28:00,  1.84it/s]

Epoch [1/5], Step [451/3551], Loss: 2.6078150272369385


 14%|█▍        | 501/3551 [04:32<27:44,  1.83it/s]

Epoch [1/5], Step [501/3551], Loss: 2.4895026683807373


 16%|█▌        | 551/3551 [04:59<27:07,  1.84it/s]

Epoch [1/5], Step [551/3551], Loss: 2.643693685531616


 17%|█▋        | 601/3551 [05:26<26:43,  1.84it/s]

Epoch [1/5], Step [601/3551], Loss: 2.5177292823791504


 18%|█▊        | 651/3551 [05:53<26:14,  1.84it/s]

Epoch [1/5], Step [651/3551], Loss: 2.287750244140625


 20%|█▉        | 701/3551 [06:20<25:55,  1.83it/s]

Epoch [1/5], Step [701/3551], Loss: 2.370805025100708


 21%|██        | 751/3551 [06:47<26:06,  1.79it/s]

Epoch [1/5], Step [751/3551], Loss: 2.5987722873687744


 23%|██▎       | 801/3551 [07:14<24:49,  1.85it/s]

Epoch [1/5], Step [801/3551], Loss: 2.4467763900756836


 24%|██▍       | 851/3551 [07:41<24:21,  1.85it/s]

Epoch [1/5], Step [851/3551], Loss: 2.535731554031372


 25%|██▌       | 901/3551 [08:08<24:18,  1.82it/s]

Epoch [1/5], Step [901/3551], Loss: 2.4842069149017334


 27%|██▋       | 951/3551 [08:35<23:50,  1.82it/s]

Epoch [1/5], Step [951/3551], Loss: 2.5968689918518066


 28%|██▊       | 1001/3551 [09:02<23:06,  1.84it/s]

Epoch [1/5], Step [1001/3551], Loss: 2.5029633045196533


 30%|██▉       | 1051/3551 [09:29<22:57,  1.81it/s]

Epoch [1/5], Step [1051/3551], Loss: 2.53466796875


 31%|███       | 1101/3551 [09:57<22:24,  1.82it/s]

Epoch [1/5], Step [1101/3551], Loss: 2.325892686843872


 32%|███▏      | 1151/3551 [10:24<21:59,  1.82it/s]

Epoch [1/5], Step [1151/3551], Loss: 2.563166856765747


 34%|███▍      | 1201/3551 [10:51<21:32,  1.82it/s]

Epoch [1/5], Step [1201/3551], Loss: 2.6073241233825684


 35%|███▌      | 1251/3551 [11:18<20:51,  1.84it/s]

Epoch [1/5], Step [1251/3551], Loss: 2.6762595176696777


 37%|███▋      | 1301/3551 [11:46<20:53,  1.79it/s]

Epoch [1/5], Step [1301/3551], Loss: 2.4904122352600098


 38%|███▊      | 1351/3551 [12:14<20:04,  1.83it/s]

Epoch [1/5], Step [1351/3551], Loss: 2.528141736984253


 39%|███▉      | 1401/3551 [12:41<19:56,  1.80it/s]

Epoch [1/5], Step [1401/3551], Loss: 2.4934821128845215


 41%|████      | 1451/3551 [13:09<19:12,  1.82it/s]

Epoch [1/5], Step [1451/3551], Loss: 2.5008108615875244


 42%|████▏     | 1501/3551 [13:36<18:42,  1.83it/s]

Epoch [1/5], Step [1501/3551], Loss: 2.545868158340454


 44%|████▎     | 1551/3551 [14:03<18:09,  1.84it/s]

Epoch [1/5], Step [1551/3551], Loss: 2.515972852706909


 45%|████▌     | 1601/3551 [14:31<17:53,  1.82it/s]

Epoch [1/5], Step [1601/3551], Loss: 2.555776357650757


 46%|████▋     | 1651/3551 [14:58<17:32,  1.81it/s]

Epoch [1/5], Step [1651/3551], Loss: 2.484743595123291


 48%|████▊     | 1701/3551 [15:25<17:46,  1.73it/s]

Epoch [1/5], Step [1701/3551], Loss: 2.583036184310913


 49%|████▉     | 1751/3551 [15:52<16:19,  1.84it/s]

Epoch [1/5], Step [1751/3551], Loss: 2.5414388179779053


 51%|█████     | 1801/3551 [16:19<15:55,  1.83it/s]

Epoch [1/5], Step [1801/3551], Loss: 2.5452709197998047


 52%|█████▏    | 1851/3551 [16:46<15:35,  1.82it/s]

Epoch [1/5], Step [1851/3551], Loss: 2.468043327331543


 54%|█████▎    | 1901/3551 [17:14<14:54,  1.85it/s]

Epoch [1/5], Step [1901/3551], Loss: 2.452723264694214


 55%|█████▍    | 1951/3551 [17:41<14:37,  1.82it/s]

Epoch [1/5], Step [1951/3551], Loss: 2.442843437194824


 56%|█████▋    | 2001/3551 [18:08<14:18,  1.80it/s]

Epoch [1/5], Step [2001/3551], Loss: 2.574573040008545


 58%|█████▊    | 2051/3551 [18:35<13:36,  1.84it/s]

Epoch [1/5], Step [2051/3551], Loss: 2.405710458755493


 59%|█████▉    | 2101/3551 [19:02<13:05,  1.85it/s]

Epoch [1/5], Step [2101/3551], Loss: 2.5057902336120605


 61%|██████    | 2151/3551 [19:29<13:14,  1.76it/s]

Epoch [1/5], Step [2151/3551], Loss: 2.4921913146972656


 62%|██████▏   | 2201/3551 [19:58<12:44,  1.77it/s]

Epoch [1/5], Step [2201/3551], Loss: 2.448138475418091


 63%|██████▎   | 2251/3551 [20:26<12:10,  1.78it/s]

Epoch [1/5], Step [2251/3551], Loss: 2.6758856773376465


 65%|██████▍   | 2301/3551 [20:53<11:50,  1.76it/s]

Epoch [1/5], Step [2301/3551], Loss: 2.5243256092071533


 66%|██████▌   | 2351/3551 [21:21<11:01,  1.81it/s]

Epoch [1/5], Step [2351/3551], Loss: 2.5018162727355957


 68%|██████▊   | 2401/3551 [21:48<10:35,  1.81it/s]

Epoch [1/5], Step [2401/3551], Loss: 2.46651029586792


 69%|██████▉   | 2451/3551 [22:16<10:03,  1.82it/s]

Epoch [1/5], Step [2451/3551], Loss: 2.4105265140533447


 70%|███████   | 2501/3551 [22:43<09:35,  1.82it/s]

Epoch [1/5], Step [2501/3551], Loss: 2.544501543045044


 72%|███████▏  | 2551/3551 [23:10<09:26,  1.76it/s]

Epoch [1/5], Step [2551/3551], Loss: 2.4084837436676025


 73%|███████▎  | 2601/3551 [23:38<08:46,  1.81it/s]

Epoch [1/5], Step [2601/3551], Loss: 2.440166711807251


 75%|███████▍  | 2651/3551 [24:06<08:12,  1.83it/s]

Epoch [1/5], Step [2651/3551], Loss: 2.3053905963897705


 76%|███████▌  | 2701/3551 [24:33<07:54,  1.79it/s]

Epoch [1/5], Step [2701/3551], Loss: 2.4461276531219482


 77%|███████▋  | 2751/3551 [25:01<07:41,  1.73it/s]

Epoch [1/5], Step [2751/3551], Loss: 2.500408411026001


 79%|███████▉  | 2801/3551 [25:29<07:07,  1.75it/s]

Epoch [1/5], Step [2801/3551], Loss: 2.38733172416687


 80%|████████  | 2851/3551 [25:57<06:38,  1.76it/s]

Epoch [1/5], Step [2851/3551], Loss: 2.4048938751220703


 82%|████████▏ | 2901/3551 [26:25<05:59,  1.81it/s]

Epoch [1/5], Step [2901/3551], Loss: 2.5767698287963867


 83%|████████▎ | 2951/3551 [26:52<05:34,  1.79it/s]

Epoch [1/5], Step [2951/3551], Loss: 2.513230800628662


 85%|████████▍ | 3001/3551 [27:20<05:13,  1.75it/s]

Epoch [1/5], Step [3001/3551], Loss: 2.5411903858184814


 86%|████████▌ | 3051/3551 [27:49<04:51,  1.72it/s]

Epoch [1/5], Step [3051/3551], Loss: 2.5570294857025146


 87%|████████▋ | 3101/3551 [28:17<04:14,  1.77it/s]

Epoch [1/5], Step [3101/3551], Loss: 2.4853034019470215


 89%|████████▊ | 3151/3551 [28:45<03:48,  1.75it/s]

Epoch [1/5], Step [3151/3551], Loss: 2.571009397506714


 90%|█████████ | 3201/3551 [29:14<03:18,  1.77it/s]

Epoch [1/5], Step [3201/3551], Loss: 2.468646764755249


 92%|█████████▏| 3251/3551 [29:41<02:44,  1.82it/s]

Epoch [1/5], Step [3251/3551], Loss: 2.5347537994384766


 93%|█████████▎| 3301/3551 [30:09<02:21,  1.77it/s]

Epoch [1/5], Step [3301/3551], Loss: 2.4041059017181396


 94%|█████████▍| 3351/3551 [30:37<01:50,  1.81it/s]

Epoch [1/5], Step [3351/3551], Loss: 2.6321582794189453


 96%|█████████▌| 3401/3551 [31:05<01:26,  1.73it/s]

Epoch [1/5], Step [3401/3551], Loss: 2.4252851009368896


 97%|█████████▋| 3451/3551 [31:32<00:55,  1.81it/s]

Epoch [1/5], Step [3451/3551], Loss: 2.3623440265655518


 99%|█████████▊| 3501/3551 [32:00<00:27,  1.81it/s]

Epoch [1/5], Step [3501/3551], Loss: 2.5065364837646484


100%|██████████| 3551/3551 [32:27<00:00,  1.82it/s]


Epoch [1/5], Step [3551/3551], Loss: 2.677116632461548
Validation Loss: 2.472609027427963


  0%|          | 1/3551 [00:00<32:45,  1.81it/s]

Epoch [2/5], Step [1/3551], Loss: 2.494256019592285


  1%|▏         | 51/3551 [00:27<32:28,  1.80it/s]

Epoch [2/5], Step [51/3551], Loss: 2.358346700668335


  3%|▎         | 101/3551 [00:55<32:17,  1.78it/s]

Epoch [2/5], Step [101/3551], Loss: 2.537179470062256


  4%|▍         | 151/3551 [01:23<32:36,  1.74it/s]

Epoch [2/5], Step [151/3551], Loss: 2.411990165710449


  6%|▌         | 201/3551 [01:52<44:00,  1.27it/s]

Epoch [2/5], Step [201/3551], Loss: 2.5170958042144775


  7%|▋         | 251/3551 [02:26<33:45,  1.63it/s]

Epoch [2/5], Step [251/3551], Loss: 2.349385976791382


  8%|▊         | 301/3551 [03:03<47:55,  1.13it/s]

Epoch [2/5], Step [301/3551], Loss: 2.609746217727661


 10%|▉         | 351/3551 [03:36<36:09,  1.47it/s]

Epoch [2/5], Step [351/3551], Loss: 2.505084753036499


 11%|█▏        | 401/3551 [04:08<32:39,  1.61it/s]

Epoch [2/5], Step [401/3551], Loss: 2.3499257564544678


 13%|█▎        | 451/3551 [04:41<33:46,  1.53it/s]

Epoch [2/5], Step [451/3551], Loss: 2.523761034011841


 14%|█▍        | 501/3551 [05:12<31:32,  1.61it/s]

Epoch [2/5], Step [501/3551], Loss: 2.4324424266815186


 16%|█▌        | 551/3551 [05:45<31:42,  1.58it/s]

Epoch [2/5], Step [551/3551], Loss: 2.548150062561035


 17%|█▋        | 601/3551 [06:21<39:36,  1.24it/s]

Epoch [2/5], Step [601/3551], Loss: 2.4298758506774902


 18%|█▊        | 651/3551 [06:58<32:40,  1.48it/s]

Epoch [2/5], Step [651/3551], Loss: 2.4512040615081787


 20%|█▉        | 701/3551 [07:30<31:46,  1.49it/s]

Epoch [2/5], Step [701/3551], Loss: 2.459383487701416


 21%|██        | 751/3551 [08:07<30:24,  1.53it/s]

Epoch [2/5], Step [751/3551], Loss: 2.520547389984131


 23%|██▎       | 801/3551 [08:48<29:29,  1.55it/s]

Epoch [2/5], Step [801/3551], Loss: 2.405019521713257


 24%|██▍       | 851/3551 [09:26<29:46,  1.51it/s]

Epoch [2/5], Step [851/3551], Loss: 2.6540169715881348


 25%|██▌       | 901/3551 [10:00<32:53,  1.34it/s]

Epoch [2/5], Step [901/3551], Loss: 2.552517890930176


 27%|██▋       | 951/3551 [10:42<50:24,  1.16s/it]

Epoch [2/5], Step [951/3551], Loss: 2.517550468444824


 28%|██▊       | 1001/3551 [11:24<43:41,  1.03s/it]

Epoch [2/5], Step [1001/3551], Loss: 2.4128611087799072


 30%|██▉       | 1051/3551 [12:09<25:42,  1.62it/s]

Epoch [2/5], Step [1051/3551], Loss: 2.464664936065674


 31%|███       | 1101/3551 [12:46<29:55,  1.36it/s]

Epoch [2/5], Step [1101/3551], Loss: 2.3766589164733887


 32%|███▏      | 1151/3551 [13:22<32:36,  1.23it/s]

Epoch [2/5], Step [1151/3551], Loss: 2.482229709625244


 34%|███▍      | 1201/3551 [13:56<25:26,  1.54it/s]

Epoch [2/5], Step [1201/3551], Loss: 2.357384204864502


 35%|███▌      | 1251/3551 [14:29<24:46,  1.55it/s]

Epoch [2/5], Step [1251/3551], Loss: 2.4759130477905273


 37%|███▋      | 1301/3551 [15:04<25:00,  1.50it/s]

Epoch [2/5], Step [1301/3551], Loss: 2.43741774559021


 38%|███▊      | 1351/3551 [15:37<23:59,  1.53it/s]

Epoch [2/5], Step [1351/3551], Loss: 2.4414618015289307


 39%|███▉      | 1401/3551 [16:13<23:22,  1.53it/s]

Epoch [2/5], Step [1401/3551], Loss: 2.3750083446502686


 41%|████      | 1451/3551 [16:47<23:45,  1.47it/s]

Epoch [2/5], Step [1451/3551], Loss: 2.3814964294433594


 42%|████▏     | 1501/3551 [17:22<26:31,  1.29it/s]

Epoch [2/5], Step [1501/3551], Loss: 2.440310478210449


 44%|████▎     | 1551/3551 [18:02<22:10,  1.50it/s]

Epoch [2/5], Step [1551/3551], Loss: 2.521256923675537


 45%|████▌     | 1601/3551 [18:36<20:39,  1.57it/s]

Epoch [2/5], Step [1601/3551], Loss: 2.5201971530914307


 46%|████▋     | 1651/3551 [19:16<20:50,  1.52it/s]

Epoch [2/5], Step [1651/3551], Loss: 2.3889944553375244


 48%|████▊     | 1701/3551 [19:56<31:14,  1.01s/it]

Epoch [2/5], Step [1701/3551], Loss: 2.515343189239502


 49%|████▉     | 1751/3551 [20:34<20:52,  1.44it/s]

Epoch [2/5], Step [1751/3551], Loss: 2.5665714740753174


 51%|█████     | 1801/3551 [21:08<22:44,  1.28it/s]

Epoch [2/5], Step [1801/3551], Loss: 2.5205440521240234


 52%|█████▏    | 1851/3551 [21:42<20:04,  1.41it/s]

Epoch [2/5], Step [1851/3551], Loss: 2.4301342964172363


 54%|█████▎    | 1901/3551 [22:17<17:26,  1.58it/s]

Epoch [2/5], Step [1901/3551], Loss: 2.518019199371338


 55%|█████▍    | 1951/3551 [22:52<19:33,  1.36it/s]

Epoch [2/5], Step [1951/3551], Loss: 2.3188838958740234


 56%|█████▋    | 2001/3551 [23:25<17:58,  1.44it/s]

Epoch [2/5], Step [2001/3551], Loss: 2.5274245738983154


 58%|█████▊    | 2051/3551 [24:02<24:31,  1.02it/s]

Epoch [2/5], Step [2051/3551], Loss: 2.5117993354797363


 59%|█████▉    | 2101/3551 [24:38<17:41,  1.37it/s]

Epoch [2/5], Step [2101/3551], Loss: 2.557192802429199


 61%|██████    | 2151/3551 [25:15<19:24,  1.20it/s]

Epoch [2/5], Step [2151/3551], Loss: 2.5969157218933105


 62%|██████▏   | 2201/3551 [25:48<14:07,  1.59it/s]

Epoch [2/5], Step [2201/3551], Loss: 2.4098777770996094


 63%|██████▎   | 2251/3551 [26:18<13:10,  1.65it/s]

Epoch [2/5], Step [2251/3551], Loss: 2.4969322681427


 65%|██████▍   | 2301/3551 [26:48<14:43,  1.42it/s]

Epoch [2/5], Step [2301/3551], Loss: 2.4259538650512695


 66%|██████▌   | 2351/3551 [27:18<11:49,  1.69it/s]

Epoch [2/5], Step [2351/3551], Loss: 2.3403661251068115


 68%|██████▊   | 2401/3551 [27:48<11:28,  1.67it/s]

Epoch [2/5], Step [2401/3551], Loss: 2.461550235748291


 69%|██████▉   | 2451/3551 [28:17<10:57,  1.67it/s]

Epoch [2/5], Step [2451/3551], Loss: 2.454007625579834


 70%|███████   | 2501/3551 [28:54<14:16,  1.23it/s]

Epoch [2/5], Step [2501/3551], Loss: 2.466019630432129


 72%|███████▏  | 2551/3551 [29:35<13:36,  1.22it/s]

Epoch [2/5], Step [2551/3551], Loss: 2.5419952869415283


 73%|███████▎  | 2601/3551 [30:16<13:18,  1.19it/s]

Epoch [2/5], Step [2601/3551], Loss: 2.454570770263672


 75%|███████▍  | 2651/3551 [30:59<13:01,  1.15it/s]

Epoch [2/5], Step [2651/3551], Loss: 2.3878135681152344


 76%|███████▌  | 2701/3551 [31:42<12:13,  1.16it/s]

Epoch [2/5], Step [2701/3551], Loss: 2.5399880409240723


 77%|███████▋  | 2751/3551 [32:24<11:32,  1.16it/s]

Epoch [2/5], Step [2751/3551], Loss: 2.4130282402038574


 79%|███████▉  | 2801/3551 [33:07<10:23,  1.20it/s]

Epoch [2/5], Step [2801/3551], Loss: 2.606778621673584


 80%|████████  | 2851/3551 [33:48<09:47,  1.19it/s]

Epoch [2/5], Step [2851/3551], Loss: 2.4339966773986816


 82%|████████▏ | 2901/3551 [34:29<09:22,  1.16it/s]

Epoch [2/5], Step [2901/3551], Loss: 2.5462300777435303


 83%|████████▎ | 2951/3551 [35:10<08:16,  1.21it/s]

Epoch [2/5], Step [2951/3551], Loss: 2.658470869064331


 85%|████████▍ | 3001/3551 [35:55<08:17,  1.11it/s]

Epoch [2/5], Step [3001/3551], Loss: 2.592412233352661


 86%|████████▌ | 3051/3551 [36:36<07:01,  1.19it/s]

Epoch [2/5], Step [3051/3551], Loss: 2.479116439819336


 87%|████████▋ | 3101/3551 [37:18<06:16,  1.20it/s]

Epoch [2/5], Step [3101/3551], Loss: 2.622445583343506


 89%|████████▊ | 3151/3551 [37:59<05:22,  1.24it/s]

Epoch [2/5], Step [3151/3551], Loss: 2.471534490585327


 90%|█████████ | 3201/3551 [38:41<04:59,  1.17it/s]

Epoch [2/5], Step [3201/3551], Loss: 2.3571715354919434


 92%|█████████▏| 3251/3551 [39:22<04:11,  1.19it/s]

Epoch [2/5], Step [3251/3551], Loss: 2.369288206100464


 93%|█████████▎| 3301/3551 [40:03<03:30,  1.19it/s]

Epoch [2/5], Step [3301/3551], Loss: 2.33292555809021


 94%|█████████▍| 3351/3551 [40:44<02:44,  1.21it/s]

Epoch [2/5], Step [3351/3551], Loss: 2.3339903354644775


 96%|█████████▌| 3401/3551 [41:24<01:58,  1.27it/s]

Epoch [2/5], Step [3401/3551], Loss: 2.448197841644287


 97%|█████████▋| 3451/3551 [42:04<01:19,  1.26it/s]

Epoch [2/5], Step [3451/3551], Loss: 2.5420141220092773


 99%|█████████▊| 3501/3551 [42:43<00:31,  1.60it/s]

Epoch [2/5], Step [3501/3551], Loss: 2.6272950172424316


100%|██████████| 3551/3551 [43:16<00:00,  1.37it/s]


Epoch [2/5], Step [3551/3551], Loss: 1.469866394996643
Validation Loss: 2.4517747474622125


  0%|          | 1/3551 [00:00<37:10,  1.59it/s]

Epoch [3/5], Step [1/3551], Loss: 2.468337059020996


  1%|▏         | 51/3551 [00:32<35:54,  1.62it/s]

Epoch [3/5], Step [51/3551], Loss: 2.408777952194214


  3%|▎         | 101/3551 [01:04<33:29,  1.72it/s]

Epoch [3/5], Step [101/3551], Loss: 2.490530252456665


  4%|▍         | 151/3551 [01:38<43:55,  1.29it/s]

Epoch [3/5], Step [151/3551], Loss: 2.449276924133301


  6%|▌         | 201/3551 [02:11<38:16,  1.46it/s]

Epoch [3/5], Step [201/3551], Loss: 2.4751787185668945


  7%|▋         | 251/3551 [02:46<38:45,  1.42it/s]

Epoch [3/5], Step [251/3551], Loss: 2.3216686248779297


  8%|▊         | 301/3551 [03:28<49:19,  1.10it/s]

Epoch [3/5], Step [301/3551], Loss: 2.3752975463867188


 10%|▉         | 351/3551 [04:08<51:39,  1.03it/s]

Epoch [3/5], Step [351/3551], Loss: 2.4226958751678467


 11%|█▏        | 401/3551 [04:43<54:20,  1.03s/it]

Epoch [3/5], Step [401/3551], Loss: 2.278576135635376


 13%|█▎        | 451/3551 [05:27<51:53,  1.00s/it]

Epoch [3/5], Step [451/3551], Loss: 2.4513065814971924


 14%|█▍        | 501/3551 [06:08<40:43,  1.25it/s]

Epoch [3/5], Step [501/3551], Loss: 2.494425058364868


 16%|█▌        | 551/3551 [06:49<48:31,  1.03it/s]

Epoch [3/5], Step [551/3551], Loss: 2.3765108585357666


 17%|█▋        | 601/3551 [07:33<31:22,  1.57it/s]  

Epoch [3/5], Step [601/3551], Loss: 2.3054709434509277


 18%|█▊        | 651/3551 [08:07<37:08,  1.30it/s]

Epoch [3/5], Step [651/3551], Loss: 2.3429172039031982


 20%|█▉        | 701/3551 [08:42<34:29,  1.38it/s]

Epoch [3/5], Step [701/3551], Loss: 2.333503484725952


 21%|██        | 751/3551 [09:18<36:30,  1.28it/s]

Epoch [3/5], Step [751/3551], Loss: 2.3649542331695557


 23%|██▎       | 801/3551 [09:56<29:24,  1.56it/s]

Epoch [3/5], Step [801/3551], Loss: 2.323568820953369


 24%|██▍       | 851/3551 [10:26<28:05,  1.60it/s]

Epoch [3/5], Step [851/3551], Loss: 2.517500638961792


 25%|██▌       | 901/3551 [10:59<33:51,  1.30it/s]

Epoch [3/5], Step [901/3551], Loss: 2.5006489753723145


 27%|██▋       | 951/3551 [11:32<26:26,  1.64it/s]

Epoch [3/5], Step [951/3551], Loss: 2.4240496158599854


 28%|██▊       | 1001/3551 [12:15<49:36,  1.17s/it]

Epoch [3/5], Step [1001/3551], Loss: 2.4317030906677246


 30%|██▉       | 1051/3551 [12:57<28:32,  1.46it/s]

Epoch [3/5], Step [1051/3551], Loss: 2.462522506713867


 31%|███       | 1101/3551 [13:39<34:24,  1.19it/s]

Epoch [3/5], Step [1101/3551], Loss: 2.5181384086608887


 32%|███▏      | 1151/3551 [14:12<25:14,  1.59it/s]

Epoch [3/5], Step [1151/3551], Loss: 2.4276044368743896


 34%|███▍      | 1201/3551 [14:43<24:10,  1.62it/s]

Epoch [3/5], Step [1201/3551], Loss: 2.4819416999816895


 35%|███▌      | 1251/3551 [15:13<23:18,  1.65it/s]

Epoch [3/5], Step [1251/3551], Loss: 2.4686851501464844


 37%|███▋      | 1301/3551 [15:43<23:11,  1.62it/s]

Epoch [3/5], Step [1301/3551], Loss: 2.504477024078369


 38%|███▊      | 1351/3551 [16:13<22:32,  1.63it/s]

Epoch [3/5], Step [1351/3551], Loss: 2.444460391998291


 39%|███▉      | 1401/3551 [16:43<21:40,  1.65it/s]

Epoch [3/5], Step [1401/3551], Loss: 2.5618743896484375


 41%|████      | 1451/3551 [17:14<21:27,  1.63it/s]

Epoch [3/5], Step [1451/3551], Loss: 2.634786605834961


 42%|████▏     | 1501/3551 [17:43<20:32,  1.66it/s]

Epoch [3/5], Step [1501/3551], Loss: 2.4487195014953613


 44%|████▎     | 1551/3551 [18:13<20:37,  1.62it/s]

Epoch [3/5], Step [1551/3551], Loss: 2.5812089443206787


 45%|████▌     | 1601/3551 [18:44<20:09,  1.61it/s]

Epoch [3/5], Step [1601/3551], Loss: 2.4220099449157715


 46%|████▋     | 1651/3551 [19:15<19:43,  1.61it/s]

Epoch [3/5], Step [1651/3551], Loss: 2.3748390674591064


 48%|████▊     | 1701/3551 [19:45<18:27,  1.67it/s]

Epoch [3/5], Step [1701/3551], Loss: 2.3574399948120117


 49%|████▉     | 1751/3551 [20:15<18:32,  1.62it/s]

Epoch [3/5], Step [1751/3551], Loss: 2.507389783859253


 51%|█████     | 1801/3551 [20:45<18:29,  1.58it/s]

Epoch [3/5], Step [1801/3551], Loss: 2.4715576171875


 52%|█████▏    | 1851/3551 [21:15<16:57,  1.67it/s]

Epoch [3/5], Step [1851/3551], Loss: 2.4894115924835205


 54%|█████▎    | 1901/3551 [21:46<17:03,  1.61it/s]

Epoch [3/5], Step [1901/3551], Loss: 2.3646440505981445


 55%|█████▍    | 1951/3551 [22:16<16:23,  1.63it/s]

Epoch [3/5], Step [1951/3551], Loss: 2.545091152191162


 56%|█████▋    | 2001/3551 [22:47<15:53,  1.62it/s]

Epoch [3/5], Step [2001/3551], Loss: 2.3303966522216797


 58%|█████▊    | 2051/3551 [23:17<15:21,  1.63it/s]

Epoch [3/5], Step [2051/3551], Loss: 2.442039728164673


 59%|█████▉    | 2101/3551 [23:47<14:32,  1.66it/s]

Epoch [3/5], Step [2101/3551], Loss: 2.497882843017578


 61%|██████    | 2151/3551 [24:16<14:06,  1.65it/s]

Epoch [3/5], Step [2151/3551], Loss: 2.409789800643921


 62%|██████▏   | 2201/3551 [24:46<13:43,  1.64it/s]

Epoch [3/5], Step [2201/3551], Loss: 2.4090585708618164


 63%|██████▎   | 2251/3551 [25:16<13:21,  1.62it/s]

Epoch [3/5], Step [2251/3551], Loss: 2.503422498703003


 65%|██████▍   | 2301/3551 [25:46<12:45,  1.63it/s]

Epoch [3/5], Step [2301/3551], Loss: 2.450700044631958


 66%|██████▌   | 2351/3551 [26:17<12:24,  1.61it/s]

Epoch [3/5], Step [2351/3551], Loss: 2.545867443084717


 68%|██████▊   | 2401/3551 [26:47<11:44,  1.63it/s]

Epoch [3/5], Step [2401/3551], Loss: 2.391794443130493


 69%|██████▉   | 2451/3551 [27:18<11:28,  1.60it/s]

Epoch [3/5], Step [2451/3551], Loss: 2.460563898086548


 70%|███████   | 2501/3551 [27:48<10:46,  1.62it/s]

Epoch [3/5], Step [2501/3551], Loss: 2.492554187774658


 72%|███████▏  | 2551/3551 [28:19<10:15,  1.62it/s]

Epoch [3/5], Step [2551/3551], Loss: 2.5037968158721924


 73%|███████▎  | 2601/3551 [28:48<09:23,  1.69it/s]

Epoch [3/5], Step [2601/3551], Loss: 2.4076733589172363


 75%|███████▍  | 2651/3551 [29:18<09:10,  1.63it/s]

Epoch [3/5], Step [2651/3551], Loss: 2.5544800758361816


 76%|███████▌  | 2701/3551 [29:48<08:41,  1.63it/s]

Epoch [3/5], Step [2701/3551], Loss: 2.479673147201538


 77%|███████▋  | 2751/3551 [30:19<08:13,  1.62it/s]

Epoch [3/5], Step [2751/3551], Loss: 2.446486234664917


 79%|███████▉  | 2801/3551 [30:49<07:45,  1.61it/s]

Epoch [3/5], Step [2801/3551], Loss: 2.3800365924835205


 80%|████████  | 2851/3551 [31:22<07:23,  1.58it/s]

Epoch [3/5], Step [2851/3551], Loss: 2.561187982559204


 82%|████████▏ | 2901/3551 [31:52<06:40,  1.62it/s]

Epoch [3/5], Step [2901/3551], Loss: 2.405590295791626


 83%|████████▎ | 2951/3551 [32:23<06:10,  1.62it/s]

Epoch [3/5], Step [2951/3551], Loss: 2.404393434524536


 85%|████████▍ | 3001/3551 [32:53<05:33,  1.65it/s]

Epoch [3/5], Step [3001/3551], Loss: 2.293161630630493


 86%|████████▌ | 3051/3551 [33:23<05:05,  1.64it/s]

Epoch [3/5], Step [3051/3551], Loss: 2.5190138816833496


 87%|████████▋ | 3101/3551 [33:54<04:34,  1.64it/s]

Epoch [3/5], Step [3101/3551], Loss: 2.48305344581604


 89%|████████▊ | 3151/3551 [34:23<04:00,  1.66it/s]

Epoch [3/5], Step [3151/3551], Loss: 2.446277379989624


 90%|█████████ | 3201/3551 [34:53<03:32,  1.65it/s]

Epoch [3/5], Step [3201/3551], Loss: 2.407923460006714


 92%|█████████▏| 3251/3551 [35:23<03:05,  1.61it/s]

Epoch [3/5], Step [3251/3551], Loss: 2.6499416828155518


 93%|█████████▎| 3301/3551 [35:54<02:49,  1.48it/s]

Epoch [3/5], Step [3301/3551], Loss: 2.467775344848633


 94%|█████████▍| 3351/3551 [36:25<02:01,  1.64it/s]

Epoch [3/5], Step [3351/3551], Loss: 2.4871327877044678


 96%|█████████▌| 3401/3551 [36:55<01:31,  1.64it/s]

Epoch [3/5], Step [3401/3551], Loss: 2.4619011878967285


 97%|█████████▋| 3451/3551 [37:26<01:01,  1.62it/s]

Epoch [3/5], Step [3451/3551], Loss: 2.4009716510772705


 99%|█████████▊| 3501/3551 [37:56<00:30,  1.63it/s]

Epoch [3/5], Step [3501/3551], Loss: 2.3631632328033447


100%|██████████| 3551/3551 [38:26<00:00,  1.54it/s]


Epoch [3/5], Step [3551/3551], Loss: 2.164804220199585
Validation Loss: 2.4500190390816217


  0%|          | 1/3551 [00:00<38:23,  1.54it/s]

Epoch [4/5], Step [1/3551], Loss: 2.4277961254119873


  1%|▏         | 51/3551 [00:30<34:53,  1.67it/s]

Epoch [4/5], Step [51/3551], Loss: 2.5398168563842773


  3%|▎         | 101/3551 [01:00<35:08,  1.64it/s]

Epoch [4/5], Step [101/3551], Loss: 2.3583016395568848


  4%|▍         | 151/3551 [01:31<34:42,  1.63it/s]

Epoch [4/5], Step [151/3551], Loss: 2.4253480434417725


  6%|▌         | 201/3551 [02:01<34:00,  1.64it/s]

Epoch [4/5], Step [201/3551], Loss: 2.329566240310669


  7%|▋         | 251/3551 [02:33<39:48,  1.38it/s]

Epoch [4/5], Step [251/3551], Loss: 2.4421374797821045


  8%|▊         | 301/3551 [03:06<36:11,  1.50it/s]

Epoch [4/5], Step [301/3551], Loss: 2.459641695022583


 10%|▉         | 351/3551 [03:38<33:18,  1.60it/s]

Epoch [4/5], Step [351/3551], Loss: 2.46999454498291


 11%|█▏        | 401/3551 [04:18<37:51,  1.39it/s]

Epoch [4/5], Step [401/3551], Loss: 2.492246150970459


 13%|█▎        | 451/3551 [04:54<36:33,  1.41it/s]

Epoch [4/5], Step [451/3551], Loss: 2.4722025394439697


 14%|█▍        | 501/3551 [05:26<32:29,  1.56it/s]

Epoch [4/5], Step [501/3551], Loss: 2.5053882598876953


 16%|█▌        | 551/3551 [06:01<31:06,  1.61it/s]

Epoch [4/5], Step [551/3551], Loss: 2.2801156044006348


 17%|█▋        | 601/3551 [06:42<52:45,  1.07s/it]

Epoch [4/5], Step [601/3551], Loss: 2.4842734336853027


 18%|█▊        | 651/3551 [07:29<36:38,  1.32it/s]

Epoch [4/5], Step [651/3551], Loss: 2.5668089389801025


 20%|█▉        | 701/3551 [08:06<34:01,  1.40it/s]

Epoch [4/5], Step [701/3551], Loss: 2.3175406455993652


 21%|██        | 751/3551 [08:40<37:07,  1.26it/s]

Epoch [4/5], Step [751/3551], Loss: 2.4494524002075195


 23%|██▎       | 801/3551 [09:16<36:42,  1.25it/s]

Epoch [4/5], Step [801/3551], Loss: 2.394958257675171


 24%|██▍       | 851/3551 [09:49<31:19,  1.44it/s]

Epoch [4/5], Step [851/3551], Loss: 2.4166722297668457


 25%|██▌       | 901/3551 [10:25<45:48,  1.04s/it]

Epoch [4/5], Step [901/3551], Loss: 2.3432559967041016


 27%|██▋       | 951/3551 [11:01<30:19,  1.43it/s]

Epoch [4/5], Step [951/3551], Loss: 2.4579274654388428


 28%|██▊       | 1001/3551 [11:36<29:51,  1.42it/s]

Epoch [4/5], Step [1001/3551], Loss: 2.3331451416015625


 30%|██▉       | 1051/3551 [12:11<27:58,  1.49it/s]

Epoch [4/5], Step [1051/3551], Loss: 2.3483047485351562


 31%|███       | 1101/3551 [12:46<30:01,  1.36it/s]

Epoch [4/5], Step [1101/3551], Loss: 2.4297471046447754


 32%|███▏      | 1151/3551 [13:20<29:16,  1.37it/s]

Epoch [4/5], Step [1151/3551], Loss: 2.1846938133239746


 34%|███▍      | 1201/3551 [13:55<27:52,  1.41it/s]

Epoch [4/5], Step [1201/3551], Loss: 2.465590238571167


 35%|███▌      | 1251/3551 [14:29<27:22,  1.40it/s]

Epoch [4/5], Step [1251/3551], Loss: 2.3645179271698


 37%|███▋      | 1301/3551 [15:04<25:42,  1.46it/s]

Epoch [4/5], Step [1301/3551], Loss: 2.320164918899536


 38%|███▊      | 1351/3551 [15:38<25:12,  1.45it/s]

Epoch [4/5], Step [1351/3551], Loss: 2.67384934425354


 39%|███▉      | 1401/3551 [16:12<25:00,  1.43it/s]

Epoch [4/5], Step [1401/3551], Loss: 2.395723342895508


 41%|████      | 1451/3551 [16:54<30:53,  1.13it/s]

Epoch [4/5], Step [1451/3551], Loss: 2.4171221256256104


 42%|████▏     | 1501/3551 [17:30<24:29,  1.39it/s]

Epoch [4/5], Step [1501/3551], Loss: 2.3908658027648926


 44%|████▎     | 1551/3551 [18:08<27:43,  1.20it/s]

Epoch [4/5], Step [1551/3551], Loss: 2.4089131355285645


 45%|████▌     | 1601/3551 [18:46<27:42,  1.17it/s]

Epoch [4/5], Step [1601/3551], Loss: 2.343975782394409


 46%|████▋     | 1651/3551 [19:23<22:45,  1.39it/s]

Epoch [4/5], Step [1651/3551], Loss: 2.47025728225708


 48%|████▊     | 1701/3551 [20:01<27:24,  1.12it/s]

Epoch [4/5], Step [1701/3551], Loss: 2.4582719802856445


 49%|████▉     | 1751/3551 [20:36<20:37,  1.45it/s]

Epoch [4/5], Step [1751/3551], Loss: 2.423372745513916


 51%|█████     | 1801/3551 [21:12<20:01,  1.46it/s]

Epoch [4/5], Step [1801/3551], Loss: 2.503695011138916


 52%|█████▏    | 1851/3551 [21:49<21:02,  1.35it/s]

Epoch [4/5], Step [1851/3551], Loss: 2.4679229259490967


 54%|█████▎    | 1901/3551 [22:27<25:22,  1.08it/s]

Epoch [4/5], Step [1901/3551], Loss: 2.2710952758789062


 55%|█████▍    | 1951/3551 [23:10<18:32,  1.44it/s]

Epoch [4/5], Step [1951/3551], Loss: 2.394608974456787


 56%|█████▋    | 2001/3551 [23:44<16:42,  1.55it/s]

Epoch [4/5], Step [2001/3551], Loss: 2.341055393218994


 58%|█████▊    | 2051/3551 [24:21<19:17,  1.30it/s]

Epoch [4/5], Step [2051/3551], Loss: 2.488637924194336


 59%|█████▉    | 2101/3551 [24:57<21:47,  1.11it/s]

Epoch [4/5], Step [2101/3551], Loss: 2.5185210704803467


 61%|██████    | 2151/3551 [25:33<15:08,  1.54it/s]

Epoch [4/5], Step [2151/3551], Loss: 2.2169137001037598


 62%|██████▏   | 2201/3551 [26:06<16:18,  1.38it/s]

Epoch [4/5], Step [2201/3551], Loss: 2.3455092906951904


 63%|██████▎   | 2251/3551 [26:37<13:13,  1.64it/s]

Epoch [4/5], Step [2251/3551], Loss: 2.4422366619110107


 65%|██████▍   | 2301/3551 [27:10<12:27,  1.67it/s]

Epoch [4/5], Step [2301/3551], Loss: 2.3519065380096436


 66%|██████▌   | 2351/3551 [27:41<11:54,  1.68it/s]

Epoch [4/5], Step [2351/3551], Loss: 2.421510934829712


 68%|██████▊   | 2401/3551 [28:12<11:03,  1.73it/s]

Epoch [4/5], Step [2401/3551], Loss: 2.4246673583984375


 69%|██████▉   | 2451/3551 [28:41<10:34,  1.73it/s]

Epoch [4/5], Step [2451/3551], Loss: 2.510911703109741


 70%|███████   | 2501/3551 [29:10<09:58,  1.75it/s]

Epoch [4/5], Step [2501/3551], Loss: 2.528066873550415


 72%|███████▏  | 2551/3551 [29:40<09:45,  1.71it/s]

Epoch [4/5], Step [2551/3551], Loss: 2.1578586101531982


 73%|███████▎  | 2601/3551 [30:09<11:53,  1.33it/s]

Epoch [4/5], Step [2601/3551], Loss: 2.3580589294433594


 75%|███████▍  | 2651/3551 [30:42<11:20,  1.32it/s]

Epoch [4/5], Step [2651/3551], Loss: 2.4564356803894043


 76%|███████▌  | 2701/3551 [31:14<10:04,  1.41it/s]

Epoch [4/5], Step [2701/3551], Loss: 2.2973039150238037


 77%|███████▋  | 2751/3551 [31:48<09:27,  1.41it/s]

Epoch [4/5], Step [2751/3551], Loss: 2.4613895416259766


 79%|███████▉  | 2801/3551 [32:20<07:05,  1.76it/s]

Epoch [4/5], Step [2801/3551], Loss: 2.3817360401153564


 80%|████████  | 2851/3551 [32:48<06:35,  1.77it/s]

Epoch [4/5], Step [2851/3551], Loss: 2.6874265670776367


 82%|████████▏ | 2901/3551 [33:17<06:20,  1.71it/s]

Epoch [4/5], Step [2901/3551], Loss: 2.5024943351745605


 83%|████████▎ | 2951/3551 [33:45<05:46,  1.73it/s]

Epoch [4/5], Step [2951/3551], Loss: 2.447706937789917


 85%|████████▍ | 3001/3551 [34:17<05:21,  1.71it/s]

Epoch [4/5], Step [3001/3551], Loss: 2.3885838985443115


 86%|████████▌ | 3051/3551 [34:48<05:36,  1.49it/s]

Epoch [4/5], Step [3051/3551], Loss: 2.579310178756714


 87%|████████▋ | 3101/3551 [35:19<04:41,  1.60it/s]

Epoch [4/5], Step [3101/3551], Loss: 2.3635711669921875


 89%|████████▊ | 3151/3551 [35:49<03:44,  1.78it/s]

Epoch [4/5], Step [3151/3551], Loss: 2.4435954093933105


 90%|█████████ | 3201/3551 [36:16<03:15,  1.79it/s]

Epoch [4/5], Step [3201/3551], Loss: 2.3870677947998047


 92%|█████████▏| 3251/3551 [36:44<02:48,  1.79it/s]

Epoch [4/5], Step [3251/3551], Loss: 2.532822608947754


 93%|█████████▎| 3301/3551 [37:12<02:20,  1.78it/s]

Epoch [4/5], Step [3301/3551], Loss: 2.418844699859619


 94%|█████████▍| 3351/3551 [37:40<01:52,  1.78it/s]

Epoch [4/5], Step [3351/3551], Loss: 2.327195167541504


 96%|█████████▌| 3401/3551 [38:08<01:23,  1.79it/s]

Epoch [4/5], Step [3401/3551], Loss: 2.448355197906494


 97%|█████████▋| 3451/3551 [38:36<01:04,  1.55it/s]

Epoch [4/5], Step [3451/3551], Loss: 2.4969096183776855


 99%|█████████▊| 3501/3551 [39:09<00:29,  1.69it/s]

Epoch [4/5], Step [3501/3551], Loss: 2.126668930053711


100%|██████████| 3551/3551 [39:54<00:00,  1.48it/s]


Epoch [4/5], Step [3551/3551], Loss: 2.0919885635375977
Validation Loss: 2.4308744665942617


  0%|          | 1/3551 [00:00<58:32,  1.01it/s]

Epoch [5/5], Step [1/3551], Loss: 2.413022041320801


  1%|▏         | 51/3551 [00:48<50:00,  1.17it/s]  

Epoch [5/5], Step [51/3551], Loss: 2.484943389892578


  3%|▎         | 101/3551 [01:32<49:56,  1.15it/s] 

Epoch [5/5], Step [101/3551], Loss: 2.2748396396636963


  4%|▍         | 151/3551 [02:21<51:31,  1.10it/s]  

Epoch [5/5], Step [151/3551], Loss: 2.3111703395843506


  6%|▌         | 201/3551 [03:04<50:13,  1.11it/s]  

Epoch [5/5], Step [201/3551], Loss: 2.6119844913482666


  7%|▋         | 251/3551 [03:51<39:46,  1.38it/s]  

Epoch [5/5], Step [251/3551], Loss: 2.5571093559265137


  8%|▊         | 301/3551 [04:45<47:30,  1.14it/s]  

Epoch [5/5], Step [301/3551], Loss: 2.4009828567504883


 10%|▉         | 351/3551 [05:32<1:01:06,  1.15s/it]

Epoch [5/5], Step [351/3551], Loss: 2.2521119117736816


 11%|█▏        | 401/3551 [06:21<49:32,  1.06it/s]  

Epoch [5/5], Step [401/3551], Loss: 2.4921646118164062


 13%|█▎        | 451/3551 [07:07<46:13,  1.12it/s]

Epoch [5/5], Step [451/3551], Loss: 2.479564905166626


 14%|█▍        | 501/3551 [07:38<31:04,  1.64it/s]

Epoch [5/5], Step [501/3551], Loss: 2.5162758827209473


 16%|█▌        | 551/3551 [08:09<30:30,  1.64it/s]

Epoch [5/5], Step [551/3551], Loss: 2.3637826442718506


 17%|█▋        | 601/3551 [08:40<30:02,  1.64it/s]

Epoch [5/5], Step [601/3551], Loss: 2.5399746894836426


 18%|█▊        | 651/3551 [09:13<37:12,  1.30it/s]

Epoch [5/5], Step [651/3551], Loss: 2.2613024711608887


 20%|█▉        | 701/3551 [09:51<31:48,  1.49it/s]

Epoch [5/5], Step [701/3551], Loss: 2.4682750701904297


 21%|██        | 751/3551 [10:32<47:41,  1.02s/it]

Epoch [5/5], Step [751/3551], Loss: 2.4792063236236572


 23%|██▎       | 801/3551 [11:22<43:46,  1.05it/s]  

Epoch [5/5], Step [801/3551], Loss: 2.4685797691345215


 24%|██▍       | 851/3551 [12:10<40:40,  1.11it/s]

Epoch [5/5], Step [851/3551], Loss: 2.672905445098877


 25%|██▌       | 901/3551 [12:52<32:14,  1.37it/s]

Epoch [5/5], Step [901/3551], Loss: 2.3609347343444824


 27%|██▋       | 951/3551 [13:38<45:25,  1.05s/it]

Epoch [5/5], Step [951/3551], Loss: 2.5506467819213867


 28%|██▊       | 1001/3551 [14:28<35:59,  1.18it/s] 

Epoch [5/5], Step [1001/3551], Loss: 2.47353458404541


 30%|██▉       | 1051/3551 [15:06<1:00:07,  1.44s/it]

Epoch [5/5], Step [1051/3551], Loss: 2.45589280128479


 31%|███       | 1101/3551 [15:46<27:15,  1.50it/s]  

Epoch [5/5], Step [1101/3551], Loss: 2.29241943359375


 32%|███▏      | 1151/3551 [16:18<24:20,  1.64it/s]

Epoch [5/5], Step [1151/3551], Loss: 2.3158631324768066


 34%|███▍      | 1201/3551 [16:50<25:54,  1.51it/s]

Epoch [5/5], Step [1201/3551], Loss: 2.282144784927368


 35%|███▌      | 1251/3551 [17:22<23:45,  1.61it/s]

Epoch [5/5], Step [1251/3551], Loss: 2.4173953533172607


 37%|███▋      | 1301/3551 [17:55<25:07,  1.49it/s]

Epoch [5/5], Step [1301/3551], Loss: 2.2517504692077637


 38%|███▊      | 1351/3551 [18:28<22:59,  1.59it/s]

Epoch [5/5], Step [1351/3551], Loss: 2.316084146499634


 39%|███▉      | 1401/3551 [19:03<23:16,  1.54it/s]

Epoch [5/5], Step [1401/3551], Loss: 2.3639333248138428


 41%|████      | 1451/3551 [19:41<22:11,  1.58it/s]

Epoch [5/5], Step [1451/3551], Loss: 2.3742687702178955


 42%|████▏     | 1501/3551 [20:14<22:47,  1.50it/s]

Epoch [5/5], Step [1501/3551], Loss: 2.485152244567871


 44%|████▎     | 1551/3551 [20:49<24:17,  1.37it/s]

Epoch [5/5], Step [1551/3551], Loss: 2.4057400226593018


 45%|████▌     | 1601/3551 [21:23<19:43,  1.65it/s]

Epoch [5/5], Step [1601/3551], Loss: 2.5207386016845703


 46%|████▋     | 1651/3551 [21:53<19:09,  1.65it/s]

Epoch [5/5], Step [1651/3551], Loss: 2.367220878601074


 48%|████▊     | 1701/3551 [22:26<22:38,  1.36it/s]

Epoch [5/5], Step [1701/3551], Loss: 2.4486727714538574


 49%|████▉     | 1751/3551 [23:07<24:39,  1.22it/s]

Epoch [5/5], Step [1751/3551], Loss: 2.4442222118377686


 51%|█████     | 1801/3551 [23:58<34:10,  1.17s/it]

Epoch [5/5], Step [1801/3551], Loss: 2.3982596397399902


 52%|█████▏    | 1851/3551 [24:39<17:53,  1.58it/s]

Epoch [5/5], Step [1851/3551], Loss: 2.312260866165161


 54%|█████▎    | 1901/3551 [25:12<17:52,  1.54it/s]

Epoch [5/5], Step [1901/3551], Loss: 2.415961503982544


 55%|█████▍    | 1951/3551 [25:49<19:38,  1.36it/s]

Epoch [5/5], Step [1951/3551], Loss: 2.4436285495758057


 56%|█████▋    | 2001/3551 [26:39<32:16,  1.25s/it]

Epoch [5/5], Step [2001/3551], Loss: 2.492671489715576


 58%|█████▊    | 2051/3551 [27:15<15:18,  1.63it/s]

Epoch [5/5], Step [2051/3551], Loss: 2.439377546310425


 59%|█████▉    | 2101/3551 [27:45<14:49,  1.63it/s]

Epoch [5/5], Step [2101/3551], Loss: 2.3586983680725098


 61%|██████    | 2151/3551 [28:15<14:33,  1.60it/s]

Epoch [5/5], Step [2151/3551], Loss: 2.3991892337799072


 62%|██████▏   | 2201/3551 [28:46<13:42,  1.64it/s]

Epoch [5/5], Step [2201/3551], Loss: 2.4774744510650635


 63%|██████▎   | 2251/3551 [29:16<13:29,  1.61it/s]

Epoch [5/5], Step [2251/3551], Loss: 2.1582319736480713


 65%|██████▍   | 2301/3551 [29:47<13:00,  1.60it/s]

Epoch [5/5], Step [2301/3551], Loss: 2.435537338256836


 66%|██████▌   | 2351/3551 [30:20<18:41,  1.07it/s]

Epoch [5/5], Step [2351/3551], Loss: 2.383193016052246


 68%|██████▊   | 2401/3551 [31:11<15:55,  1.20it/s]

Epoch [5/5], Step [2401/3551], Loss: 2.398545026779175


 69%|██████▉   | 2451/3551 [31:44<11:21,  1.61it/s]

Epoch [5/5], Step [2451/3551], Loss: 2.376009702682495


 70%|███████   | 2501/3551 [32:17<10:54,  1.60it/s]

Epoch [5/5], Step [2501/3551], Loss: 2.364232301712036


 72%|███████▏  | 2551/3551 [32:57<10:36,  1.57it/s]

Epoch [5/5], Step [2551/3551], Loss: 2.501002550125122


 73%|███████▎  | 2601/3551 [33:28<09:48,  1.61it/s]

Epoch [5/5], Step [2601/3551], Loss: 2.4274566173553467


 75%|███████▍  | 2651/3551 [34:27<18:53,  1.26s/it]

Epoch [5/5], Step [2651/3551], Loss: 2.363659620285034


 76%|███████▌  | 2701/3551 [35:26<11:50,  1.20it/s]

Epoch [5/5], Step [2701/3551], Loss: 2.4274699687957764


 77%|███████▋  | 2751/3551 [36:00<09:53,  1.35it/s]

Epoch [5/5], Step [2751/3551], Loss: 2.40085506439209


 79%|███████▉  | 2801/3551 [36:34<07:59,  1.56it/s]

Epoch [5/5], Step [2801/3551], Loss: 2.449610710144043


 80%|████████  | 2851/3551 [37:10<09:53,  1.18it/s]

Epoch [5/5], Step [2851/3551], Loss: 2.4753453731536865


 82%|████████▏ | 2901/3551 [37:43<07:44,  1.40it/s]

Epoch [5/5], Step [2901/3551], Loss: 2.3705391883850098


 83%|████████▎ | 2951/3551 [38:14<06:11,  1.61it/s]

Epoch [5/5], Step [2951/3551], Loss: 2.3870012760162354


 85%|████████▍ | 3001/3551 [38:57<07:24,  1.24it/s]

Epoch [5/5], Step [3001/3551], Loss: 2.497218608856201


 86%|████████▌ | 3051/3551 [39:31<05:36,  1.49it/s]

Epoch [5/5], Step [3051/3551], Loss: 2.5153021812438965


 87%|████████▋ | 3101/3551 [40:07<05:26,  1.38it/s]

Epoch [5/5], Step [3101/3551], Loss: 2.4815151691436768


 89%|████████▊ | 3151/3551 [40:46<04:56,  1.35it/s]

Epoch [5/5], Step [3151/3551], Loss: 2.4224064350128174


 90%|█████████ | 3201/3551 [41:22<04:11,  1.39it/s]

Epoch [5/5], Step [3201/3551], Loss: 2.2583913803100586


 92%|█████████▏| 3251/3551 [41:56<03:15,  1.54it/s]

Epoch [5/5], Step [3251/3551], Loss: 2.333813428878784


 93%|█████████▎| 3301/3551 [42:30<02:39,  1.57it/s]

Epoch [5/5], Step [3301/3551], Loss: 2.4086861610412598


 94%|█████████▍| 3351/3551 [43:03<02:02,  1.63it/s]

Epoch [5/5], Step [3351/3551], Loss: 2.3760459423065186


 96%|█████████▌| 3401/3551 [43:37<01:51,  1.34it/s]

Epoch [5/5], Step [3401/3551], Loss: 2.3026256561279297


 97%|█████████▋| 3451/3551 [44:12<01:19,  1.26it/s]

Epoch [5/5], Step [3451/3551], Loss: 2.490705966949463


 99%|█████████▊| 3501/3551 [44:44<00:33,  1.49it/s]

Epoch [5/5], Step [3501/3551], Loss: 2.412125587463379


100%|██████████| 3551/3551 [45:17<00:00,  1.31it/s]


Epoch [5/5], Step [3551/3551], Loss: 2.256152868270874
Validation Loss: 2.4189310043672974


In [22]:
torch.save(encoder_decoder.state_dict(), 'model_state3.pth')


In [23]:
!tensorboard --logdir=runs

^C


In [18]:
# TODO: I have no idea what I'm doing or if this is correct
def translate(to_translate, model, loader):
    model.eval()
    inp = loader.tokenize_texts([to_translate])[0].unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(inp)
    print(out.shape)

    out = torch.argmax(out.squeeze(0), dim=0)
    token_ids = out.tolist()
    translated_texts = loader.tokenizer.decode(token_ids)
    translated_texts = translated_texts.replace("<pad>", "")
    print(f"Translated text: {translated_texts}")


In [19]:
torch.save(encoder_decoder, 'model_whole.pth')


NameError: name 'encoder_decoder' is not defined

In [31]:
encoder_decoder.eval()
text = ["Hallo, wie geht es dir?"]
print(f"Translating: {text}")
translate(text, encoder_decoder, wmt_json_loader)
writer.close()

Translating: ['Hallo, wie geht es dir?']
torch.Size([1, 384, 128])
Translated text: Ia tor             


In [20]:
loaded_model = torch.load('model_whole.pth')

In [24]:
loaded_model.eval()
text = ["Was ist es?"]
print(f"Translating: {text}")
translate(text, loaded_model, wmt_json_loader)


Translating: ['Was ist es?']
torch.Size([1, 384, 128])
Translated text: What is    


In [3]:
writer.close()


NameError: name 'writer' is not defined