In [3]:
from os import path as os_path
from logging import getLogger
from typing import Dict, Callable
import torch
from torchtext.data import to_map_style_dataset
from torch.utils.data import DataLoader
from functools import partial

MIN_FREQUENCY_OF_WORD = 10
UNKNOWN_WORD_INDEX = 0
CBOW_N_WORDS = 2
SKIPGRAM_N_WORDS = 2
MAX_SEQUENCE_LENGTH = 100



class DataSet:
    """
    This class provides a simple API for loading the data set
    The preprocessing of the text will be handled in this class
    """

    def __init__(self, 
                 path_to_data_set: str = 
                 "path_to_dataset"):
        # Make sure that the provided file path is correct
        if not os_path.isfile(path_to_data_set):
            print(f"The provided file path to the data set is invalid : {path_to_data_set}")
            raise IOError("Unable to load data set file")

        self._data_set_path = path_to_data_set
        # Read the entire contents of the dataset to memory
        with open(self._data_set_path, mode="r") as open_dataset:
            self._loaded_data_set = open_dataset.read()

        self._data_set_lines = self._loaded_data_set.splitlines()

        # Now the entire data set is loaded into memory
        # Proceed to extract the vocabulary from the data set
        self._vocabulary = self.__extract_vocabulary()
        self._indexed_vocabulary = self.__get_indexed_vocabulary()

    def __extract_vocabulary(self) -> Dict[str, int]:
        """
        This method extracts all the unique words in the data set
        :return: The vocabulary as a dictionary with the frequency of words as values and words as keys
        """
        vocabulary = {}
        for line in self._loaded_data_set.splitlines():
            for word in line.split():
                if word not in vocabulary:
                    vocabulary[word] = 1
                else:
                    vocabulary[word] += 1

        print("Succesfully built vocabulary from the data set")
        print(f"The total number of unique words found are : {len(vocabulary)}")

        return vocabulary

    def __get_indexed_vocabulary(self) -> Dict[str, int]:
        """
        This function maps each word to an index, while taking into account
        the minimum frequency that is needed for  a word to be included
        :return: The method returns a dictionary which maps each word to an index
        """
        indexed_vocabulary = {}
        index = 1
        for word in self._vocabulary.keys():
            if self._vocabulary[word] < MIN_FREQUENCY_OF_WORD:
                continue
            else:
                indexed_vocabulary[word] = index
                index += 1

        return indexed_vocabulary

    
    def __getitem__(self, index: int):
        return self._data_set_lines[index]

    def __len__(self):
        return len(self._data_set_lines)
    
    def get_vocabulary(self):
        return self._indexed_vocabulary

    
    def get_word_indices(self, sentence: str) -> list:
        """
        This method returns the indices of each word in the sentence
        An unknown word is given the default index 0
        :param sentence: The sentence from the data set as a string
        :return: The indices from the _indexed_vocabulary
        """
        indices = []
        for word in sentence.split():
            if word in self._indexed_vocabulary:
                indices.append(self._indexed_vocabulary[word])
            else:
                # The vocabulary does not contain the word
                indices.append(UNKNOWN_WORD_INDEX)

        return indices

    @staticmethod
    def collate_cbow(batch, get_word_indices: Callable):
        """
        Collate_fn for CBOW model to be used with Dataloader.
        `batch` is expected to be list of text paragrahs.

        Context is represented as N=CBOW_N_WORDS past words
        and N=CBOW_N_WORDS future words.

        Long paragraphs will be truncated to contain
        no more that MAX_SEQUENCE_LENGTH tokens.

        Each element in `batch_input` is N=CBOW_N_WORDS*2 context words.
        Each element in `batch_output` is a middle word.
        """
        batch_input, batch_output = [], []
        for text in batch:
            text_tokens_ids = get_word_indices(text)

            if len(text_tokens_ids) < CBOW_N_WORDS * 2 + 1:
                continue

            if MAX_SEQUENCE_LENGTH:
                text_tokens_ids = text_tokens_ids[:MAX_SEQUENCE_LENGTH]

            for idx in range(len(text_tokens_ids) - CBOW_N_WORDS * 2):
                token_id_sequence = text_tokens_ids[idx: (idx + CBOW_N_WORDS * 2 + 1)]
                output = token_id_sequence.pop(CBOW_N_WORDS)
                input_ = token_id_sequence
                batch_input.append(input_)
                batch_output.append(output)

        batch_input = torch.tensor(batch_input, dtype=torch.long)
        batch_output = torch.tensor(batch_output, dtype=torch.long)
        return batch_input, batch_output


    @staticmethod
    def collate_skipgram(batch, get_word_indices: Callable):
        """
        Collate_fn for Skip-Gram model to be used with Dataloader.
        `batch` is going to be sentences from the data set

        Context is represented as N=SKIPGRAM_N_WORDS past words
        and N=SKIPGRAM_N_WORDS future words.

        Each element in `batch_input` is a middle word.
        Each element in `batch_output` is a context word.
        """
        batch_input, batch_output = [], []
        for text in batch:
            text_tokens_ids = get_word_indices(text)

            if len(text_tokens_ids) < SKIPGRAM_N_WORDS * 2 + 1:
                continue

            if MAX_SEQUENCE_LENGTH:
                text_tokens_ids = text_tokens_ids[:MAX_SEQUENCE_LENGTH]

            for idx in range(len(text_tokens_ids) - SKIPGRAM_N_WORDS * 2):
                token_id_sequence = text_tokens_ids[idx: (idx + SKIPGRAM_N_WORDS * 2 + 1)]
                input_ = token_id_sequence.pop(SKIPGRAM_N_WORDS)
                outputs = token_id_sequence

                for output in outputs:
                    batch_input.append(input_)
                    batch_output.append(output)

        batch_input = torch.tensor(batch_input, dtype=torch.long)
        batch_output = torch.tensor(batch_output, dtype=torch.long)
        return batch_input, batch_output

    @staticmethod
    def get_dataloader_and_vocab(
            model_name, batch_size, shuffle, vocab=None
    ):
        data_set = DataSet()
        if not vocab:
            vocab = data_set.get_vocabulary()

        text_pipeline = data_set.get_word_indices

        if model_name == "cbow":
            collate_fn = DataSet.collate_cbow
        elif model_name == "skipgram":
            collate_fn = DataSet.collate_skipgram
        else:
            raise ValueError("Choose model from: cbow, skipgram")

        dataloader = DataLoader(
            data_set,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=partial(collate_fn, get_word_indices=text_pipeline),
        )
        return dataloader, vocab


In [4]:
import torch.nn as nn

EMBEDDING_VECTOR_DIM = 150
# Restrict the maximum value of the weights for a word to prevent them from becoming too large
EMBED_MAX_NORM = 1


class CbowModel(nn.Module):
    def __init__(self, vocab_size: int):
        super().__init__()
        self.embedding_layer = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=EMBEDDING_VECTOR_DIM,
            max_norm=EMBED_MAX_NORM,
        )
        self.linear = nn.Linear(
            in_features=EMBEDDING_VECTOR_DIM,
            out_features=vocab_size)

    def forward(self, input_features):
        x = self.embedding_layer(input_features)
        # For the CBOW approach we have to use the mean of the embedded context words
        x = x.mean(axis=1)
        x = self.linear(x)
        return x


class SkipGramModel(nn.Module):
    """
    Class to create a skip gram model
    """
    def __init__(self, vocab_size: int):
        super(SkipGramModel, self).__init__()
        self.embedding_layer = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=EMBEDDING_VECTOR_DIM,
            max_norm=EMBED_MAX_NORM,
        )
        self.linear = nn.Linear(
            in_features=EMBEDDING_VECTOR_DIM,
            out_features=vocab_size,
        )

    def forward(self, input_features):
        x = self.embedding_layer(input_features)
        # No need to take mean in case of a skip gram mode
        x = self.linear(x)
        return x

In [7]:
import numpy as np
import torch.optim
from tqdm import tqdm
# Create the train loop
def train_epoch(model, train_dataloader,optimizer, criterion, device = torch.device("cuda"), train_steps= 100):
        model.train()
        running_loss = []

        for i in tqdm(range(1000)):
            for i, batch_data in enumerate(train_dataloader, 1):
                inputs = batch_data[0].to(device)
                labels = batch_data[1].to(device)

                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss.append(loss.item())

                if i == train_steps:
                    break

            epoch_loss = np.mean(running_loss)
            print(f"Epoch loss : {epoch_loss}")

# Create a CBOW model

data_set = DataSet()
vocab = data_set.get_vocabulary()

model = CbowModel(vocab_size=len(vocab))
model.to(torch.device("cuda"))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_criteria = torch.nn.CrossEntropyLoss()

train_epoch(model=model, 
            train_dataloader=DataSet.get_dataloader_and_vocab(model_name="cbow", batch_size=15, vocab=vocab, shuffle=False)[0],
            optimizer=optimizer, criterion=loss_criteria
            )


Succesfully built vocabulary from the data set
The total number of unique words found are : 15517
Succesfully built vocabulary from the data set
The total number of unique words found are : 15517


  0%|          | 2/1000 [00:00<02:36,  6.37it/s]

Epoch loss : 6.770091948509216
Epoch loss : 6.223940343856811


  0%|          | 4/1000 [00:00<02:36,  6.35it/s]

Epoch loss : 5.794501754442851
Epoch loss : 5.521556769609451


  1%|          | 6/1000 [00:00<02:36,  6.36it/s]

Epoch loss : 5.341455503463745
Epoch loss : 5.210314040184021


  1%|          | 8/1000 [00:01<02:37,  6.28it/s]

Epoch loss : 5.106981028829302
Epoch loss : 5.020740823149681


  1%|          | 10/1000 [00:01<02:35,  6.36it/s]

Epoch loss : 4.945663929250506
Epoch loss : 4.878316205263138


  1%|          | 12/1000 [00:01<02:35,  6.37it/s]

Epoch loss : 4.816655064062639
Epoch loss : 4.759397992690404


  1%|▏         | 14/1000 [00:02<02:37,  6.26it/s]

Epoch loss : 4.705680034160614
Epoch loss : 4.654883821862085


  2%|▏         | 16/1000 [00:02<02:35,  6.34it/s]

Epoch loss : 4.606566417058309
Epoch loss : 4.560385833829641


  2%|▏         | 18/1000 [00:02<02:33,  6.40it/s]

Epoch loss : 4.5160886315738455
Epoch loss : 4.473481070200602


  2%|▏         | 20/1000 [00:03<02:32,  6.42it/s]

Epoch loss : 4.432406697022287
Epoch loss : 4.392734553813934


  2%|▏         | 22/1000 [00:03<02:32,  6.43it/s]

Epoch loss : 4.354361511525654
Epoch loss : 4.317204540100964


  2%|▏         | 24/1000 [00:03<02:31,  6.44it/s]

Epoch loss : 4.281183646139891
Epoch loss : 4.246231744587422


  3%|▎         | 26/1000 [00:04<02:33,  6.34it/s]

Epoch loss : 4.212289774894715
Epoch loss : 4.1792996795360855


  3%|▎         | 28/1000 [00:04<02:34,  6.31it/s]

Epoch loss : 4.147214340015694
Epoch loss : 4.115989486234529


  3%|▎         | 30/1000 [00:04<02:36,  6.22it/s]

Epoch loss : 4.085579890136061
Epoch loss : 4.055946831782659


  3%|▎         | 32/1000 [00:05<02:35,  6.24it/s]

Epoch loss : 4.027054430976991
Epoch loss : 3.998867204412818


  3%|▎         | 34/1000 [00:05<02:32,  6.34it/s]

Epoch loss : 3.971355118896022
Epoch loss : 3.944487526276532


  4%|▎         | 36/1000 [00:05<02:30,  6.40it/s]

Epoch loss : 3.9182374488966807
Epoch loss : 3.892578712105751


  4%|▍         | 38/1000 [00:05<02:30,  6.41it/s]

Epoch loss : 3.8674886890359828
Epoch loss : 3.8429421268011392


  4%|▍         | 40/1000 [00:06<02:27,  6.53it/s]

Epoch loss : 3.8189187182524265
Epoch loss : 3.795395875453949


  4%|▍         | 42/1000 [00:06<02:26,  6.53it/s]

Epoch loss : 3.772361300747569
Epoch loss : 3.7497872752802714


  4%|▍         | 44/1000 [00:06<02:40,  5.96it/s]

Epoch loss : 3.727665473361348
Epoch loss : 3.705973272919655


  5%|▍         | 46/1000 [00:07<02:48,  5.66it/s]

Epoch loss : 3.684695558812883
Epoch loss : 3.663822623802268


  5%|▍         | 48/1000 [00:07<02:50,  5.60it/s]

Epoch loss : 3.643334564198839
Epoch loss : 3.6232215749224026


  5%|▌         | 50/1000 [00:08<02:51,  5.53it/s]

Epoch loss : 3.6034710251068582
Epoch loss : 3.5840705862998963


  5%|▌         | 52/1000 [00:08<02:51,  5.54it/s]

Epoch loss : 3.565008569698708
Epoch loss : 3.5462760355151617


  5%|▌         | 53/1000 [00:08<02:54,  5.44it/s]

Epoch loss : 3.5278589056122978


  6%|▌         | 55/1000 [00:08<02:57,  5.32it/s]

Epoch loss : 3.509748618580677
Epoch loss : 3.49193760620464


  6%|▌         | 57/1000 [00:09<02:56,  5.35it/s]

Epoch loss : 3.4744109896038258
Epoch loss : 3.457164896538383


  6%|▌         | 59/1000 [00:09<02:54,  5.41it/s]

Epoch loss : 3.4401909112724764
Epoch loss : 3.4234784173561357


  6%|▌         | 61/1000 [00:10<02:56,  5.33it/s]

Epoch loss : 3.4070231434504192
Epoch loss : 3.39081609358553


  6%|▋         | 63/1000 [00:10<02:48,  5.55it/s]

Epoch loss : 3.3748515073906993
Epoch loss : 3.359120491421412


  6%|▋         | 65/1000 [00:10<02:40,  5.84it/s]

Epoch loss : 3.343618005346507
Epoch loss : 3.328338636948512


  7%|▋         | 67/1000 [00:11<02:34,  6.05it/s]

Epoch loss : 3.3132743948156183
Epoch loss : 3.2984185558824395


  7%|▋         | 69/1000 [00:11<02:30,  6.18it/s]

Epoch loss : 3.28376686064636
Epoch loss : 3.269314330989036


  7%|▋         | 71/1000 [00:11<02:27,  6.32it/s]

Epoch loss : 3.2550558101449694
Epoch loss : 3.2409880500276325


  7%|▋         | 73/1000 [00:12<02:26,  6.33it/s]

Epoch loss : 3.227103102952242
Epoch loss : 3.213398722361212


  8%|▊         | 75/1000 [00:12<02:24,  6.39it/s]

Epoch loss : 3.19987073023577
Epoch loss : 3.186514009284973


  8%|▊         | 77/1000 [00:12<02:24,  6.40it/s]

Epoch loss : 3.173323104444303
Epoch loss : 3.1602967354539153


  8%|▊         | 79/1000 [00:13<02:28,  6.20it/s]

Epoch loss : 3.147429350950779
Epoch loss : 3.134716353491892


  8%|▊         | 81/1000 [00:13<02:29,  6.16it/s]

Epoch loss : 3.122155634731054
Epoch loss : 3.1097437883895123


  8%|▊         | 83/1000 [00:13<02:27,  6.21it/s]

Epoch loss : 3.097476529653479
Epoch loss : 3.085352372361953


  8%|▊         | 85/1000 [00:13<02:30,  6.07it/s]

Epoch loss : 3.0733667119344075
Epoch loss : 3.061517538323122


  9%|▊         | 87/1000 [00:14<02:24,  6.31it/s]

Epoch loss : 3.0498002832296285
Epoch loss : 3.0382130693978278


  9%|▉         | 89/1000 [00:14<02:20,  6.47it/s]

Epoch loss : 3.026755095232617
Epoch loss : 3.0154202569334694


  9%|▉         | 91/1000 [00:14<02:20,  6.49it/s]

Epoch loss : 3.0042072251240413
Epoch loss : 2.9931145961861034


  9%|▉         | 93/1000 [00:15<02:20,  6.48it/s]

Epoch loss : 2.9821415464774423
Epoch loss : 2.9712816458748232


 10%|▉         | 95/1000 [00:15<02:19,  6.51it/s]

Epoch loss : 2.9605351382493974
Epoch loss : 2.9498980754300166


 10%|▉         | 97/1000 [00:15<02:16,  6.59it/s]

Epoch loss : 2.939370589690904
Epoch loss : 2.9289487120662767


 10%|▉         | 99/1000 [00:16<02:17,  6.54it/s]

Epoch loss : 2.9186303889143224
Epoch loss : 2.9084154141310488


 10%|█         | 101/1000 [00:16<02:18,  6.47it/s]

Epoch loss : 2.8983006455540656
Epoch loss : 2.888284504850312


 10%|█         | 103/1000 [00:16<02:18,  6.49it/s]

Epoch loss : 2.8783656873071894
Epoch loss : 2.8685428322403177


 10%|█         | 105/1000 [00:17<02:20,  6.39it/s]

Epoch loss : 2.858811664260351
Epoch loss : 2.8491723182428452


 11%|█         | 107/1000 [00:17<02:19,  6.40it/s]

Epoch loss : 2.839622476112168
Epoch loss : 2.8301613731807636


 11%|█         | 109/1000 [00:17<02:18,  6.43it/s]

Epoch loss : 2.8207879796293045
Epoch loss : 2.8114991354504855


 11%|█         | 111/1000 [00:18<02:20,  6.32it/s]

Epoch loss : 2.8022944878773255
Epoch loss : 2.7931722455948322


 11%|█▏        | 113/1000 [00:18<02:20,  6.33it/s]

Epoch loss : 2.784131356041346
Epoch loss : 2.7751695026655114


 12%|█▏        | 115/1000 [00:18<02:18,  6.41it/s]

Epoch loss : 2.7662863895767615
Epoch loss : 2.757479410337365


 12%|█▏        | 117/1000 [00:18<02:17,  6.45it/s]

Epoch loss : 2.7487505491437583
Epoch loss : 2.7400958623947242


 12%|█▏        | 119/1000 [00:19<02:18,  6.37it/s]

Epoch loss : 2.7315147685903614
Epoch loss : 2.723006619475469


 12%|█▏        | 121/1000 [00:19<02:20,  6.27it/s]

Epoch loss : 2.7145692344208556
Epoch loss : 2.706201656111016


 12%|█▏        | 123/1000 [00:19<02:16,  6.40it/s]

Epoch loss : 2.6979035767668584
Epoch loss : 2.6896733255018064


 12%|█▎        | 125/1000 [00:20<02:17,  6.38it/s]

Epoch loss : 2.681510900730087
Epoch loss : 2.6734155347442625


 13%|█▎        | 127/1000 [00:20<02:31,  5.77it/s]

Epoch loss : 2.6653842304244875
Epoch loss : 2.657417260192511


 13%|█▎        | 128/1000 [00:20<02:34,  5.65it/s]

Epoch loss : 2.6495134787354617
Epoch loss : 2.6416720078065414


 13%|█▎        | 131/1000 [00:21<02:39,  5.46it/s]

Epoch loss : 2.6338920372541135
Epoch loss : 2.6261720835252573


 13%|█▎        | 132/1000 [00:21<02:41,  5.36it/s]

Epoch loss : 2.618512190255252


 13%|█▎        | 134/1000 [00:21<02:45,  5.22it/s]

Epoch loss : 2.6109134440852286
Epoch loss : 2.6033699622083066


 14%|█▎        | 136/1000 [00:22<02:43,  5.28it/s]

Epoch loss : 2.595884419873909
Epoch loss : 2.5884557689112775


 14%|█▎        | 137/1000 [00:22<03:21,  4.29it/s]

Epoch loss : 2.581082394671266


 14%|█▍        | 139/1000 [00:23<03:25,  4.19it/s]

Epoch loss : 2.5737652133426803
Epoch loss : 2.5665014301787177


 14%|█▍        | 141/1000 [00:23<03:04,  4.67it/s]

Epoch loss : 2.5592909621340887
Epoch loss : 2.552134129240158


 14%|█▍        | 142/1000 [00:23<02:57,  4.85it/s]

Epoch loss : 2.545027752468284


 14%|█▍        | 143/1000 [00:23<02:58,  4.79it/s]

Epoch loss : 2.537972813211121


 14%|█▍        | 144/1000 [00:24<03:12,  4.45it/s]

Epoch loss : 2.5309699517074558


 15%|█▍        | 146/1000 [00:24<02:53,  4.93it/s]

Epoch loss : 2.524016710721213
Epoch loss : 2.517113129525152


 15%|█▍        | 148/1000 [00:24<02:46,  5.12it/s]

Epoch loss : 2.51025889497225
Epoch loss : 2.50345278169255


 15%|█▌        | 150/1000 [00:25<02:45,  5.12it/s]

Epoch loss : 2.4966952809911445
Epoch loss : 2.48998396538496


 15%|█▌        | 152/1000 [00:25<02:28,  5.72it/s]

Epoch loss : 2.483320659197719
Epoch loss : 2.476702703821816


 15%|█▌        | 153/1000 [00:25<02:37,  5.36it/s]

Epoch loss : 2.4701300130599466


 16%|█▌        | 155/1000 [00:26<02:32,  5.55it/s]

Epoch loss : 2.46360188781441
Epoch loss : 2.457117914430557


 16%|█▌        | 157/1000 [00:26<02:19,  6.04it/s]

Epoch loss : 2.450678428426767
Epoch loss : 2.4442817586432595


 16%|█▌        | 159/1000 [00:26<02:13,  6.29it/s]

Epoch loss : 2.4379277078222623
Epoch loss : 2.4316166669392736


 16%|█▌        | 161/1000 [00:27<02:11,  6.37it/s]

Epoch loss : 2.4253476281799378
Epoch loss : 2.419119760793929


 16%|█▋        | 163/1000 [00:27<02:11,  6.38it/s]

Epoch loss : 2.4129327896053407
Epoch loss : 2.406786882076527


 16%|█▋        | 165/1000 [00:27<02:11,  6.35it/s]

Epoch loss : 2.40068038616602
Epoch loss : 2.3946142248782243


 17%|█▋        | 167/1000 [00:28<02:10,  6.39it/s]

Epoch loss : 2.3885874701408016
Epoch loss : 2.3825992845311137


 17%|█▋        | 169/1000 [00:28<02:09,  6.40it/s]

Epoch loss : 2.37664919039678
Epoch loss : 2.370737322833411


 17%|█▋        | 171/1000 [00:28<02:07,  6.50it/s]

Epoch loss : 2.364863151960513
Epoch loss : 2.3590261443734866


 17%|█▋        | 173/1000 [00:29<02:06,  6.52it/s]

Epoch loss : 2.353224992623856
Epoch loss : 2.3474608260496503


 18%|█▊        | 175/1000 [00:29<02:08,  6.43it/s]

Epoch loss : 2.3417319603144438
Epoch loss : 2.3360401822294508


 18%|█▊        | 176/1000 [00:29<02:09,  6.38it/s]

Epoch loss : 2.3303826280513946


 18%|█▊        | 178/1000 [00:29<02:27,  5.57it/s]

Epoch loss : 2.324760696988995
Epoch loss : 2.319173120802708


 18%|█▊        | 180/1000 [00:30<02:17,  5.94it/s]

Epoch loss : 2.3136198220605957
Epoch loss : 2.3081001093486946


 18%|█▊        | 182/1000 [00:30<02:13,  6.14it/s]

Epoch loss : 2.30261472566681
Epoch loss : 2.2971618982101534


 18%|█▊        | 183/1000 [00:30<02:10,  6.28it/s]

Epoch loss : 2.2917419499517138


 18%|█▊        | 184/1000 [00:30<02:30,  5.42it/s]

Epoch loss : 2.2863544383353513


 19%|█▊        | 186/1000 [00:31<02:55,  4.63it/s]

Epoch loss : 2.280999750446629
Epoch loss : 2.2756761376800076


 19%|█▉        | 188/1000 [00:31<02:30,  5.40it/s]

Epoch loss : 2.2703846100020537
Epoch loss : 2.265123911636307


 19%|█▉        | 190/1000 [00:32<02:21,  5.71it/s]

Epoch loss : 2.2598946628557943
Epoch loss : 2.254696665221139


 19%|█▉        | 192/1000 [00:32<02:15,  5.97it/s]

Epoch loss : 2.249529105187711
Epoch loss : 2.244391080988571


 19%|█▉        | 194/1000 [00:32<02:13,  6.06it/s]

Epoch loss : 2.239283434464524
Epoch loss : 2.234205230308562


 20%|█▉        | 196/1000 [00:33<02:10,  6.17it/s]

Epoch loss : 2.2291564991015655
Epoch loss : 2.224136783772585


 20%|█▉        | 198/1000 [00:33<02:08,  6.25it/s]

Epoch loss : 2.2191461557149887
Epoch loss : 2.2141837774473005


 20%|██        | 200/1000 [00:33<02:06,  6.33it/s]

Epoch loss : 2.2092495872117768
Epoch loss : 2.204343351781368


 20%|██        | 202/1000 [00:34<02:14,  5.92it/s]

Epoch loss : 2.1994650679085384
Epoch loss : 2.194613857853531


 20%|██        | 203/1000 [00:34<02:18,  5.74it/s]

Epoch loss : 2.1897903212185565


 20%|██        | 205/1000 [00:34<02:25,  5.46it/s]

Epoch loss : 2.1849933090513827
Epoch loss : 2.1802241941690443


 21%|██        | 207/1000 [00:35<02:26,  5.39it/s]

Epoch loss : 2.1754811980597024
Epoch loss : 2.1707651326132282


 21%|██        | 209/1000 [00:35<02:28,  5.32it/s]

Epoch loss : 2.1660752428953463
Epoch loss : 2.161411411528382


 21%|██        | 211/1000 [00:35<02:28,  5.30it/s]

Epoch loss : 2.1567728026821498
Epoch loss : 2.152159385144428


 21%|██▏       | 213/1000 [00:36<02:28,  5.31it/s]

Epoch loss : 2.1475717562577636
Epoch loss : 2.1430082382590556


 22%|██▏       | 215/1000 [00:36<02:27,  5.32it/s]

Epoch loss : 2.1384702489754863
Epoch loss : 2.133956446789032


 22%|██▏       | 217/1000 [00:36<02:26,  5.33it/s]

Epoch loss : 2.1294669600879703
Epoch loss : 2.125001620622824


 22%|██▏       | 219/1000 [00:37<02:27,  5.31it/s]

Epoch loss : 2.1205607157590194
Epoch loss : 2.116143144431724


 22%|██▏       | 221/1000 [00:37<02:22,  5.47it/s]

Epoch loss : 2.1117486561292953
Epoch loss : 2.1073787484098885


 22%|██▏       | 223/1000 [00:37<02:14,  5.79it/s]

Epoch loss : 2.103031404222454
Epoch loss : 2.098706601780626


 22%|██▎       | 225/1000 [00:38<02:10,  5.95it/s]

Epoch loss : 2.094405264284994
Epoch loss : 2.090126680217849


 23%|██▎       | 227/1000 [00:38<02:05,  6.14it/s]

Epoch loss : 2.085870565054691
Epoch loss : 2.081637157172884


 23%|██▎       | 229/1000 [00:38<02:02,  6.31it/s]

Epoch loss : 2.077425892394886
Epoch loss : 2.0732362762002445


 23%|██▎       | 231/1000 [00:39<02:00,  6.40it/s]

Epoch loss : 2.0690682840787846
Epoch loss : 2.064922449681666


 23%|██▎       | 233/1000 [00:39<02:01,  6.31it/s]

Epoch loss : 2.060797491641394
Epoch loss : 2.0566937204287288


 24%|██▎       | 235/1000 [00:39<01:59,  6.41it/s]

Epoch loss : 2.0526109593673647
Epoch loss : 2.048549589486832


 24%|██▎       | 237/1000 [00:40<01:57,  6.51it/s]

Epoch loss : 2.044509028986854
Epoch loss : 2.040489462827831


 24%|██▍       | 239/1000 [00:40<01:56,  6.51it/s]

Epoch loss : 2.0364894622238743
Epoch loss : 2.0325104689872413


 24%|██▍       | 241/1000 [00:40<01:56,  6.51it/s]

Epoch loss : 2.0285511594762404
Epoch loss : 2.0246119140811976


 24%|██▍       | 243/1000 [00:41<01:56,  6.47it/s]

Epoch loss : 2.0206921528439876
Epoch loss : 2.0167923796716543


 24%|██▍       | 245/1000 [00:41<01:57,  6.42it/s]

Epoch loss : 2.012911947218121
Epoch loss : 2.009051558482404


 25%|██▍       | 247/1000 [00:41<01:58,  6.35it/s]

Epoch loss : 2.005209945098656
Epoch loss : 2.0013879395882612


 25%|██▍       | 249/1000 [00:42<01:57,  6.38it/s]

Epoch loss : 1.9975845462252055
Epoch loss : 1.9938001775167074


 25%|██▌       | 251/1000 [00:42<01:57,  6.35it/s]

Epoch loss : 1.9900344969415664
Epoch loss : 1.986287476681618


 25%|██▌       | 253/1000 [00:42<01:58,  6.33it/s]

Epoch loss : 1.9825589154708008
Epoch loss : 1.9788483013937124


 26%|██▌       | 255/1000 [00:42<01:56,  6.39it/s]

Epoch loss : 1.975155965258756
Epoch loss : 1.9714819390890646


 26%|██▌       | 257/1000 [00:43<01:58,  6.28it/s]

Epoch loss : 1.9678252969821914
Epoch loss : 1.964186840704443


 26%|██▌       | 259/1000 [00:43<01:58,  6.26it/s]

Epoch loss : 1.960566240939521
Epoch loss : 1.9569631748540062


 26%|██▌       | 261/1000 [00:43<01:55,  6.41it/s]

Epoch loss : 1.953377260024731
Epoch loss : 1.9498090475226728


 26%|██▋       | 263/1000 [00:44<01:53,  6.47it/s]

Epoch loss : 1.946257384621915
Epoch loss : 1.9427229445810101


 26%|██▋       | 264/1000 [00:44<01:53,  6.49it/s]

Epoch loss : 1.9392058031464166


 27%|██▋       | 266/1000 [00:44<02:13,  5.48it/s]

Epoch loss : 1.9357051305028627
Epoch loss : 1.9322211648512604


 27%|██▋       | 268/1000 [00:45<02:04,  5.89it/s]

Epoch loss : 1.9287533411751971
Epoch loss : 1.925301936533469


 27%|██▋       | 270/1000 [00:45<02:15,  5.37it/s]

Epoch loss : 1.9218671454638796
Epoch loss : 1.9184482444745523


 27%|██▋       | 272/1000 [00:45<02:06,  5.77it/s]

Epoch loss : 1.9150458865500024
Epoch loss : 1.9116590197976022


 27%|██▋       | 274/1000 [00:46<01:59,  6.07it/s]

Epoch loss : 1.9082886049599002
Epoch loss : 1.9049336680748168


 28%|██▊       | 276/1000 [00:46<01:56,  6.21it/s]

Epoch loss : 1.9015945122718811
Epoch loss : 1.898270310299552


 28%|██▊       | 278/1000 [00:46<01:54,  6.31it/s]

Epoch loss : 1.8949621889350217
Epoch loss : 1.8916689280082855


 28%|██▊       | 280/1000 [00:47<01:52,  6.42it/s]

Epoch loss : 1.8883916071388456
Epoch loss : 1.885129134808268


 28%|██▊       | 282/1000 [00:47<01:51,  6.41it/s]

Epoch loss : 1.881881804273222
Epoch loss : 1.8786489175207226


 28%|██▊       | 284/1000 [00:47<02:05,  5.71it/s]

Epoch loss : 1.8754313263455045
Epoch loss : 1.8722280415955563


 29%|██▊       | 286/1000 [00:48<02:08,  5.57it/s]

Epoch loss : 1.8690397607799163
Epoch loss : 1.8658658855361538


 29%|██▉       | 288/1000 [00:48<02:08,  5.53it/s]

Epoch loss : 1.8627067567828641
Epoch loss : 1.8595615359892448


 29%|██▉       | 290/1000 [00:48<02:11,  5.41it/s]

Epoch loss : 1.8564308021101572
Epoch loss : 1.8533140817202371


 29%|██▉       | 292/1000 [00:49<02:09,  5.45it/s]

Epoch loss : 1.8502118018300262
Epoch loss : 1.8471232867302143


 29%|██▉       | 294/1000 [00:49<02:10,  5.42it/s]

Epoch loss : 1.8440490916890089
Epoch loss : 1.8409879940423837


 30%|██▉       | 296/1000 [00:50<02:13,  5.28it/s]

Epoch loss : 1.837940743959556
Epoch loss : 1.8349068646132947


 30%|██▉       | 298/1000 [00:50<02:12,  5.30it/s]

Epoch loss : 1.8318871606339509
Epoch loss : 1.82888017064773


 30%|███       | 300/1000 [00:50<02:12,  5.27it/s]

Epoch loss : 1.8258870177304865
Epoch loss : 1.8229065101881823


 30%|███       | 302/1000 [00:51<02:03,  5.65it/s]

Epoch loss : 1.8199394691049855
Epoch loss : 1.8169857174908088


 30%|███       | 304/1000 [00:51<01:57,  5.92it/s]

Epoch loss : 1.8140451712970294
Epoch loss : 1.8111169500394086


 31%|███       | 306/1000 [00:51<01:55,  6.00it/s]

Epoch loss : 1.8082017318186212
Epoch loss : 1.8052993351532742


 31%|███       | 308/1000 [00:52<01:53,  6.08it/s]

Epoch loss : 1.8024093682311646
Epoch loss : 1.7995322513251335


 31%|███       | 310/1000 [00:52<01:50,  6.25it/s]

Epoch loss : 1.796667800561124
Epoch loss : 1.7938153687965486


 31%|███       | 312/1000 [00:52<01:48,  6.35it/s]

Epoch loss : 1.7909756407956219
Epoch loss : 1.7881475723821383


 31%|███▏      | 314/1000 [00:53<01:47,  6.40it/s]

Epoch loss : 1.785332606151081
Epoch loss : 1.7825289748675504


 32%|███▏      | 316/1000 [00:53<01:45,  6.47it/s]

Epoch loss : 1.7797378657734584
Epoch loss : 1.7769582452106325


 32%|███▏      | 318/1000 [00:53<01:44,  6.50it/s]

Epoch loss : 1.774191332466971
Epoch loss : 1.7714357051759395


 32%|███▏      | 320/1000 [00:53<01:43,  6.54it/s]

Epoch loss : 1.7686917968229814
Epoch loss : 1.765959402002394


 32%|███▏      | 322/1000 [00:54<01:43,  6.58it/s]

Epoch loss : 1.7632386827450304
Epoch loss : 1.7605297438013627


 32%|███▏      | 324/1000 [00:54<01:43,  6.54it/s]

Epoch loss : 1.7578318244680162
Epoch loss : 1.7551453742771237


 33%|███▎      | 326/1000 [00:54<01:43,  6.49it/s]

Epoch loss : 1.752470486993056
Epoch loss : 1.7498070038041453


 33%|███▎      | 328/1000 [00:55<01:43,  6.47it/s]

Epoch loss : 1.7471547254856208
Epoch loss : 1.744513259949844


 33%|███▎      | 330/1000 [00:55<01:44,  6.40it/s]

Epoch loss : 1.7418830182396532
Epoch loss : 1.7392634184541125


 33%|███▎      | 332/1000 [00:55<01:43,  6.44it/s]

Epoch loss : 1.7366551958001992
Epoch loss : 1.734057342216193


 33%|███▎      | 334/1000 [00:56<01:46,  6.28it/s]

Epoch loss : 1.7314705386587808
Epoch loss : 1.7288941700283638


 34%|███▎      | 336/1000 [00:56<01:46,  6.25it/s]

Epoch loss : 1.7263287769129028
Epoch loss : 1.723773679367843


 34%|███▍      | 338/1000 [00:56<01:44,  6.34it/s]

Epoch loss : 1.7212292484713592
Epoch loss : 1.7186950597484436


 34%|███▍      | 340/1000 [00:57<01:44,  6.34it/s]

Epoch loss : 1.7161711759500446
Epoch loss : 1.713657628054128


 34%|███▍      | 342/1000 [00:57<01:43,  6.36it/s]

Epoch loss : 1.7111544873515183
Epoch loss : 1.70866187945444


 34%|███▍      | 344/1000 [00:57<01:42,  6.39it/s]

Epoch loss : 1.7061793258221434
Epoch loss : 1.7037067184770522


 35%|███▍      | 346/1000 [00:58<01:43,  6.33it/s]

Epoch loss : 1.7012443541046502
Epoch loss : 1.698791473985063


 35%|███▍      | 348/1000 [00:58<01:41,  6.41it/s]

Epoch loss : 1.696348885694567
Epoch loss : 1.6939160385779266


 35%|███▌      | 350/1000 [00:58<01:42,  6.36it/s]

Epoch loss : 1.6914928997399814
Epoch loss : 1.689079232258456


 35%|███▌      | 352/1000 [00:58<01:43,  6.25it/s]

Epoch loss : 1.686675112492848
Epoch loss : 1.6842808010158214


 35%|███▌      | 354/1000 [00:59<01:41,  6.39it/s]

Epoch loss : 1.6818959541276899
Epoch loss : 1.679520876002682


 36%|███▌      | 356/1000 [00:59<01:39,  6.45it/s]

Epoch loss : 1.6771552978796018
Epoch loss : 1.6747992008212913


 36%|███▌      | 358/1000 [00:59<01:39,  6.42it/s]

Epoch loss : 1.6724522458271487
Epoch loss : 1.6701146504649236


 36%|███▌      | 360/1000 [01:00<01:40,  6.36it/s]

Epoch loss : 1.667786206338233
Epoch loss : 1.6654666647157734


 36%|███▌      | 362/1000 [01:00<01:40,  6.37it/s]

Epoch loss : 1.6631566639644948
Epoch loss : 1.660855518965596


 36%|███▋      | 364/1000 [01:00<01:40,  6.31it/s]

Epoch loss : 1.658563677186644
Epoch loss : 1.6562804949381849


 37%|███▋      | 366/1000 [01:01<01:52,  5.65it/s]

Epoch loss : 1.6540065424515777
Epoch loss : 1.6517413268828653


 37%|███▋      | 367/1000 [01:01<01:54,  5.53it/s]

Epoch loss : 1.649484998124172
Epoch loss : 1.6472376023265332

 37%|███▋      | 369/1000 [01:01<01:57,  5.36it/s]


Epoch loss : 1.6449985455545952


 37%|███▋      | 370/1000 [01:02<02:01,  5.20it/s]

Epoch loss : 1.6427682985019039


 37%|███▋      | 372/1000 [01:02<02:02,  5.15it/s]

Epoch loss : 1.6405466677083802
Epoch loss : 1.6383338129127858


 37%|███▋      | 374/1000 [01:02<02:01,  5.15it/s]

Epoch loss : 1.6361292093353041
Epoch loss : 1.633932964841631


 38%|███▊      | 376/1000 [01:03<02:00,  5.17it/s]

Epoch loss : 1.6317451172947883
Epoch loss : 1.6295655703322685


 38%|███▊      | 378/1000 [01:03<02:00,  5.17it/s]

Epoch loss : 1.6273944306294545
Epoch loss : 1.625231301132491


 38%|███▊      | 380/1000 [01:03<01:59,  5.17it/s]

Epoch loss : 1.6230769457422334
Epoch loss : 1.6209304110839178


 38%|███▊      | 382/1000 [01:04<01:57,  5.28it/s]

Epoch loss : 1.6187922867960505
Epoch loss : 1.6166620774231657


 38%|███▊      | 384/1000 [01:04<01:48,  5.68it/s]

Epoch loss : 1.6145404065254463
Epoch loss : 1.612426353385672


 39%|███▊      | 386/1000 [01:05<01:42,  6.01it/s]

Epoch loss : 1.6103204399765312
Epoch loss : 1.608222207677951


 39%|███▉      | 388/1000 [01:05<01:39,  6.13it/s]

Epoch loss : 1.606131830321726
Epoch loss : 1.6040491612270935


 39%|███▉      | 390/1000 [01:05<01:36,  6.31it/s]

Epoch loss : 1.6019744378688403
Epoch loss : 1.5999074415679162


 39%|███▉      | 392/1000 [01:05<01:36,  6.32it/s]

Epoch loss : 1.5978483126703125
Epoch loss : 1.5957966830528207


 39%|███▉      | 394/1000 [01:06<01:36,  6.30it/s]

Epoch loss : 1.5937528064050748
Epoch loss : 1.5917161605072203


 40%|███▉      | 396/1000 [01:06<01:37,  6.22it/s]

Epoch loss : 1.5896874367655078
Epoch loss : 1.5876660008276955


 40%|███▉      | 398/1000 [01:06<01:35,  6.32it/s]

Epoch loss : 1.585652069345079
Epoch loss : 1.5836453409627753


 40%|████      | 400/1000 [01:07<01:34,  6.33it/s]

Epoch loss : 1.5816464142960713
Epoch loss : 1.5796545321822166


 40%|████      | 402/1000 [01:07<01:33,  6.41it/s]

Epoch loss : 1.5776699668704126
Epoch loss : 1.5756925088752858


 40%|████      | 404/1000 [01:07<01:32,  6.43it/s]

Epoch loss : 1.5737222532709834
Epoch loss : 1.5717594557196493


 41%|████      | 406/1000 [01:08<01:33,  6.37it/s]

Epoch loss : 1.5698038339062974
Epoch loss : 1.5678552741282092


 41%|████      | 408/1000 [01:08<01:31,  6.44it/s]

Epoch loss : 1.5659137710225963
Epoch loss : 1.5639789662283718


 41%|████      | 410/1000 [01:08<01:30,  6.54it/s]

Epoch loss : 1.5620516087526506
Epoch loss : 1.5601305790921536


 41%|████      | 412/1000 [01:09<01:29,  6.55it/s]

Epoch loss : 1.5582169826781953
Epoch loss : 1.5563100058759012


 41%|████▏     | 414/1000 [01:09<01:30,  6.48it/s]

Epoch loss : 1.5544103242193528
Epoch loss : 1.5525169283417977


 42%|████▏     | 416/1000 [01:09<01:29,  6.55it/s]

Epoch loss : 1.5506305089937635
Epoch loss : 1.548750783817556


 42%|████▏     | 418/1000 [01:10<01:29,  6.50it/s]

Epoch loss : 1.546877772966735
Epoch loss : 1.5450113841620359


 42%|████▏     | 420/1000 [01:10<01:29,  6.51it/s]

Epoch loss : 1.5431514586547395
Epoch loss : 1.5412983613965057


 42%|████▏     | 422/1000 [01:10<01:29,  6.45it/s]

Epoch loss : 1.539451656679926
Epoch loss : 1.5376115444945215


 42%|████▏     | 424/1000 [01:10<01:28,  6.48it/s]

Epoch loss : 1.5357780168446806
Epoch loss : 1.5339507628126807


 43%|████▎     | 426/1000 [01:11<01:28,  6.48it/s]

Epoch loss : 1.5321300516717573
Epoch loss : 1.5303156661210766


 43%|████▎     | 428/1000 [01:11<01:29,  6.38it/s]

Epoch loss : 1.5285079285781613
Epoch loss : 1.5267060628157352


 43%|████▎     | 430/1000 [01:11<01:30,  6.33it/s]

Epoch loss : 1.524910795227353
Epoch loss : 1.523121587896763


 43%|████▎     | 432/1000 [01:12<01:29,  6.34it/s]

Epoch loss : 1.5213386330195204
Epoch loss : 1.519561697723413


 43%|████▎     | 434/1000 [01:12<01:27,  6.46it/s]

Epoch loss : 1.5177911769827575
Epoch loss : 1.5160264565442014


 44%|████▎     | 436/1000 [01:12<01:27,  6.47it/s]

Epoch loss : 1.5142678261234843
Epoch loss : 1.5125154110051076


 44%|████▍     | 438/1000 [01:13<01:26,  6.53it/s]

Epoch loss : 1.510768905712347
Epoch loss : 1.5090284339763802


 44%|████▍     | 440/1000 [01:13<01:26,  6.45it/s]

Epoch loss : 1.5072940118722873
Epoch loss : 1.5055656904876231


 44%|████▍     | 442/1000 [01:13<01:25,  6.50it/s]

Epoch loss : 1.5038434623509578
Epoch loss : 1.5021269560797452


 44%|████▍     | 444/1000 [01:14<01:25,  6.54it/s]

Epoch loss : 1.500416793229214
Epoch loss : 1.4987118791063896


 45%|████▍     | 446/1000 [01:14<01:28,  6.25it/s]

Epoch loss : 1.4970129544406794
Epoch loss : 1.4953194859055927


 45%|████▍     | 448/1000 [01:14<01:35,  5.79it/s]

Epoch loss : 1.493632044960855
Epoch loss : 1.4919500754946577


 45%|████▌     | 450/1000 [01:15<01:37,  5.63it/s]

Epoch loss : 1.49027387290662
Epoch loss : 1.4886031708319982


 45%|████▌     | 452/1000 [01:15<01:40,  5.47it/s]

Epoch loss : 1.48693814128373
Epoch loss : 1.4852786730935352


 45%|████▌     | 454/1000 [01:15<01:43,  5.28it/s]

Epoch loss : 1.4836247748807565
Epoch loss : 1.4819764445544865


 46%|████▌     | 456/1000 [01:16<01:42,  5.31it/s]

Epoch loss : 1.480333580681911
Epoch loss : 1.478696119654597


 46%|████▌     | 458/1000 [01:16<01:43,  5.26it/s]

Epoch loss : 1.4770641617832352
Epoch loss : 1.475437385100855


 46%|████▌     | 459/1000 [01:16<01:42,  5.28it/s]

Epoch loss : 1.4738163184833941


 46%|████▌     | 461/1000 [01:17<01:43,  5.21it/s]

Epoch loss : 1.4722003319859505
Epoch loss : 1.4705901220486117


 46%|████▌     | 462/1000 [01:17<01:43,  5.18it/s]

Epoch loss : 1.46898492998246


 46%|████▋     | 463/1000 [01:17<01:46,  5.06it/s]

Epoch loss : 1.4673852077731562
Epoch loss : 1.4657905712775114

 46%|████▋     | 465/1000 [01:17<01:40,  5.31it/s]


Epoch loss : 1.464201180035709


 47%|████▋     | 467/1000 [01:18<01:33,  5.71it/s]

Epoch loss : 1.4626171675081416
Epoch loss : 1.4610383338648707


 47%|████▋     | 469/1000 [01:18<01:29,  5.92it/s]

Epoch loss : 1.4594649628581655
Epoch loss : 1.4578963802453042


 47%|████▋     | 471/1000 [01:18<01:25,  6.17it/s]

Epoch loss : 1.4563330462815913
Epoch loss : 1.4547748449748488


 47%|████▋     | 473/1000 [01:19<01:24,  6.21it/s]

Epoch loss : 1.4532217068231459
Epoch loss : 1.4516737327382883


 48%|████▊     | 475/1000 [01:19<01:25,  6.11it/s]

Epoch loss : 1.4501302844430575
Epoch loss : 1.4485921140733518


 48%|████▊     | 477/1000 [01:19<01:24,  6.16it/s]

Epoch loss : 1.4470587972351232
Epoch loss : 1.4455308054753069


 48%|████▊     | 479/1000 [01:20<01:22,  6.31it/s]

Epoch loss : 1.4440072544679732
Epoch loss : 1.4424887713609509


 48%|████▊     | 481/1000 [01:20<01:20,  6.41it/s]

Epoch loss : 1.4409749247487635
Epoch loss : 1.4394659549567903


 48%|████▊     | 483/1000 [01:20<01:21,  6.34it/s]

Epoch loss : 1.437961797887359
Epoch loss : 1.436462593000868


 48%|████▊     | 485/1000 [01:21<01:21,  6.33it/s]

Epoch loss : 1.434968211839268
Epoch loss : 1.4334787256680812


 49%|████▊     | 487/1000 [01:21<01:20,  6.36it/s]

Epoch loss : 1.4319939008263158
Epoch loss : 1.4305138166343652


 49%|████▉     | 489/1000 [01:21<01:22,  6.21it/s]

Epoch loss : 1.4290383023249567
Epoch loss : 1.4275674652279275


 49%|████▉     | 491/1000 [01:22<01:23,  6.10it/s]

Epoch loss : 1.4261011436466051
Epoch loss : 1.4246396761968763


 49%|████▉     | 493/1000 [01:22<01:22,  6.18it/s]

Epoch loss : 1.4231826087553812
Epoch loss : 1.4217307229349143


 50%|████▉     | 495/1000 [01:22<01:22,  6.16it/s]

Epoch loss : 1.4202829769814787
Epoch loss : 1.4188397289183405


 50%|████▉     | 497/1000 [01:23<01:20,  6.22it/s]

Epoch loss : 1.4174010046769774
Epoch loss : 1.4159667394092865


 50%|████▉     | 499/1000 [01:23<01:20,  6.24it/s]

Epoch loss : 1.4145368938884104
Epoch loss : 1.413111414151464


 50%|█████     | 501/1000 [01:23<01:19,  6.26it/s]

Epoch loss : 1.411690513037443
Epoch loss : 1.4102738855794041


 50%|█████     | 503/1000 [01:24<01:19,  6.24it/s]

Epoch loss : 1.408861716272109
Epoch loss : 1.4074537157846492


 50%|█████     | 505/1000 [01:24<01:19,  6.24it/s]

Epoch loss : 1.4060504446894166
Epoch loss : 1.404651449959467


 51%|█████     | 507/1000 [01:24<01:17,  6.35it/s]

Epoch loss : 1.4032566993333133
Epoch loss : 1.4018663319892433


 51%|█████     | 509/1000 [01:25<01:16,  6.44it/s]

Epoch loss : 1.4004803369065204
Epoch loss : 1.3990989634230704


 51%|█████     | 511/1000 [01:25<01:16,  6.42it/s]

Epoch loss : 1.3977215035248036
Epoch loss : 1.3963483158383818


 51%|█████▏    | 513/1000 [01:25<01:16,  6.37it/s]

Epoch loss : 1.3949789860908641
Epoch loss : 1.393613794190377


 52%|█████▏    | 515/1000 [01:25<01:16,  6.36it/s]

Epoch loss : 1.392252845989707
Epoch loss : 1.3908959434773158


 52%|█████▏    | 517/1000 [01:26<01:15,  6.39it/s]

Epoch loss : 1.389543332719179
Epoch loss : 1.3881947876550473


 52%|█████▏    | 519/1000 [01:26<01:16,  6.31it/s]

Epoch loss : 1.3868503409474513
Epoch loss : 1.3855099164555298


 52%|█████▏    | 521/1000 [01:26<01:17,  6.19it/s]

Epoch loss : 1.384173453582021
Epoch loss : 1.3828410898666692


 52%|█████▏    | 523/1000 [01:27<01:17,  6.17it/s]

Epoch loss : 1.3815125650709845
Epoch loss : 1.3801881342368418


 52%|█████▎    | 525/1000 [01:27<01:17,  6.10it/s]

Epoch loss : 1.3788676281853487
Epoch loss : 1.377551548153446


 53%|█████▎    | 527/1000 [01:27<01:22,  5.72it/s]

Epoch loss : 1.376239020210822
Epoch loss : 1.3749310226537697


 53%|█████▎    | 529/1000 [01:28<01:28,  5.31it/s]

Epoch loss : 1.373626524865853
Epoch loss : 1.372325994469831


 53%|█████▎    | 531/1000 [01:28<01:29,  5.25it/s]

Epoch loss : 1.3710291874155682
Epoch loss : 1.3697360564501722


 53%|█████▎    | 533/1000 [01:29<01:28,  5.27it/s]

Epoch loss : 1.3684467130540905
Epoch loss : 1.3671611906803347


 54%|█████▎    | 535/1000 [01:29<01:28,  5.24it/s]

Epoch loss : 1.3658794643905725
Epoch loss : 1.364601448469073


 54%|█████▎    | 536/1000 [01:29<01:28,  5.27it/s]

Epoch loss : 1.363327225238752


 54%|█████▍    | 538/1000 [01:30<01:29,  5.15it/s]

Epoch loss : 1.3620567580291678
Epoch loss : 1.3607901697863434


 54%|█████▍    | 540/1000 [01:30<01:29,  5.14it/s]

Epoch loss : 1.3595273866812683
Epoch loss : 1.358268411314046


 54%|█████▍    | 542/1000 [01:30<01:29,  5.14it/s]

Epoch loss : 1.3570131074798306
Epoch loss : 1.3557613770520776


 54%|█████▍    | 544/1000 [01:31<01:29,  5.08it/s]

Epoch loss : 1.354513404551022
Epoch loss : 1.3532689936935682


 55%|█████▍    | 546/1000 [01:31<01:24,  5.38it/s]

Epoch loss : 1.3520284186060276
Epoch loss : 1.3507914062592137


 55%|█████▍    | 548/1000 [01:31<01:20,  5.61it/s]

Epoch loss : 1.3495578827184558
Epoch loss : 1.3483279190065651


 55%|█████▌    | 550/1000 [01:32<01:16,  5.89it/s]

Epoch loss : 1.3471012039773016
Epoch loss : 1.3458782378283414


 55%|█████▌    | 552/1000 [01:32<01:14,  6.03it/s]

Epoch loss : 1.3446585462746732
Epoch loss : 1.3434423482272289


 55%|█████▌    | 554/1000 [01:32<01:13,  6.09it/s]

Epoch loss : 1.342229586396144
Epoch loss : 1.3410201919573739


 56%|█████▌    | 556/1000 [01:33<01:12,  6.09it/s]

Epoch loss : 1.339814472410056
Epoch loss : 1.3386121467109635


 56%|█████▌    | 558/1000 [01:33<01:11,  6.17it/s]

Epoch loss : 1.3374132627478836
Epoch loss : 1.336217772414821


 56%|█████▌    | 560/1000 [01:33<01:11,  6.12it/s]

Epoch loss : 1.3350259466983765
Epoch loss : 1.333837367461196


 56%|█████▌    | 562/1000 [01:34<01:10,  6.17it/s]

Epoch loss : 1.3326525743791764
Epoch loss : 1.3314709389909731


 56%|█████▋    | 564/1000 [01:34<01:10,  6.19it/s]

Epoch loss : 1.3302929448827334
Epoch loss : 1.329118034535266


 57%|█████▋    | 566/1000 [01:34<01:10,  6.17it/s]

Epoch loss : 1.3279461566378585
Epoch loss : 1.3267777869500246


 57%|█████▋    | 568/1000 [01:35<01:09,  6.21it/s]

Epoch loss : 1.3256125405157475
Epoch loss : 1.324451035222327


 57%|█████▋    | 570/1000 [01:35<01:07,  6.39it/s]

Epoch loss : 1.3232923391304243
Epoch loss : 1.3221369832748906


 57%|█████▋    | 572/1000 [01:35<01:06,  6.45it/s]

Epoch loss : 1.3209845443777153
Epoch loss : 1.3198356203646302


 57%|█████▋    | 574/1000 [01:36<01:06,  6.37it/s]

Epoch loss : 1.3186897494822065
Epoch loss : 1.3175472190851534


 58%|█████▊    | 576/1000 [01:36<01:06,  6.33it/s]

Epoch loss : 1.316407869514175
Epoch loss : 1.3152716370283937


 58%|█████▊    | 578/1000 [01:36<01:06,  6.39it/s]

Epoch loss : 1.3141387070816146
Epoch loss : 1.3130088491995648


 58%|█████▊    | 580/1000 [01:37<01:06,  6.35it/s]

Epoch loss : 1.3118826469377534
Epoch loss : 1.3107592473770011


 58%|█████▊    | 582/1000 [01:37<01:06,  6.28it/s]

Epoch loss : 1.3096388968050787
Epoch loss : 1.3085212592903486


 58%|█████▊    | 584/1000 [01:37<01:06,  6.23it/s]

Epoch loss : 1.3074067610998097
Epoch loss : 1.3062952406189008


 59%|█████▊    | 586/1000 [01:38<01:06,  6.23it/s]

Epoch loss : 1.3051868326312457
Epoch loss : 1.3040815530529202


 59%|█████▉    | 588/1000 [01:38<01:05,  6.30it/s]

Epoch loss : 1.302979563495471
Epoch loss : 1.3018805965233822


 59%|█████▉    | 590/1000 [01:38<01:06,  6.19it/s]

Epoch loss : 1.3007843959437768
Epoch loss : 1.2996911468096708


 59%|█████▉    | 592/1000 [01:39<01:06,  6.15it/s]

Epoch loss : 1.2986010829123547
Epoch loss : 1.2975137478646797


 59%|█████▉    | 594/1000 [01:39<01:04,  6.29it/s]

Epoch loss : 1.296429434937329
Epoch loss : 1.295347988933626


 60%|█████▉    | 596/1000 [01:39<01:03,  6.33it/s]

Epoch loss : 1.294269649951398
Epoch loss : 1.2931943862990245


 60%|█████▉    | 598/1000 [01:39<01:04,  6.23it/s]

Epoch loss : 1.2921220109540614
Epoch loss : 1.2910523485399807


 60%|██████    | 600/1000 [01:40<01:04,  6.24it/s]

Epoch loss : 1.28998551974611
Epoch loss : 1.2889215804621577


 60%|██████    | 602/1000 [01:40<01:03,  6.23it/s]

Epoch loss : 1.2878605231419578
Epoch loss : 1.2868022042453486


 60%|██████    | 604/1000 [01:40<01:02,  6.38it/s]

Epoch loss : 1.2857468545135375
Epoch loss : 1.2846940580814683


 61%|██████    | 606/1000 [01:41<01:02,  6.29it/s]

Epoch loss : 1.2836444122687845
Epoch loss : 1.2825973035768905


 61%|██████    | 608/1000 [01:41<01:07,  5.80it/s]

Epoch loss : 1.2815534085010776
Epoch loss : 1.2805119658281143


 61%|██████    | 610/1000 [01:42<01:11,  5.43it/s]

Epoch loss : 1.2794736921562153
Epoch loss : 1.2784379217023731


 61%|██████    | 612/1000 [01:42<01:13,  5.28it/s]

Epoch loss : 1.2774047839319647
Epoch loss : 1.2763744382662516


 61%|██████▏   | 614/1000 [01:42<01:13,  5.23it/s]

Epoch loss : 1.2753465933570272
Epoch loss : 1.274321579981219


 62%|██████▏   | 616/1000 [01:43<01:13,  5.23it/s]

Epoch loss : 1.2732991919415753
Epoch loss : 1.2722796371879128


 62%|██████▏   | 618/1000 [01:43<01:12,  5.25it/s]

Epoch loss : 1.2712626595855532
Epoch loss : 1.2702484423983058


 62%|██████▏   | 620/1000 [01:43<01:13,  5.20it/s]

Epoch loss : 1.2692368055829517
Epoch loss : 1.2682278220595853


 62%|██████▏   | 621/1000 [01:44<01:14,  5.10it/s]

Epoch loss : 1.2672214210844461


 62%|██████▏   | 623/1000 [01:44<01:15,  4.97it/s]

Epoch loss : 1.266217592585221
Epoch loss : 1.2652165102657307


 62%|██████▎   | 625/1000 [01:44<01:13,  5.11it/s]

Epoch loss : 1.264217950100891
Epoch loss : 1.2632223015375137


 63%|██████▎   | 627/1000 [01:45<01:07,  5.50it/s]

Epoch loss : 1.262229245050837
Epoch loss : 1.261238799008171


 63%|██████▎   | 629/1000 [01:45<01:03,  5.81it/s]

Epoch loss : 1.2602507370824267
Epoch loss : 1.259265437001549


 63%|██████▎   | 631/1000 [01:45<01:01,  6.04it/s]

Epoch loss : 1.2582825017434265
Epoch loss : 1.2573022536411527


 63%|██████▎   | 633/1000 [01:46<00:59,  6.15it/s]

Epoch loss : 1.2563242388950495
Epoch loss : 1.2553489275377885


 64%|██████▎   | 635/1000 [01:46<00:57,  6.32it/s]

Epoch loss : 1.2543759190627266
Epoch loss : 1.2534054928895995


 64%|██████▎   | 637/1000 [01:46<00:57,  6.31it/s]

Epoch loss : 1.2524375286923264
Epoch loss : 1.2514719656390523


 64%|██████▍   | 639/1000 [01:47<00:57,  6.28it/s]

Epoch loss : 1.2505091523269314
Epoch loss : 1.2495485263204724


 64%|██████▍   | 641/1000 [01:47<00:57,  6.28it/s]

Epoch loss : 1.2485904451394454
Epoch loss : 1.247634573754115


 64%|██████▍   | 643/1000 [01:47<00:57,  6.26it/s]

Epoch loss : 1.2466811984146124
Epoch loss : 1.245730254156037


 64%|██████▍   | 645/1000 [01:48<00:55,  6.34it/s]

Epoch loss : 1.2447816078397242
Epoch loss : 1.2438355698035668


 65%|██████▍   | 647/1000 [01:48<00:57,  6.15it/s]

Epoch loss : 1.2428916687059328
Epoch loss : 1.2419503975554818


 65%|██████▍   | 649/1000 [01:48<00:56,  6.16it/s]

Epoch loss : 1.2410113805325496
Epoch loss : 1.2400748430667747


 65%|██████▌   | 651/1000 [01:49<00:56,  6.16it/s]

Epoch loss : 1.2391406316133646
Epoch loss : 1.2382089021545584


 65%|██████▌   | 653/1000 [01:49<00:58,  5.96it/s]

Epoch loss : 1.2372795825844711
Epoch loss : 1.2363528341844472


 66%|██████▌   | 655/1000 [01:49<00:56,  6.09it/s]

Epoch loss : 1.2354281172252028
Epoch loss : 1.234505807174981


 66%|██████▌   | 657/1000 [01:50<00:55,  6.18it/s]

Epoch loss : 1.2335858489500313
Epoch loss : 1.232668046352526


 66%|██████▌   | 659/1000 [01:50<00:55,  6.17it/s]

Epoch loss : 1.2317525059949481
Epoch loss : 1.2308391473497715


 66%|██████▌   | 661/1000 [01:50<00:55,  6.14it/s]

Epoch loss : 1.229928040786223
Epoch loss : 1.2290191397927532


 66%|██████▋   | 663/1000 [01:51<00:54,  6.21it/s]

Epoch loss : 1.228112715633524
Epoch loss : 1.2272084841370403


 66%|██████▋   | 665/1000 [01:51<00:53,  6.23it/s]

Epoch loss : 1.2263064297405353
Epoch loss : 1.2254063879291814


 67%|██████▋   | 667/1000 [01:51<00:53,  6.19it/s]

Epoch loss : 1.2245086101207647
Epoch loss : 1.2236129218045024


 67%|██████▋   | 669/1000 [01:52<00:53,  6.23it/s]

Epoch loss : 1.2227194740275245
Epoch loss : 1.221828198806231


 67%|██████▋   | 671/1000 [01:52<00:53,  6.13it/s]

Epoch loss : 1.2209393306156593
Epoch loss : 1.2200525336569361


 67%|██████▋   | 673/1000 [01:52<00:53,  6.15it/s]

Epoch loss : 1.2191681080466756
Epoch loss : 1.2182858080947205


 68%|██████▊   | 675/1000 [01:53<00:52,  6.23it/s]

Epoch loss : 1.2174058716936529
Epoch loss : 1.2165277662767304


 68%|██████▊   | 677/1000 [01:53<00:51,  6.32it/s]

Epoch loss : 1.2156520895655516
Epoch loss : 1.214778442083909


 68%|██████▊   | 679/1000 [01:53<00:50,  6.39it/s]

Epoch loss : 1.213906858897051
Epoch loss : 1.2130372898813901


 68%|██████▊   | 681/1000 [01:53<00:50,  6.36it/s]

Epoch loss : 1.2121699757860864
Epoch loss : 1.2113048699616336


 68%|██████▊   | 683/1000 [01:54<00:50,  6.26it/s]

Epoch loss : 1.210441580198902
Epoch loss : 1.2095803397143394


 68%|██████▊   | 685/1000 [01:54<00:50,  6.29it/s]

Epoch loss : 1.2087212205068236
Epoch loss : 1.207864137760914


 69%|██████▊   | 687/1000 [01:54<00:52,  5.94it/s]

Epoch loss : 1.2070094470866568
Epoch loss : 1.2061564625398387


 69%|██████▉   | 689/1000 [01:55<00:55,  5.56it/s]

Epoch loss : 1.2053053462267096
Epoch loss : 1.2044563613384314


 69%|██████▉   | 691/1000 [01:55<00:57,  5.39it/s]

Epoch loss : 1.2036093211428842
Epoch loss : 1.2027646420415161


 69%|██████▉   | 692/1000 [01:55<00:58,  5.31it/s]

Epoch loss : 1.2019219628497528
Epoch loss : 1.2010814495369642


 70%|██████▉   | 695/1000 [01:56<01:00,  5.07it/s]

Epoch loss : 1.2002426766065253
Epoch loss : 1.1994058127047347


 70%|██████▉   | 697/1000 [01:56<00:59,  5.11it/s]

Epoch loss : 1.1985709045857362
Epoch loss : 1.1977381043416184


 70%|██████▉   | 699/1000 [01:57<00:57,  5.19it/s]

Epoch loss : 1.1969073437278113
Epoch loss : 1.1960790136119668


 70%|███████   | 701/1000 [01:57<00:57,  5.17it/s]

Epoch loss : 1.1952521366583448
Epoch loss : 1.1944272910217926


 70%|███████   | 703/1000 [01:58<00:58,  5.06it/s]

Epoch loss : 1.1936042178870097
Epoch loss : 1.1927830449386816


 70%|███████   | 705/1000 [01:58<00:56,  5.20it/s]

Epoch loss : 1.1919638376488266
Epoch loss : 1.1911468414634678


 71%|███████   | 707/1000 [01:58<00:51,  5.69it/s]

Epoch loss : 1.190331511322736
Epoch loss : 1.1895182520960481


 71%|███████   | 709/1000 [01:59<00:49,  5.94it/s]

Epoch loss : 1.1887067748496762
Epoch loss : 1.1878972858030805


 71%|███████   | 711/1000 [01:59<00:47,  6.12it/s]

Epoch loss : 1.1870894432214785
Epoch loss : 1.1862834052726856


 71%|███████▏  | 713/1000 [01:59<00:48,  5.94it/s]

Epoch loss : 1.185479169440236
Epoch loss : 1.1846767724584397


 72%|███████▏  | 715/1000 [02:00<00:46,  6.08it/s]

Epoch loss : 1.1838763824873277
Epoch loss : 1.18307794066951


 72%|███████▏  | 717/1000 [02:00<00:45,  6.18it/s]

Epoch loss : 1.1822816043796653
Epoch loss : 1.1814870404586133


 72%|███████▏  | 719/1000 [02:00<00:44,  6.29it/s]

Epoch loss : 1.1806944060275815
Epoch loss : 1.1799037920554092


 72%|███████▏  | 721/1000 [02:01<00:43,  6.37it/s]

Epoch loss : 1.1791149853753546
Epoch loss : 1.1783278121477854


 72%|███████▏  | 723/1000 [02:01<00:43,  6.37it/s]

Epoch loss : 1.177542930164083
Epoch loss : 1.1767592964326528


 72%|███████▎  | 725/1000 [02:01<00:43,  6.31it/s]

Epoch loss : 1.1759775078547758
Epoch loss : 1.1751970486735475


 73%|███████▎  | 727/1000 [02:02<00:44,  6.18it/s]

Epoch loss : 1.1744183752671418
Epoch loss : 1.173641366665924


 73%|███████▎  | 729/1000 [02:02<00:43,  6.25it/s]

Epoch loss : 1.1728661477115456
Epoch loss : 1.1720928472157859


 73%|███████▎  | 731/1000 [02:02<00:42,  6.28it/s]

Epoch loss : 1.1713213155212467
Epoch loss : 1.1705516279232975


 73%|███████▎  | 733/1000 [02:02<00:42,  6.33it/s]

Epoch loss : 1.1697837157482331
Epoch loss : 1.1690179258503661


 74%|███████▎  | 735/1000 [02:03<00:42,  6.21it/s]

Epoch loss : 1.1682538528906876
Epoch loss : 1.167491437911582


 74%|███████▎  | 737/1000 [02:03<00:42,  6.22it/s]

Epoch loss : 1.1667305429474166
Epoch loss : 1.1659714324802686


 74%|███████▍  | 739/1000 [02:03<00:42,  6.21it/s]

Epoch loss : 1.1652137964576241
Epoch loss : 1.1644578005242412


 74%|███████▍  | 741/1000 [02:04<00:41,  6.19it/s]

Epoch loss : 1.1637035144611791
Epoch loss : 1.1629510279693585


 74%|███████▍  | 743/1000 [02:04<00:41,  6.13it/s]

Epoch loss : 1.1622002750848022
Epoch loss : 1.1614513964620123


 74%|███████▍  | 745/1000 [02:04<00:41,  6.13it/s]

Epoch loss : 1.1607041725456233
Epoch loss : 1.1599591370424969


 75%|███████▍  | 747/1000 [02:05<00:40,  6.23it/s]

Epoch loss : 1.159215601003921
Epoch loss : 1.1584742451843169


 75%|███████▍  | 749/1000 [02:05<00:39,  6.29it/s]

Epoch loss : 1.1577340257530384
Epoch loss : 1.1569954522914976


 75%|███████▌  | 751/1000 [02:05<00:39,  6.24it/s]

Epoch loss : 1.1562579548609258
Epoch loss : 1.1555221025869944


 75%|███████▌  | 753/1000 [02:06<00:39,  6.30it/s]

Epoch loss : 1.1547876728397417
Epoch loss : 1.1540550702169121


 76%|███████▌  | 755/1000 [02:06<00:39,  6.27it/s]

Epoch loss : 1.1533241151553884
Epoch loss : 1.1525947030534018


 76%|███████▌  | 757/1000 [02:06<00:39,  6.14it/s]

Epoch loss : 1.151866875315706
Epoch loss : 1.151140805422236


 76%|███████▌  | 759/1000 [02:07<00:38,  6.27it/s]

Epoch loss : 1.1504163178872306
Epoch loss : 1.1496933813124148


 76%|███████▌  | 761/1000 [02:07<00:37,  6.36it/s]

Epoch loss : 1.1489721157895891
Epoch loss : 1.1482528222014494


 76%|███████▋  | 763/1000 [02:07<00:37,  6.26it/s]

Epoch loss : 1.1475351972823225
Epoch loss : 1.1468191239413552


 76%|███████▋  | 765/1000 [02:08<00:37,  6.24it/s]

Epoch loss : 1.1461043778760114
Epoch loss : 1.1453911756767947


 77%|███████▋  | 767/1000 [02:08<00:40,  5.77it/s]

Epoch loss : 1.1446793821753627
Epoch loss : 1.1439689852291477


 77%|███████▋  | 769/1000 [02:08<00:42,  5.39it/s]

Epoch loss : 1.1432602396425016
Epoch loss : 1.1425533432900052


 77%|███████▋  | 771/1000 [02:09<00:42,  5.40it/s]

Epoch loss : 1.1418476470132153
Epoch loss : 1.1411436149342669


 77%|███████▋  | 773/1000 [02:09<00:42,  5.28it/s]

Epoch loss : 1.1404409007450151
Epoch loss : 1.1397397038151653


 78%|███████▊  | 775/1000 [02:10<00:43,  5.20it/s]

Epoch loss : 1.1390400660619409
Epoch loss : 1.1383421496402832


 78%|███████▊  | 776/1000 [02:10<00:43,  5.14it/s]

Epoch loss : 1.1376458736343944


 78%|███████▊  | 778/1000 [02:10<00:43,  5.12it/s]

Epoch loss : 1.1369512134713708
Epoch loss : 1.1362582384187796


 78%|███████▊  | 780/1000 [02:11<00:43,  5.02it/s]

Epoch loss : 1.1355663021258886
Epoch loss : 1.134875827838595


 78%|███████▊  | 782/1000 [02:11<00:42,  5.07it/s]

Epoch loss : 1.1341866311537776
Epoch loss : 1.13349907350624


 78%|███████▊  | 784/1000 [02:11<00:43,  5.02it/s]

Epoch loss : 1.1328129019313236
Epoch loss : 1.132128289692028


 79%|███████▊  | 786/1000 [02:12<00:38,  5.50it/s]

Epoch loss : 1.1314449802887667
Epoch loss : 1.1307631614451645


 79%|███████▉  | 788/1000 [02:12<00:36,  5.82it/s]

Epoch loss : 1.1300828217733012
Epoch loss : 1.129404090560481


 79%|███████▉  | 790/1000 [02:12<00:34,  6.06it/s]

Epoch loss : 1.1287267670632164
Epoch loss : 1.1280513336149198


 79%|███████▉  | 792/1000 [02:13<00:34,  6.00it/s]

Epoch loss : 1.1273773568812122
Epoch loss : 1.1267046727702925


 79%|███████▉  | 794/1000 [02:13<00:33,  6.08it/s]

Epoch loss : 1.1260332830500182
Epoch loss : 1.1253633103971967


 80%|███████▉  | 796/1000 [02:13<00:33,  6.18it/s]

Epoch loss : 1.1246943214903087
Epoch loss : 1.124026662515261


 80%|███████▉  | 798/1000 [02:14<00:32,  6.18it/s]

Epoch loss : 1.1233607963980217
Epoch loss : 1.1226961602352765


 80%|████████  | 800/1000 [02:14<00:32,  6.17it/s]

Epoch loss : 1.1220333428458964
Epoch loss : 1.1213713533680887


 80%|████████  | 802/1000 [02:14<00:32,  6.18it/s]

Epoch loss : 1.1207109102137973
Epoch loss : 1.1200518939582784


 80%|████████  | 804/1000 [02:15<00:32,  6.05it/s]

Epoch loss : 1.1193942782872937
Epoch loss : 1.1187384859238987


 81%|████████  | 806/1000 [02:15<00:31,  6.08it/s]

Epoch loss : 1.118083948816572
Epoch loss : 1.1174306892143289


 81%|████████  | 808/1000 [02:15<00:31,  6.14it/s]

Epoch loss : 1.1167787054757203
Epoch loss : 1.1161278772590184


 81%|████████  | 810/1000 [02:16<00:31,  6.09it/s]

Epoch loss : 1.1154783622100857
Epoch loss : 1.1148301206974336


 81%|████████  | 812/1000 [02:16<00:30,  6.19it/s]

Epoch loss : 1.114183309101735
Epoch loss : 1.1135379213013965


 81%|████████▏ | 814/1000 [02:16<00:29,  6.28it/s]

Epoch loss : 1.112893861848563
Epoch loss : 1.1122510592925021


 82%|████████▏ | 816/1000 [02:17<00:29,  6.18it/s]

Epoch loss : 1.1116096252616199
Epoch loss : 1.1109698399320682


 82%|████████▏ | 818/1000 [02:17<00:30,  5.92it/s]

Epoch loss : 1.1103314959813215
Epoch loss : 1.1096946063258828


 82%|████████▏ | 820/1000 [02:17<00:32,  5.51it/s]

Epoch loss : 1.109058937840846
Epoch loss : 1.1084244364641789


 82%|████████▏ | 821/1000 [02:18<00:33,  5.34it/s]

Epoch loss : 1.1077913854488066


 82%|████████▏ | 823/1000 [02:18<00:34,  5.16it/s]

Epoch loss : 1.1071592353074076
Epoch loss : 1.106528640480334


 82%|████████▎ | 825/1000 [02:18<00:34,  5.05it/s]

Epoch loss : 1.1058989083582482
Epoch loss : 1.1052708202445145


 83%|████████▎ | 827/1000 [02:19<00:33,  5.14it/s]

Epoch loss : 1.1046438718228162
Epoch loss : 1.1040181907333948


 83%|████████▎ | 829/1000 [02:19<00:33,  5.08it/s]

Epoch loss : 1.1033935222529991
Epoch loss : 1.1027700887061431


 83%|████████▎ | 831/1000 [02:20<00:33,  5.09it/s]

Epoch loss : 1.102147889965988
Epoch loss : 1.1015269318111824


 83%|████████▎ | 833/1000 [02:20<00:32,  5.11it/s]

Epoch loss : 1.1009074700448231
Epoch loss : 1.100289424361945


 84%|████████▎ | 835/1000 [02:20<00:32,  5.13it/s]

Epoch loss : 1.0996726754736557
Epoch loss : 1.0990571802872384


 84%|████████▎ | 837/1000 [02:21<00:29,  5.58it/s]

Epoch loss : 1.098442854385413
Epoch loss : 1.0978298017357697


 84%|████████▍ | 839/1000 [02:21<00:27,  5.80it/s]

Epoch loss : 1.0972176956596977
Epoch loss : 1.096607038260785


 84%|████████▍ | 841/1000 [02:21<00:26,  5.93it/s]

Epoch loss : 1.0959973873627327
Epoch loss : 1.095389273263399


 84%|████████▍ | 843/1000 [02:22<00:28,  5.42it/s]

Epoch loss : 1.094782323133988
Epoch loss : 1.0941768548405184


 84%|████████▍ | 845/1000 [02:22<00:29,  5.30it/s]

Epoch loss : 1.0935723564378317
Epoch loss : 1.0929690746633258


 85%|████████▍ | 846/1000 [02:22<00:29,  5.20it/s]

Epoch loss : 1.0923671885169426


 85%|████████▍ | 847/1000 [02:22<00:30,  5.06it/s]

Epoch loss : 1.0917665910896754
Epoch loss : 1.0911673064681016


 85%|████████▍ | 849/1000 [02:23<00:29,  5.06it/s]

Epoch loss : 1.0905691693461685


 85%|████████▌ | 850/1000 [02:23<00:29,  5.02it/s]

Epoch loss : 1.0899720505191999


 85%|████████▌ | 852/1000 [02:23<00:29,  5.05it/s]

Epoch loss : 1.0893763612576712
Epoch loss : 1.0887811453724412


 85%|████████▌ | 853/1000 [02:24<00:28,  5.09it/s]

Epoch loss : 1.0881872725347002


 86%|████████▌ | 855/1000 [02:24<00:28,  5.09it/s]

Epoch loss : 1.0875944097130723
Epoch loss : 1.087002704084268


 86%|████████▌ | 857/1000 [02:24<00:27,  5.15it/s]

Epoch loss : 1.0864122368806155
Epoch loss : 1.0858231328019183


 86%|████████▌ | 858/1000 [02:25<00:27,  5.13it/s]

Epoch loss : 1.0852352483212808


 86%|████████▌ | 860/1000 [02:25<00:27,  5.12it/s]

Epoch loss : 1.0846483811569574
Epoch loss : 1.0840627824858178


 86%|████████▌ | 862/1000 [02:25<00:24,  5.59it/s]

Epoch loss : 1.0834785047699627
Epoch loss : 1.0828955403533862


 86%|████████▋ | 864/1000 [02:26<00:23,  5.85it/s]

Epoch loss : 1.0823136073381987
Epoch loss : 1.0817326427340783


 87%|████████▋ | 866/1000 [02:26<00:22,  6.02it/s]

Epoch loss : 1.0811528436703488
Epoch loss : 1.0805740450642676


 87%|████████▋ | 868/1000 [02:26<00:21,  6.04it/s]

Epoch loss : 1.0799967203987924
Epoch loss : 1.0794205231422103


 87%|████████▋ | 870/1000 [02:27<00:21,  6.06it/s]

Epoch loss : 1.0788456997206763
Epoch loss : 1.0782715697562557


 87%|████████▋ | 872/1000 [02:27<00:20,  6.12it/s]

Epoch loss : 1.077698738271826
Epoch loss : 1.0771266573072846


 87%|████████▋ | 874/1000 [02:27<00:20,  6.18it/s]

Epoch loss : 1.0765558107393578
Epoch loss : 1.0759857550212257


 88%|████████▊ | 876/1000 [02:28<00:20,  6.15it/s]

Epoch loss : 1.075417225350993
Epoch loss : 1.074849732132739


 88%|████████▊ | 878/1000 [02:28<00:19,  6.23it/s]

Epoch loss : 1.0742836008706005
Epoch loss : 1.0737182734723378


 88%|████████▊ | 880/1000 [02:28<00:18,  6.38it/s]

Epoch loss : 1.073154127889051
Epoch loss : 1.0725908114398746


 88%|████████▊ | 882/1000 [02:29<00:18,  6.31it/s]

Epoch loss : 1.0720286076942989
Epoch loss : 1.0714673419742762


 88%|████████▊ | 884/1000 [02:29<00:18,  6.19it/s]

Epoch loss : 1.070907252782941
Epoch loss : 1.0703481649954665


 89%|████████▊ | 886/1000 [02:29<00:18,  6.18it/s]

Epoch loss : 1.0697903355009812
Epoch loss : 1.0692335614844586


 89%|████████▉ | 888/1000 [02:30<00:18,  6.07it/s]

Epoch loss : 1.0686778992743744
Epoch loss : 1.068123281802143


 89%|████████▉ | 890/1000 [02:30<00:17,  6.13it/s]

Epoch loss : 1.0675699907977079
Epoch loss : 1.0670176858741245


 89%|████████▉ | 892/1000 [02:30<00:17,  6.22it/s]

Epoch loss : 1.0664665371918784
Epoch loss : 1.065916859680628


 89%|████████▉ | 894/1000 [02:31<00:16,  6.27it/s]

Epoch loss : 1.065368587395045
Epoch loss : 1.0648208367424523


 90%|████████▉ | 896/1000 [02:31<00:16,  6.23it/s]

Epoch loss : 1.0642742273304715
Epoch loss : 1.063728192618915


 90%|████████▉ | 898/1000 [02:31<00:16,  6.23it/s]

Epoch loss : 1.0631832817419449
Epoch loss : 1.0626391664329908


 90%|█████████ | 900/1000 [02:32<00:16,  6.08it/s]

Epoch loss : 1.0620961017298354
Epoch loss : 1.0615537785841358


 90%|█████████ | 902/1000 [02:32<00:16,  6.10it/s]

Epoch loss : 1.0610125818206917
Epoch loss : 1.0604722390299097


 90%|█████████ | 904/1000 [02:32<00:15,  6.09it/s]

Epoch loss : 1.0599329833691302
Epoch loss : 1.0593947999670574


 91%|█████████ | 906/1000 [02:33<00:15,  6.06it/s]

Epoch loss : 1.058857577358491
Epoch loss : 1.0583217590376242


 91%|█████████ | 908/1000 [02:33<00:15,  6.11it/s]

Epoch loss : 1.057786989335254
Epoch loss : 1.057253394570973


 91%|█████████ | 910/1000 [02:33<00:14,  6.17it/s]

Epoch loss : 1.0567207366729727
Epoch loss : 1.0561891292169854


 91%|█████████ | 912/1000 [02:34<00:14,  6.13it/s]

Epoch loss : 1.0556584628798173
Epoch loss : 1.0551289099376453


 91%|█████████▏| 914/1000 [02:34<00:13,  6.20it/s]

Epoch loss : 1.054600038059374
Epoch loss : 1.0540720541370907


 92%|█████████▏| 916/1000 [02:34<00:13,  6.16it/s]

Epoch loss : 1.0535452706761699
Epoch loss : 1.0530191705899727


 92%|█████████▏| 918/1000 [02:34<00:13,  6.06it/s]

Epoch loss : 1.052494159601507
Epoch loss : 1.0519700857217795


 92%|█████████▏| 920/1000 [02:35<00:13,  5.94it/s]

Epoch loss : 1.0514470675620082
Epoch loss : 1.050925148313784


 92%|█████████▏| 922/1000 [02:35<00:14,  5.54it/s]

Epoch loss : 1.0504044067930838
Epoch loss : 1.0498853754173003


 92%|█████████▏| 924/1000 [02:36<00:14,  5.26it/s]

Epoch loss : 1.049366483587937
Epoch loss : 1.0488486842717573


 92%|█████████▎| 925/1000 [02:36<00:14,  5.19it/s]

Epoch loss : 1.0483315347162452


 93%|█████████▎| 927/1000 [02:36<00:14,  5.09it/s]

Epoch loss : 1.0478150913782305
Epoch loss : 1.0472997052615514


 93%|█████████▎| 929/1000 [02:37<00:14,  5.00it/s]

Epoch loss : 1.0467852207999033
Epoch loss : 1.046271867662251


 93%|█████████▎| 931/1000 [02:37<00:13,  5.00it/s]

Epoch loss : 1.045759335481031
Epoch loss : 1.045247653312381


 93%|█████████▎| 932/1000 [02:37<00:13,  4.94it/s]

Epoch loss : 1.0447367008996111


 93%|█████████▎| 934/1000 [02:38<00:13,  4.98it/s]

Epoch loss : 1.0442268309626421
Epoch loss : 1.0437176969411286


 94%|█████████▎| 936/1000 [02:38<00:12,  5.07it/s]

Epoch loss : 1.0432097862015433
Epoch loss : 1.0427027173114256


 94%|█████████▍| 938/1000 [02:38<00:12,  4.98it/s]

Epoch loss : 1.0421969744025579
Epoch loss : 1.041692281686294


 94%|█████████▍| 940/1000 [02:39<00:11,  5.28it/s]

Epoch loss : 1.0411884675848597
Epoch loss : 1.0406852460129463


 94%|█████████▍| 942/1000 [02:39<00:10,  5.55it/s]

Epoch loss : 1.040183066676641
Epoch loss : 1.0396817330550996


 94%|█████████▍| 944/1000 [02:39<00:09,  5.71it/s]

Epoch loss : 1.039181479138825
Epoch loss : 1.0386820518272786


 95%|█████████▍| 946/1000 [02:40<00:09,  5.89it/s]

Epoch loss : 1.0381840735163008
Epoch loss : 1.037686646759636


 95%|█████████▍| 948/1000 [02:40<00:08,  5.88it/s]

Epoch loss : 1.0371902473716197
Epoch loss : 1.0366941989147211


 95%|█████████▌| 950/1000 [02:40<00:08,  5.95it/s]

Epoch loss : 1.036199345515513
Epoch loss : 1.0357050565085912


 95%|█████████▌| 952/1000 [02:41<00:07,  6.06it/s]

Epoch loss : 1.0352118253118734
Epoch loss : 1.034719603547529


 95%|█████████▌| 954/1000 [02:41<00:07,  6.06it/s]

Epoch loss : 1.0342281286363713
Epoch loss : 1.0337375305571146


 96%|█████████▌| 956/1000 [02:41<00:07,  6.08it/s]

Epoch loss : 1.0332478105141853
Epoch loss : 1.032759013887978


 96%|█████████▌| 958/1000 [02:42<00:06,  6.01it/s]

Epoch loss : 1.0322709032435402
Epoch loss : 1.031783791963081


 96%|█████████▌| 960/1000 [02:42<00:06,  6.00it/s]

Epoch loss : 1.0312974600716105
Epoch loss : 1.0308120878621314


 96%|█████████▌| 962/1000 [02:42<00:06,  5.96it/s]

Epoch loss : 1.0303275908708076
Epoch loss : 1.0298441202845618


 96%|█████████▋| 964/1000 [02:43<00:05,  6.04it/s]

Epoch loss : 1.0293613670987627
Epoch loss : 1.0288794742196923


 97%|█████████▋| 966/1000 [02:43<00:05,  6.07it/s]

Epoch loss : 1.028398542005781
Epoch loss : 1.0279184957891394


 97%|█████████▋| 968/1000 [02:43<00:05,  6.05it/s]

Epoch loss : 1.027439189539084
Epoch loss : 1.0269611864004375


 97%|█████████▋| 970/1000 [02:44<00:04,  6.11it/s]

Epoch loss : 1.0264837699399645
Epoch loss : 1.026007295995029


 97%|█████████▋| 972/1000 [02:44<00:04,  6.08it/s]

Epoch loss : 1.0255317002697537
Epoch loss : 1.0250570793830074


 97%|█████████▋| 974/1000 [02:44<00:04,  5.98it/s]

Epoch loss : 1.0245830091785062
Epoch loss : 1.024110008387901


 98%|█████████▊| 976/1000 [02:45<00:03,  6.09it/s]

Epoch loss : 1.0236374879075931
Epoch loss : 1.023165830456209


 98%|█████████▊| 978/1000 [02:45<00:03,  6.07it/s]

Epoch loss : 1.0226947470703291
Epoch loss : 1.0222245094756042


 98%|█████████▊| 980/1000 [02:45<00:03,  6.01it/s]

Epoch loss : 1.0217552627441955
Epoch loss : 1.0212867671482417


 98%|█████████▊| 982/1000 [02:46<00:02,  6.10it/s]

Epoch loss : 1.0208191491993075
Epoch loss : 1.0203524706495999


 98%|█████████▊| 984/1000 [02:46<00:02,  6.11it/s]

Epoch loss : 1.0198864782077857
Epoch loss : 1.0194214693814274


 99%|█████████▊| 986/1000 [02:46<00:02,  6.00it/s]

Epoch loss : 1.0189570206742602
Epoch loss : 1.0184934698187313


 99%|█████████▉| 988/1000 [02:47<00:01,  6.11it/s]

Epoch loss : 1.0180305058974868
Epoch loss : 1.017568467175852


 99%|█████████▉| 990/1000 [02:47<00:01,  6.21it/s]

Epoch loss : 1.0171073068227878
Epoch loss : 1.0166469814708137


 99%|█████████▉| 992/1000 [02:47<00:01,  6.14it/s]

Epoch loss : 1.0161875459858556
Epoch loss : 1.0157290610839282


 99%|█████████▉| 994/1000 [02:48<00:00,  6.25it/s]

Epoch loss : 1.0152713979401018
Epoch loss : 1.0148149546072396


100%|█████████▉| 996/1000 [02:48<00:00,  6.22it/s]

Epoch loss : 1.0143585111924152
Epoch loss : 1.0139028882489625


100%|█████████▉| 998/1000 [02:48<00:00,  6.13it/s]

Epoch loss : 1.013448056762173
Epoch loss : 1.012994520019792


100%|█████████▉| 999/1000 [02:49<00:00,  6.16it/s]

Epoch loss : 1.0125415116056307
Epoch loss : 1.0120892664995789


100%|██████████| 1000/1000 [02:49<00:00,  5.91it/s]


In [6]:
print(torch.cuda.is_available())

True


In [8]:
import numpy as np
import pandas as pd
import torch
import sys

from sklearn.manifold import TSNE
import plotly.graph_objects as go


In [9]:
# embedding from first model layer
embeddings = list(model.parameters())[0]
embeddings = embeddings.cpu().detach().numpy()

# normalization
norms = (embeddings ** 2).sum(axis=1) ** (1 / 2)
norms = np.reshape(norms, (len(norms), 1))
embeddings_norm = embeddings / norms
embeddings_norm.shape

(1123, 150)

In [11]:


# get embeddings
embeddings_df = pd.DataFrame(embeddings)

# t-SNE transform
tsne = TSNE(n_components=2)
embeddings_df_trans = tsne.fit_transform(embeddings_df)
embeddings_df_trans = pd.DataFrame(embeddings_df_trans)

# get token order
embeddings_df_trans.index = data_set._indexed_vocabulary.keys()

# if token is a number
is_numeric = embeddings_df_trans.index.str.isnumeric()



In [12]:

color = np.where(is_numeric, "green", "black")
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=embeddings_df_trans[0],
        y=embeddings_df_trans[1],
        mode="text",
        text=embeddings_df_trans.index,
        textposition="middle center",
        textfont=dict(color=color),
    )
)
fig.show()

