In [160]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torchtext
import pandas as pd
import os
from string import punctuation
from nltk import word_tokenize
from torchtext.vocab import build_vocab_from_iterator
import math
from sklearn.model_selection import train_test_split
import torchtext; torchtext.disable_torchtext_deprecation_warning()
from tqdm import tqdm
import models
import data_reader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [4]:
data_vie = data_reader.read_text("./eng_vie/vi_sents")
data_eng = data_reader.read_text("./eng_vie/en_sents")

sents_vie = data_reader.to_lines(data_vie)
sents_eng = data_reader.to_lines(data_eng)

In [5]:
dataset = pd.DataFrame({'sents_eng': sents_eng, 'sents_vie': sents_vie})

In [6]:
dataset.head()

Unnamed: 0,sents_eng,sents_vie
0,Please put the dustpan in the broom closet,xin vui lòng đặt người quét rác trong tủ chổi
1,Be quiet for a moment.,im lặng một lát
2,Read this,đọc này
3,Tom persuaded the store manager to give him ba...,tom thuyết phục người quản lý cửa hàng trả lại...
4,Friendship consists of mutual understanding,tình bạn bao gồm sự hiểu biết lẫn nhau


In [7]:
def clean_text(text, lowercase=False, remove_punc=False, remove_num=False, sos_token='<sos>', eos_token='<eos>'):
    if lowercase:
        text = text.lower()
    if remove_punc:
        text = ''.join([ch for ch in text if ch not in punctuation])
    if remove_num:
        text = ''.join([ch for ch in text if ch not in '1234567890'])
    text = [sos_token] + word_tokenize(text) + [eos_token]
    return text

In [8]:
dataset['clean_eng'] = dataset['sents_eng'].apply(lambda x: clean_text(x, lowercase=True, remove_punc=True, remove_num=False))
dataset['clean_vie'] = dataset['sents_vie'].apply(lambda x: clean_text(x, lowercase=True, remove_punc=True, remove_num=False))

In [9]:
dataset.head()

Unnamed: 0,sents_eng,sents_vie,clean_eng,clean_vie
0,Please put the dustpan in the broom closet,xin vui lòng đặt người quét rác trong tủ chổi,"[<sos>, please, put, the, dustpan, in, the, br...","[<sos>, xin, vui, lòng, đặt, người, quét, rác,..."
1,Be quiet for a moment.,im lặng một lát,"[<sos>, be, quiet, for, a, moment, <eos>]","[<sos>, im, lặng, một, lát, <eos>]"
2,Read this,đọc này,"[<sos>, read, this, <eos>]","[<sos>, đọc, này, <eos>]"
3,Tom persuaded the store manager to give him ba...,tom thuyết phục người quản lý cửa hàng trả lại...,"[<sos>, tom, persuaded, the, store, manager, t...","[<sos>, tom, thuyết, phục, người, quản, lý, cử..."
4,Friendship consists of mutual understanding,tình bạn bao gồm sự hiểu biết lẫn nhau,"[<sos>, friendship, consists, of, mutual, unde...","[<sos>, tình, bạn, bao, gồm, sự, hiểu, biết, l..."


In [10]:
unk_token = '<unk>'
pad_token = '<pad>'
sos_token = '<sos>'
eos_token = '<eos>'

In [11]:
specials = [unk_token, pad_token, sos_token, eos_token]

In [12]:
eng_vocab = build_vocab_from_iterator(dataset['clean_eng'], specials = specials)
vie_vocab = build_vocab_from_iterator(dataset['clean_vie'], specials = specials)

In [13]:
vie_vocab.get_itos()[:10], eng_vocab.get_itos()[:10]

(['<unk>',
  '<pad>',
  '<sos>',
  '<eos>',
  'tôi',
  'bạn',
  'không',
  'tom',
  'có',
  'một'],
 ['<unk>', '<pad>', '<sos>', '<eos>', 'the', 'to', 'i', 'tom', 'you', 'a'])

In [14]:
def text_to_number(text, vocab):
    return vocab.lookup_indices(text)

In [15]:
dataset['eng_nums'] = dataset['clean_eng'].apply(lambda x: text_to_number(x, eng_vocab))
dataset['vie_nums'] = dataset['clean_vie'].apply(lambda x: text_to_number(x, vie_vocab))

In [16]:
dataset.head()

Unnamed: 0,sents_eng,sents_vie,clean_eng,clean_vie,eng_nums,vie_nums
0,Please put the dustpan in the broom closet,xin vui lòng đặt người quét rác trong tủ chổi,"[<sos>, please, put, the, dustpan, in, the, br...","[<sos>, xin, vui, lòng, đặt, người, quét, rác,...","[2, 104, 145, 4, 15795, 12, 4, 5736, 2020, 3]","[2, 161, 120, 173, 239, 33, 1742, 1282, 31, 92..."
1,Be quiet for a moment.,im lặng một lát,"[<sos>, be, quiet, for, a, moment, <eos>]","[<sos>, im, lặng, một, lát, <eos>]","[2, 25, 729, 20, 9, 495, 3]","[2, 668, 606, 9, 1121, 3]"
2,Read this,đọc này,"[<sos>, read, this, <eos>]","[<sos>, đọc, này, <eos>]","[2, 244, 19, 3]","[2, 235, 30, 3]"
3,Tom persuaded the store manager to give him ba...,tom thuyết phục người quản lý cửa hàng trả lại...,"[<sos>, tom, persuaded, the, store, manager, t...","[<sos>, tom, thuyết, phục, người, quản, lý, cử...","[2, 7, 1982, 4, 576, 1515, 5, 122, 43, 109, 24...","[2, 7, 451, 426, 33, 777, 212, 158, 148, 178, ..."
4,Friendship consists of mutual understanding,tình bạn bao gồm sự hiểu biết lẫn nhau,"[<sos>, friendship, consists, of, mutual, unde...","[<sos>, tình, bạn, bao, gồm, sự, hiểu, biết, l...","[2, 2027, 3004, 13, 5537, 2324, 3]","[2, 312, 5, 71, 1257, 57, 211, 37, 1187, 156, 3]"


In [17]:
def train_test(dataset, test_size = 0.2):
    train, test = train_test_split(dataset, test_size=test_size, random_state = 42)
    train = train.reset_index(drop=True)
    test = test.reset_index(drop=True)
    return train, test

In [18]:
train, test = train_test(dataset)

In [19]:
train.head()

Unnamed: 0,sents_eng,sents_vie,clean_eng,clean_vie,eng_nums,vie_nums
0,He was never to see his homeland again,anh không bao giờ được gặp lại quê hương,"[<sos>, he, was, never, to, see, his, homeland...","[<sos>, anh, không, bao, giờ, được, gặp, lại, ...","[2, 11, 15, 97, 5, 83, 24, 8738, 186, 3]","[2, 11, 6, 71, 49, 41, 103, 62, 1424, 1461, 3]"
1,She had a basket full of apples,cô ấy có một giỏ đầy táo,"[<sos>, she, had, a, basket, full, of, apples,...","[<sos>, cô, ấy, có, một, giỏ, đầy, táo, <eos>]","[2, 35, 48, 9, 2274, 644, 13, 1085, 3]","[2, 26, 14, 8, 9, 1736, 594, 701, 3]"
2,There was a larger crowd at the concert than w...,có một đám đông lớn hơn tại buổi hòa nhạc hơn ...,"[<sos>, there, was, a, larger, crowd, at, the,...","[<sos>, có, một, đám, đông, lớn, hơn, tại, buổ...","[2, 56, 15, 9, 1919, 1093, 37, 4, 793, 103, 23...","[2, 8, 9, 649, 551, 160, 66, 80, 363, 501, 355..."
3,"Windy this morning, isn't it?",sáng nay có gió không?,"[<sos>, windy, this, morning, isnt, it, <eos>]","[<sos>, sáng, nay, có, gió, không, <eos>]","[2, 3844, 19, 237, 150, 16, 3]","[2, 210, 112, 8, 1032, 6, 3]"
4,I worked hard for this money.,tôi đã làm việc chăm chỉ vì số tiền này,"[<sos>, i, worked, hard, for, this, money, <eos>]","[<sos>, tôi, đã, làm, việc, chăm, chỉ, vì, số,...","[2, 6, 619, 219, 20, 19, 116, 3]","[2, 4, 10, 18, 59, 324, 75, 104, 118, 115, 30, 3]"


In [20]:
pad_index = eng_vocab[pad_token]
unk_index = eng_vocab[unk_token]

In [21]:
vie_vocab.set_default_index(unk_index)
eng_vocab.set_default_index(unk_index)

In [22]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        eng_nums, viet_nums = torch.tensor(self.data.iloc[idx]['eng_nums']), torch.tensor(self.data.iloc[idx]['vie_nums'])
        return eng_nums, viet_nums


In [23]:
train_dataset = CustomDataset(train)
test_dataset = CustomDataset(test)

In [24]:
train_dataset[0], test_dataset[0]

((tensor([   2,   11,   15,   97,    5,   83,   24, 8738,  186,    3]),
  tensor([   2,   11,    6,   71,   49,   41,  103,   62, 1424, 1461,    3])),
 (tensor([  2,   6, 282,  22, 737,   4, 173, 123,   3]),
  tensor([  2,   4,  10, 186, 201, 300,  12,   4,  45,  64, 127,   3])))

In [25]:
batch_size = 512
num_workers = 4

In [26]:
def collate_fn(batch):
    eng_numerical = [item[0] for item in batch]
    vie_numerical = [item[1] for item in batch]
    
    eng_numerical = nn.utils.rnn.pad_sequence(eng_numerical, padding_value=pad_index)
    vie_numerical = nn.utils.rnn.pad_sequence(vie_numerical, padding_value=pad_index)
    
    return eng_numerical, vie_numerical

In [27]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, collate_fn= collate_fn, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers,collate_fn=collate_fn, shuffle=False)

In [88]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)

    src_padding_mask = (src == pad_index).transpose(0, 1)
    tgt_padding_mask = (tgt == pad_index).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [161]:
torch.manual_seed(0)

SRC_VOCAB_SIZE = len(eng_vocab)
TGT_VOCAB_SIZE = len(vie_vocab)
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = models.EngToVieTranslation(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

transformer = transformer.to(device)




In [33]:
def train_fn(model, data_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for question, answer in data_loader:
        question = question.to(device)
        answer = answer.to(device)
        
        answer_in = answer[:-1, :]
        
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(question, answer_in)
                
        
        output = model(question, answer_in, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask).to(device)
        
        optimizer.zero_grad()
        
        answer_out = answer[1:, :]

        output = output.view(-1, output.size(-1))
        answer_out = answer_out.view(-1)
        
        loss = criterion(output, answer_out)
        
        loss.backward()
        
        optimizer.step()
        
        total_loss += loss.item()
        
    return total_loss / len(data_loader)

In [34]:
def eval_fn(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for question, answer in data_loader:
            question = question.to(device)
            answer = answer.to(device)

            answer_in = answer[:-1, :]

            src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(question, answer_in)

            output = model(question, answer_in, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask).to(device)

            answer_out = answer[1:, :]

            output = output.view(-1, output.size(-1))
            answer_out = answer_out.view(-1)

            loss = criterion(output, answer_out)

            total_loss += loss.item()
        
    return total_loss / len(data_loader)

In [35]:
epochs = 100
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_index)
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001)

In [36]:
min_loss = float('inf')
for epoch in tqdm(range(epochs)):
    train_loss = train_fn(transformer, train_loader, loss_fn, optimizer, device)
    test_loss = eval_fn(transformer, test_loader, loss_fn, device)
    
    if test_loss < min_loss:
        min_loss = test_loss
        torch.save(transformer.state_dict(), "best_translate_model.pth")
        
    print(f"Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss}")
torch.save(transformer.state_dict(), "last_translate_model.pth")

  1%|          | 1/100 [02:04<3:24:49, 124.13s/it]

Epoch: 0, Train Loss: 4.677931672364623, Test Loss: 3.493857219219208


  2%|▏         | 2/100 [04:08<3:23:03, 124.32s/it]

Epoch: 1, Train Loss: 3.228267881139439, Test Loss: 2.7307542824745177


  3%|▎         | 3/100 [06:12<3:20:41, 124.14s/it]

Epoch: 2, Train Loss: 2.673386968559955, Test Loss: 2.3250432920455935


  4%|▍         | 4/100 [08:16<3:18:13, 123.89s/it]

Epoch: 3, Train Loss: 2.3378920051919754, Test Loss: 2.0714025712013244


  5%|▌         | 5/100 [10:20<3:16:20, 124.01s/it]

Epoch: 4, Train Loss: 2.097836846083253, Test Loss: 1.8880160224437714


  6%|▌         | 6/100 [12:24<3:14:17, 124.02s/it]

Epoch: 5, Train Loss: 1.918555087180593, Test Loss: 1.7399869751930237


  7%|▋         | 7/100 [14:28<3:12:29, 124.18s/it]

Epoch: 6, Train Loss: 1.769673750628179, Test Loss: 1.620446035861969


  8%|▊         | 8/100 [16:32<3:10:16, 124.10s/it]

Epoch: 7, Train Loss: 1.650155202827262, Test Loss: 1.5363389837741852


  9%|▉         | 9/100 [18:36<3:08:11, 124.08s/it]

Epoch: 8, Train Loss: 1.5476742680348343, Test Loss: 1.4566003918647765


 10%|█         | 10/100 [20:41<3:06:19, 124.22s/it]

Epoch: 9, Train Loss: 1.4581737254732219, Test Loss: 1.3965217065811157


 11%|█         | 11/100 [22:45<3:04:04, 124.10s/it]

Epoch: 10, Train Loss: 1.3834167272601294, Test Loss: 1.336302136182785


 12%|█▏        | 12/100 [24:49<3:02:01, 124.11s/it]

Epoch: 11, Train Loss: 1.3147532457682356, Test Loss: 1.294531512260437


 13%|█▎        | 13/100 [26:53<2:59:53, 124.07s/it]

Epoch: 12, Train Loss: 1.2529109246167704, Test Loss: 1.2484520399570465


 14%|█▍        | 14/100 [28:57<2:57:55, 124.13s/it]

Epoch: 13, Train Loss: 1.1964951678736127, Test Loss: 1.2183717107772827


 15%|█▌        | 15/100 [31:02<2:56:03, 124.27s/it]

Epoch: 14, Train Loss: 1.1460645579213473, Test Loss: 1.1910168147087097


 16%|█▌        | 16/100 [33:06<2:53:54, 124.22s/it]

Epoch: 15, Train Loss: 1.0984285226718864, Test Loss: 1.1582943511009216


 17%|█▋        | 17/100 [35:09<2:51:31, 124.00s/it]

Epoch: 16, Train Loss: 1.0584338632360775, Test Loss: 1.1316233187913896


 18%|█▊        | 18/100 [37:13<2:49:32, 124.05s/it]

Epoch: 17, Train Loss: 1.0185162210883807, Test Loss: 1.1111142444610596


 19%|█▉        | 19/100 [39:17<2:47:21, 123.97s/it]

Epoch: 18, Train Loss: 0.9810378449646073, Test Loss: 1.0972482722997665


 20%|██        | 20/100 [41:21<2:45:15, 123.94s/it]

Epoch: 19, Train Loss: 0.9477852016837154, Test Loss: 1.077962371110916


 21%|██        | 21/100 [43:25<2:43:04, 123.86s/it]

Epoch: 20, Train Loss: 0.9153992076914514, Test Loss: 1.0600369477272034


 22%|██▏       | 22/100 [45:29<2:41:15, 124.05s/it]

Epoch: 21, Train Loss: 0.8828686508701076, Test Loss: 1.0426894265413285


 23%|██▎       | 23/100 [47:33<2:39:10, 124.04s/it]

Epoch: 22, Train Loss: 0.85713005799744, Test Loss: 1.0302476394176483


 24%|██▍       | 24/100 [49:38<2:37:21, 124.23s/it]

Epoch: 23, Train Loss: 0.8286785332102272, Test Loss: 1.02208635866642


 25%|██▌       | 25/100 [51:42<2:35:07, 124.10s/it]

Epoch: 24, Train Loss: 0.8042395871188772, Test Loss: 1.0128840279579163


 26%|██▌       | 26/100 [53:46<2:33:06, 124.14s/it]

Epoch: 25, Train Loss: 0.783319176441461, Test Loss: 1.0048552775382995


 27%|██▋       | 27/100 [55:50<2:30:58, 124.09s/it]

Epoch: 26, Train Loss: 0.7601006962546152, Test Loss: 1.000624617934227


 28%|██▊       | 28/100 [57:54<2:28:53, 124.07s/it]

Epoch: 27, Train Loss: 0.7399951084774343, Test Loss: 0.9866133576631546


 29%|██▉       | 29/100 [59:58<2:26:42, 123.97s/it]

Epoch: 28, Train Loss: 0.7195035767315621, Test Loss: 0.9877449071407318


 30%|███       | 30/100 [1:02:02<2:24:44, 124.06s/it]

Epoch: 29, Train Loss: 0.6962758686824061, Test Loss: 0.9715693038702011


 31%|███       | 31/100 [1:04:06<2:22:40, 124.06s/it]

Epoch: 30, Train Loss: 0.6807907024520127, Test Loss: 0.9668121755123138


 32%|███▏      | 32/100 [1:06:10<2:20:44, 124.18s/it]

Epoch: 31, Train Loss: 0.6656682501785719, Test Loss: 0.962508682012558


 33%|███▎      | 33/100 [1:08:15<2:18:50, 124.33s/it]

Epoch: 32, Train Loss: 0.6462322254875796, Test Loss: 0.9563987374305725


 34%|███▍      | 34/100 [1:10:19<2:16:36, 124.19s/it]

Epoch: 33, Train Loss: 0.6314674830017377, Test Loss: 0.9591618025302887


 35%|███▌      | 35/100 [1:12:23<2:14:25, 124.08s/it]

Epoch: 34, Train Loss: 0.6175579167340868, Test Loss: 0.9501152366399765


 36%|███▌      | 36/100 [1:14:26<2:12:09, 123.90s/it]

Epoch: 35, Train Loss: 0.6012887529392338, Test Loss: 0.95118792116642


 37%|███▋      | 37/100 [1:16:30<2:09:56, 123.75s/it]

Epoch: 36, Train Loss: 0.5879362708509867, Test Loss: 0.9483805841207504


 38%|███▊      | 38/100 [1:18:34<2:08:01, 123.89s/it]

Epoch: 37, Train Loss: 0.5756701961833628, Test Loss: 0.9450503414869309


 39%|███▉      | 39/100 [1:20:37<2:05:48, 123.74s/it]

Epoch: 38, Train Loss: 0.5601198997479587, Test Loss: 0.954067011475563


 40%|████      | 40/100 [1:22:42<2:03:58, 123.97s/it]

Epoch: 39, Train Loss: 0.5472474593912536, Test Loss: 0.9440027260780335


 41%|████      | 41/100 [1:24:45<2:01:38, 123.71s/it]

Epoch: 40, Train Loss: 0.5371818200277922, Test Loss: 0.94637022793293


 42%|████▏     | 42/100 [1:26:49<1:59:36, 123.73s/it]

Epoch: 41, Train Loss: 0.5264197326634996, Test Loss: 0.9451385790109634


 43%|████▎     | 43/100 [1:28:52<1:57:29, 123.67s/it]

Epoch: 42, Train Loss: 0.5118161139626, Test Loss: 0.9389835101366043


 44%|████▍     | 44/100 [1:30:56<1:55:25, 123.67s/it]

Epoch: 43, Train Loss: 0.5015529167262753, Test Loss: 0.943368730545044


 45%|████▌     | 45/100 [1:33:00<1:53:25, 123.73s/it]

Epoch: 44, Train Loss: 0.48983719061367476, Test Loss: 0.9417859828472137


 46%|████▌     | 46/100 [1:35:03<1:51:13, 123.59s/it]

Epoch: 45, Train Loss: 0.4801909497485089, Test Loss: 0.9448968809843064


 47%|████▋     | 47/100 [1:37:07<1:49:17, 123.73s/it]

Epoch: 46, Train Loss: 0.473440744020232, Test Loss: 0.9480032598972321


 48%|████▊     | 48/100 [1:39:10<1:47:06, 123.58s/it]

Epoch: 47, Train Loss: 0.4631413628707579, Test Loss: 0.941313738822937


 49%|████▉     | 49/100 [1:41:14<1:45:04, 123.62s/it]

Epoch: 48, Train Loss: 0.45345540329739076, Test Loss: 0.9508110934495926


 50%|█████     | 50/100 [1:43:18<1:43:05, 123.71s/it]

Epoch: 49, Train Loss: 0.44304141589445084, Test Loss: 0.9423506021499634


 51%|█████     | 51/100 [1:45:21<1:40:56, 123.61s/it]

Epoch: 50, Train Loss: 0.4333682964045798, Test Loss: 0.946059240102768


 52%|█████▏    | 52/100 [1:47:25<1:38:54, 123.63s/it]

Epoch: 51, Train Loss: 0.4267589560705214, Test Loss: 0.9505552124977111


 53%|█████▎    | 53/100 [1:49:29<1:36:49, 123.61s/it]

Epoch: 52, Train Loss: 0.42244502055884603, Test Loss: 0.9459740614891052


 54%|█████▍    | 54/100 [1:51:32<1:34:46, 123.61s/it]

Epoch: 53, Train Loss: 0.41389340873639185, Test Loss: 0.9454494726657867


 55%|█████▌    | 55/100 [1:53:36<1:32:44, 123.65s/it]

Epoch: 54, Train Loss: 0.4046127058603057, Test Loss: 0.9530827015638351


 56%|█████▌    | 56/100 [1:55:39<1:30:34, 123.51s/it]

Epoch: 55, Train Loss: 0.3944797005934931, Test Loss: 0.9564487046003342


 57%|█████▋    | 57/100 [1:57:43<1:28:33, 123.56s/it]

Epoch: 56, Train Loss: 0.3891421394731531, Test Loss: 0.9576662731170654


 58%|█████▊    | 58/100 [1:59:46<1:26:24, 123.45s/it]

Epoch: 57, Train Loss: 0.3814548089276606, Test Loss: 0.9592425906658173


 59%|█████▉    | 59/100 [2:01:49<1:24:21, 123.45s/it]

Epoch: 58, Train Loss: 0.3755753675297876, Test Loss: 0.9601756268739701


 60%|██████    | 60/100 [2:03:53<1:22:25, 123.64s/it]

Epoch: 59, Train Loss: 0.36916793935262976, Test Loss: 0.9726923096179962


 61%|██████    | 61/100 [2:05:57<1:20:15, 123.49s/it]

Epoch: 60, Train Loss: 0.36313818762050804, Test Loss: 0.9620252555608749


 62%|██████▏   | 62/100 [2:08:00<1:18:13, 123.51s/it]

Epoch: 61, Train Loss: 0.3558013045308578, Test Loss: 0.9738362908363343


 63%|██████▎   | 63/100 [2:10:04<1:16:10, 123.53s/it]

Epoch: 62, Train Loss: 0.34877562769992865, Test Loss: 0.9784004616737366


 64%|██████▍   | 64/100 [2:12:08<1:14:13, 123.70s/it]

Epoch: 63, Train Loss: 0.34859160846801257, Test Loss: 0.9747519785165787


 65%|██████▌   | 65/100 [2:14:11<1:12:04, 123.55s/it]

Epoch: 64, Train Loss: 0.3406301973453119, Test Loss: 0.975865650177002


 66%|██████▌   | 66/100 [2:16:15<1:10:06, 123.73s/it]

Epoch: 65, Train Loss: 0.3365857940373109, Test Loss: 0.9756537634134292


 67%|██████▋   | 67/100 [2:18:20<1:08:08, 123.90s/it]

Epoch: 66, Train Loss: 0.32922653464516205, Test Loss: 0.9752638107538223


 68%|██████▊   | 68/100 [2:20:23<1:06:01, 123.79s/it]

Epoch: 67, Train Loss: 0.32614683782934545, Test Loss: 0.9912212014198303


 69%|██████▉   | 69/100 [2:22:26<1:03:50, 123.58s/it]

Epoch: 68, Train Loss: 0.3168336930286944, Test Loss: 0.9839317178726197


 70%|███████   | 70/100 [2:24:30<1:01:46, 123.54s/it]

Epoch: 69, Train Loss: 0.31196999227880834, Test Loss: 0.9832670706510543


 71%|███████   | 71/100 [2:26:33<59:41, 123.51s/it]  

Epoch: 70, Train Loss: 0.30842507312345746, Test Loss: 0.9898216986656189


 72%|███████▏  | 72/100 [2:28:37<57:40, 123.59s/it]

Epoch: 71, Train Loss: 0.3033208087011797, Test Loss: 0.9860882312059402


 73%|███████▎  | 73/100 [2:30:41<55:37, 123.63s/it]

Epoch: 72, Train Loss: 0.2969785558443573, Test Loss: 0.9880221402645111


 74%|███████▍  | 74/100 [2:32:44<53:33, 123.60s/it]

Epoch: 73, Train Loss: 0.29257280787630896, Test Loss: 0.990030533671379


 75%|███████▌  | 75/100 [2:34:48<51:29, 123.59s/it]

Epoch: 74, Train Loss: 0.28965741908879733, Test Loss: 1.0036188495159148


 76%|███████▌  | 76/100 [2:36:51<49:25, 123.57s/it]

Epoch: 75, Train Loss: 0.28658976110082174, Test Loss: 1.0060695266723634


 77%|███████▋  | 77/100 [2:38:55<47:24, 123.66s/it]

Epoch: 76, Train Loss: 0.2869123524457366, Test Loss: 1.0077113562822342


 78%|███████▊  | 78/100 [2:40:58<45:16, 123.48s/it]

Epoch: 77, Train Loss: 0.2797270599921145, Test Loss: 1.0090661644935608


 79%|███████▉  | 79/100 [2:43:02<43:14, 123.56s/it]

Epoch: 78, Train Loss: 0.27228720912981275, Test Loss: 1.0076010203361512


 80%|████████  | 80/100 [2:45:05<41:11, 123.59s/it]

Epoch: 79, Train Loss: 0.26848952446001856, Test Loss: 1.0080369591712952


 81%|████████  | 81/100 [2:47:08<39:03, 123.37s/it]

Epoch: 80, Train Loss: 0.26724835163235067, Test Loss: 1.021067392230034


 82%|████████▏ | 82/100 [2:49:12<37:01, 123.41s/it]

Epoch: 81, Train Loss: 0.26427091845315903, Test Loss: 1.013731126189232


 83%|████████▎ | 83/100 [2:51:15<34:55, 123.29s/it]

Epoch: 82, Train Loss: 0.2610260331301234, Test Loss: 1.0174242448806763


 84%|████████▍ | 84/100 [2:53:18<32:53, 123.36s/it]

Epoch: 83, Train Loss: 0.2546066192526314, Test Loss: 1.024842211008072


 85%|████████▌ | 85/100 [2:55:22<30:51, 123.42s/it]

Epoch: 84, Train Loss: 0.25125569011548055, Test Loss: 1.0291671192646026


 86%|████████▌ | 86/100 [2:57:25<28:47, 123.42s/it]

Epoch: 85, Train Loss: 0.2509644222094785, Test Loss: 1.0222955310344697


 87%|████████▋ | 87/100 [2:59:29<26:45, 123.53s/it]

Epoch: 86, Train Loss: 0.24260508752048915, Test Loss: 1.0246473556756974


 88%|████████▊ | 88/100 [3:01:32<24:41, 123.44s/it]

Epoch: 87, Train Loss: 0.24116310604553126, Test Loss: 1.0339488542079927


 89%|████████▉ | 89/100 [3:03:36<22:37, 123.41s/it]

Epoch: 88, Train Loss: 0.24085558973365093, Test Loss: 1.0364917153120041


 90%|█████████ | 90/100 [3:05:39<20:33, 123.38s/it]

Epoch: 89, Train Loss: 0.23537320523855076, Test Loss: 1.0363112318515777


 91%|█████████ | 91/100 [3:07:42<18:29, 123.29s/it]

Epoch: 90, Train Loss: 0.23556178201682604, Test Loss: 1.040226281285286


 92%|█████████▏| 92/100 [3:09:46<16:27, 123.43s/it]

Epoch: 91, Train Loss: 0.2310415400033021, Test Loss: 1.047361701130867


 93%|█████████▎| 93/100 [3:11:49<14:24, 123.45s/it]

Epoch: 92, Train Loss: 0.23031144956098729, Test Loss: 1.0438429564237595


 94%|█████████▍| 94/100 [3:13:53<12:20, 123.39s/it]

Epoch: 93, Train Loss: 0.22757805932556566, Test Loss: 1.0475970804691315


 95%|█████████▌| 95/100 [3:15:56<10:17, 123.41s/it]

Epoch: 94, Train Loss: 0.22751960914638175, Test Loss: 1.0545192509889603


 96%|█████████▌| 96/100 [3:18:00<08:14, 123.53s/it]

Epoch: 95, Train Loss: 0.22064234892926624, Test Loss: 1.0491971176862718


 97%|█████████▋| 97/100 [3:20:04<06:10, 123.59s/it]

Epoch: 96, Train Loss: 0.217346768732646, Test Loss: 1.0585578399896622


 98%|█████████▊| 98/100 [3:22:07<04:07, 123.55s/it]

Epoch: 97, Train Loss: 0.2116531738683806, Test Loss: 1.0546039605140687


 99%|█████████▉| 99/100 [3:24:11<02:03, 123.74s/it]

Epoch: 98, Train Loss: 0.21071938515158753, Test Loss: 1.0658611744642257


100%|██████████| 100/100 [3:26:15<00:00, 123.75s/it]

Epoch: 99, Train Loss: 0.2113868618505684, Test Loss: 1.064922543168068



