In [6]:
import torchvision.models as models
from torch import nn
import torch
import pandas as pd
import random
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from torchvision.transforms import transforms
from nltk.tokenize import word_tokenize
from string import punctuation
from torchtext.vocab import build_vocab_from_iterator
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from models import Encoder, Decoder
import torchtext; torchtext.disable_torchtext_deprecation_warning()
torch.manual_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [7]:
data = pd.read_csv("./flickr8k/captions.txt")

In [8]:
data.head()

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...


In [9]:
def clean_text(text, lowercase=False, remove_punc=False, remove_num=False, sos_token='<sos>', eos_token='<eos>'):
    if lowercase:
        text = text.lower()
    if remove_punc:
        text = ''.join([ch for ch in text if ch not in punctuation])
    if remove_num:
        text = ''.join([ch for ch in text if ch not in '1234567890'])
    text = [sos_token] + word_tokenize(text) + [eos_token]
    return text

In [10]:
clean_text("A cat is sitting on the table.", lowercase=True, remove_punc=True, remove_num=True)

['<sos>', 'a', 'cat', 'is', 'sitting', 'on', 'the', 'table', '<eos>']

In [11]:
unk_token = '<unk>'
pad_token = '<pad>'
sos_token = '<sos>'
eos_token = '<eos>'

In [12]:
clean_cap = data['caption'].apply(lambda x: clean_text(x, lowercase=True, remove_punc=True, remove_num=True))

In [13]:
clean_cap.head()

0    [<sos>, a, child, in, a, pink, dress, is, clim...
1    [<sos>, a, girl, going, into, a, wooden, build...
2    [<sos>, a, little, girl, climbing, into, a, wo...
3    [<sos>, a, little, girl, climbing, the, stairs...
4    [<sos>, a, little, girl, in, a, pink, dress, g...
Name: caption, dtype: object

In [14]:
data['clean_caption'] = clean_cap

In [15]:
data.head()

Unnamed: 0,image,caption,clean_caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...,"[<sos>, a, child, in, a, pink, dress, is, clim..."
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .,"[<sos>, a, girl, going, into, a, wooden, build..."
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .,"[<sos>, a, little, girl, climbing, into, a, wo..."
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...,"[<sos>, a, little, girl, climbing, the, stairs..."
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...,"[<sos>, a, little, girl, in, a, pink, dress, g..."


In [16]:
vocab = build_vocab_from_iterator(clean_cap, specials=[unk_token, pad_token, sos_token, eos_token])

In [17]:
vocab.get_itos()[:10]

['<unk>', '<pad>', '<sos>', '<eos>', 'a', 'in', 'the', 'on', 'is', 'and']

In [18]:
pad_token_idx = vocab[pad_token]
unk_token_idx = vocab[unk_token]

In [19]:
vocab.set_default_index(unk_token_idx)

In [20]:
# to number
def text_to_number(text, vocab):
    return [vocab[token] for token in text]

In [21]:
to_int = clean_cap.apply(lambda x: text_to_number(x, vocab))

In [22]:
to_int

0        [2, 4, 43, 5, 4, 91, 171, 8, 120, 54, 4, 400, ...
1                      [2, 4, 20, 316, 65, 4, 196, 118, 3]
2                 [2, 4, 41, 20, 120, 65, 4, 196, 2569, 3]
3             [2, 4, 41, 20, 120, 6, 394, 21, 61, 2569, 3]
4        [2, 4, 41, 20, 5, 4, 91, 171, 316, 65, 4, 196,...
                               ...                        
40450         [2, 4, 12, 5, 4, 91, 38, 253, 4, 85, 124, 3]
40451             [2, 4, 12, 8, 85, 120, 197, 5, 6, 66, 3]
40452    [2, 4, 44, 5, 4, 26, 38, 120, 54, 4, 85, 124, ...
40453                     [2, 4, 85, 359, 5, 4, 26, 38, 3]
40454         [2, 4, 85, 359, 1915, 7, 4, 85, 120, 110, 3]
Name: caption, Length: 40455, dtype: object

In [23]:
data['embed_caption'] = to_int

In [24]:
data.head()

Unnamed: 0,image,caption,clean_caption,embed_caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...,"[<sos>, a, child, in, a, pink, dress, is, clim...","[2, 4, 43, 5, 4, 91, 171, 8, 120, 54, 4, 400, ..."
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .,"[<sos>, a, girl, going, into, a, wooden, build...","[2, 4, 20, 316, 65, 4, 196, 118, 3]"
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .,"[<sos>, a, little, girl, climbing, into, a, wo...","[2, 4, 41, 20, 120, 65, 4, 196, 2569, 3]"
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...,"[<sos>, a, little, girl, climbing, the, stairs...","[2, 4, 41, 20, 120, 6, 394, 21, 61, 2569, 3]"
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...,"[<sos>, a, little, girl, in, a, pink, dress, g...","[2, 4, 41, 20, 5, 4, 91, 171, 316, 65, 4, 196,..."


In [25]:
vocab.lookup_tokens(data['embed_caption'][0])

['<sos>',
 'a',
 'child',
 'in',
 'a',
 'pink',
 'dress',
 'is',
 'climbing',
 'up',
 'a',
 'set',
 'of',
 'stairs',
 'in',
 'an',
 'entry',
 'way',
 '<eos>']

In [26]:
train, test = train_test_split(data, test_size=0.2, random_state=42)
train = train.reset_index(drop=True)
test = test.reset_index(drop=True)

In [27]:
def get_collate_fn(pad_index):
    def collate_fn(batch):
        images = []
        captions = []
        for img, cap in batch:
            images.append(img)
            captions.append(cap)
        images = torch.stack(images)
        captions = torch.nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=pad_index)
        return images, captions

    return collate_fn

In [28]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, data, transform=None):
        self.root_dir = root_dir
        self.captions = data['embed_caption']
        self.images = data['image']
        self.transform = transform
        
    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        image = Image.open(os.path.join(self.root_dir, self.images[idx]))
        caption = torch.tensor(self.captions[idx])
        if self.transform:
            image = self.transform(image)
    
        return image, caption

In [29]:
transform = transforms.Compose([
    # data type convert to tensor
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    
])

In [30]:
train_dataset = CustomDataset("./flickr8k/Images", train, transform=transform)
test_dataset = CustomDataset("./flickr8k/Images", test, transform=transform)

In [31]:
batch_size = 512 
num_workers = 4

In [32]:
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers = num_workers, collate_fn=get_collate_fn(pad_token_idx))
test_data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers = num_workers, collate_fn=get_collate_fn(pad_token_idx))

In [33]:
embed_dim = 256
hidden_dim = 512
vocab_size = len(vocab)
num_layers = 2
dropout = 0.5

In [34]:
encoder = Encoder(embed_dim, dropout)
model = Decoder(embed_dim, hidden_dim, vocab_size, num_layers, device, encoder, dropout )
model = model.to(device)

In [35]:
n_epochs = 100
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index = pad_token_idx)
clip = 1.0
teacher_forcing_ratio = 0.5
best_valid_loss = float("inf")


In [36]:
def train_fn(model, data_loader, optimizer, criterion, clip, device):
    model.train()
    epoch_loss = 0
    
    for i, batch in enumerate(data_loader):
        images, captions = batch
        images, captions = images.to(device), captions.to(device)
     
        optimizer.zero_grad()
        
        captions_in = captions[:,:-1]
        outputs = model(images, captions_in)
        outputs = outputs.view(-1, outputs.shape[2]).to(device)
        
        captions = captions.view(-1)
        
        loss = criterion(outputs, captions)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        epoch_loss += loss.item()
    
    return epoch_loss / len(data_loader)


In [37]:
def evaluate_fn(model, data_loader, criterion, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            images, captions = batch
            images, captions = images.to(device), captions.to(device)
            
            captions_in = captions[:,:-1]
            outputs = model(images, captions_in)
            
            outputs = outputs.view(-1, outputs.shape[2]).to(device)
            captions = captions.view(-1)
        
            loss = criterion(outputs, captions)
            epoch_loss += loss.item()
            
    return epoch_loss / len(data_loader)


In [35]:
for epoch in tqdm(range(n_epochs)):
    train_loss = train_fn(
        model,
        train_data_loader,
        optimizer,
        criterion,
        clip,
        device)
    
    
    valid_loss = evaluate_fn(
        model,
        test_data_loader,
        criterion,
        device,
    )
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), "best-model.pt")
    print(f"\tTrain Loss: {train_loss:7.3f}")
    print(f"\tValid Loss: {valid_loss:7.3f}")
torch.save(model.state_dict(), "last_model.pt")

  1%|          | 1/100 [03:23<5:35:21, 203.24s/it]

	Train Loss:   5.645
	Valid Loss:   4.709


  2%|▏         | 2/100 [06:48<5:33:46, 204.35s/it]

	Train Loss:   4.505
	Valid Loss:   4.271


  3%|▎         | 3/100 [10:13<5:31:14, 204.89s/it]

	Train Loss:   4.100
	Valid Loss:   3.969


  4%|▍         | 4/100 [13:38<5:27:56, 204.96s/it]

	Train Loss:   3.872
	Valid Loss:   3.798


  5%|▌         | 5/100 [17:04<5:24:50, 205.16s/it]

	Train Loss:   3.723
	Valid Loss:   3.681


  6%|▌         | 6/100 [20:29<5:21:33, 205.25s/it]

	Train Loss:   3.605
	Valid Loss:   3.584


  7%|▋         | 7/100 [23:55<5:18:30, 205.49s/it]

	Train Loss:   3.510
	Valid Loss:   3.498


  8%|▊         | 8/100 [27:21<5:15:17, 205.62s/it]

	Train Loss:   3.426
	Valid Loss:   3.433


  9%|▉         | 9/100 [30:47<5:11:56, 205.67s/it]

	Train Loss:   3.349
	Valid Loss:   3.368


 10%|█         | 10/100 [34:13<5:08:30, 205.67s/it]

	Train Loss:   3.277
	Valid Loss:   3.302


 11%|█         | 11/100 [37:39<5:05:24, 205.90s/it]

	Train Loss:   3.217
	Valid Loss:   3.262


 12%|█▏        | 12/100 [41:05<5:02:06, 205.99s/it]

	Train Loss:   3.166
	Valid Loss:   3.222


 13%|█▎        | 13/100 [44:30<4:58:18, 205.73s/it]

	Train Loss:   3.122
	Valid Loss:   3.185


 14%|█▍        | 14/100 [47:57<4:55:15, 205.99s/it]

	Train Loss:   3.079
	Valid Loss:   3.158


 15%|█▌        | 15/100 [51:22<4:51:31, 205.78s/it]

	Train Loss:   3.036
	Valid Loss:   3.127


 16%|█▌        | 16/100 [54:48<4:48:13, 205.88s/it]

	Train Loss:   2.997
	Valid Loss:   3.099


 17%|█▋        | 17/100 [58:14<4:44:44, 205.84s/it]

	Train Loss:   2.959
	Valid Loss:   3.066


 18%|█▊        | 18/100 [1:01:40<4:41:18, 205.84s/it]

	Train Loss:   2.924
	Valid Loss:   3.045


 19%|█▉        | 19/100 [1:05:05<4:37:30, 205.56s/it]

	Train Loss:   2.891
	Valid Loss:   3.025


 20%|██        | 20/100 [1:08:31<4:34:06, 205.59s/it]

	Train Loss:   2.858
	Valid Loss:   3.005


 21%|██        | 21/100 [1:11:56<4:30:37, 205.54s/it]

	Train Loss:   2.829
	Valid Loss:   2.989


 22%|██▏       | 22/100 [1:15:21<4:27:00, 205.39s/it]

	Train Loss:   2.800
	Valid Loss:   2.972


 23%|██▎       | 23/100 [1:18:47<4:23:41, 205.47s/it]

	Train Loss:   2.771
	Valid Loss:   2.955


 24%|██▍       | 24/100 [1:22:12<4:20:09, 205.39s/it]

	Train Loss:   2.745
	Valid Loss:   2.938


 25%|██▌       | 25/100 [1:25:38<4:17:00, 205.61s/it]

	Train Loss:   2.720
	Valid Loss:   2.922


 26%|██▌       | 26/100 [1:29:03<4:13:29, 205.54s/it]

	Train Loss:   2.691
	Valid Loss:   2.912


 27%|██▋       | 27/100 [1:32:30<4:10:15, 205.70s/it]

	Train Loss:   2.668
	Valid Loss:   2.895


 28%|██▊       | 28/100 [1:35:55<4:06:43, 205.60s/it]

	Train Loss:   2.645
	Valid Loss:   2.885


 29%|██▉       | 29/100 [1:39:21<4:03:20, 205.65s/it]

	Train Loss:   2.621
	Valid Loss:   2.880


 30%|███       | 30/100 [1:42:46<3:59:45, 205.50s/it]

	Train Loss:   2.600
	Valid Loss:   2.870


 31%|███       | 31/100 [1:46:11<3:56:16, 205.46s/it]

	Train Loss:   2.578
	Valid Loss:   2.862


 32%|███▏      | 32/100 [1:49:37<3:53:00, 205.59s/it]

	Train Loss:   2.559
	Valid Loss:   2.853


 33%|███▎      | 33/100 [1:53:02<3:49:25, 205.45s/it]

	Train Loss:   2.540
	Valid Loss:   2.846


 34%|███▍      | 34/100 [1:56:28<3:46:02, 205.50s/it]

	Train Loss:   2.520
	Valid Loss:   2.841


 35%|███▌      | 35/100 [1:59:53<3:42:29, 205.38s/it]

	Train Loss:   2.500
	Valid Loss:   2.838


 36%|███▌      | 36/100 [2:03:19<3:39:19, 205.62s/it]

	Train Loss:   2.483
	Valid Loss:   2.830


 37%|███▋      | 37/100 [2:06:44<3:35:43, 205.45s/it]

	Train Loss:   2.465
	Valid Loss:   2.825


 38%|███▊      | 38/100 [2:10:10<3:32:24, 205.56s/it]

	Train Loss:   2.449
	Valid Loss:   2.821


 39%|███▉      | 39/100 [2:13:36<3:29:05, 205.67s/it]

	Train Loss:   2.430
	Valid Loss:   2.814


 40%|████      | 40/100 [2:17:02<3:25:49, 205.83s/it]

	Train Loss:   2.414
	Valid Loss:   2.812


 41%|████      | 41/100 [2:20:28<3:22:25, 205.85s/it]

	Train Loss:   2.396
	Valid Loss:   2.810


 42%|████▏     | 42/100 [2:23:53<3:18:37, 205.48s/it]

	Train Loss:   2.382
	Valid Loss:   2.806


 43%|████▎     | 43/100 [2:27:18<3:15:06, 205.38s/it]

	Train Loss:   2.366
	Valid Loss:   2.802


 44%|████▍     | 44/100 [2:30:43<3:11:47, 205.49s/it]

	Train Loss:   2.348
	Valid Loss:   2.798


 45%|████▌     | 45/100 [2:34:09<3:08:26, 205.57s/it]

	Train Loss:   2.335
	Valid Loss:   2.797


 46%|████▌     | 46/100 [2:37:35<3:05:08, 205.71s/it]

	Train Loss:   2.320
	Valid Loss:   2.793


 47%|████▋     | 47/100 [2:41:01<3:01:50, 205.86s/it]

	Train Loss:   2.304
	Valid Loss:   2.789


 48%|████▊     | 48/100 [2:44:26<2:58:02, 205.43s/it]

	Train Loss:   2.291
	Valid Loss:   2.790


 49%|████▉     | 49/100 [2:47:52<2:54:46, 205.61s/it]

	Train Loss:   2.278
	Valid Loss:   2.785


 50%|█████     | 50/100 [2:51:17<2:51:13, 205.48s/it]

	Train Loss:   2.264
	Valid Loss:   2.785


 51%|█████     | 51/100 [2:54:43<2:47:56, 205.63s/it]

	Train Loss:   2.252
	Valid Loss:   2.783


 52%|█████▏    | 52/100 [2:58:08<2:44:20, 205.42s/it]

	Train Loss:   2.237
	Valid Loss:   2.783


 53%|█████▎    | 53/100 [3:01:34<2:40:58, 205.49s/it]

	Train Loss:   2.226
	Valid Loss:   2.782


 54%|█████▍    | 54/100 [3:05:00<2:37:37, 205.60s/it]

	Train Loss:   2.213
	Valid Loss:   2.782


 55%|█████▌    | 55/100 [3:08:25<2:34:10, 205.56s/it]

	Train Loss:   2.200
	Valid Loss:   2.782


 56%|█████▌    | 56/100 [3:11:50<2:30:43, 205.53s/it]

	Train Loss:   2.188
	Valid Loss:   2.783


 57%|█████▋    | 57/100 [3:15:16<2:27:15, 205.48s/it]

	Train Loss:   2.175
	Valid Loss:   2.783


 58%|█████▊    | 58/100 [3:18:42<2:23:54, 205.58s/it]

	Train Loss:   2.165
	Valid Loss:   2.783


 59%|█████▉    | 59/100 [3:22:07<2:20:30, 205.63s/it]

	Train Loss:   2.154
	Valid Loss:   2.780


 60%|██████    | 60/100 [3:25:33<2:16:59, 205.49s/it]

	Train Loss:   2.143
	Valid Loss:   2.781


 61%|██████    | 61/100 [3:28:58<2:13:30, 205.39s/it]

	Train Loss:   2.132
	Valid Loss:   2.781


 62%|██████▏   | 62/100 [3:32:23<2:10:06, 205.45s/it]

	Train Loss:   2.121
	Valid Loss:   2.784


 63%|██████▎   | 63/100 [3:35:49<2:06:38, 205.38s/it]

	Train Loss:   2.110
	Valid Loss:   2.784


 64%|██████▍   | 64/100 [3:39:15<2:03:23, 205.64s/it]

	Train Loss:   2.098
	Valid Loss:   2.786


 65%|██████▌   | 65/100 [3:42:40<1:59:57, 205.66s/it]

	Train Loss:   2.091
	Valid Loss:   2.785


 66%|██████▌   | 66/100 [3:46:06<1:56:27, 205.53s/it]

	Train Loss:   2.079
	Valid Loss:   2.787


 67%|██████▋   | 67/100 [3:49:31<1:52:59, 205.43s/it]

	Train Loss:   2.069
	Valid Loss:   2.787


 68%|██████▊   | 68/100 [3:52:57<1:49:35, 205.49s/it]

	Train Loss:   2.058
	Valid Loss:   2.787


 69%|██████▉   | 69/100 [3:56:22<1:46:08, 205.45s/it]

	Train Loss:   2.048
	Valid Loss:   2.789


 70%|███████   | 70/100 [3:59:48<1:42:46, 205.55s/it]

	Train Loss:   2.038
	Valid Loss:   2.790


 71%|███████   | 71/100 [4:03:13<1:39:22, 205.61s/it]

	Train Loss:   2.030
	Valid Loss:   2.792


 72%|███████▏  | 72/100 [4:06:38<1:35:51, 205.42s/it]

	Train Loss:   2.022
	Valid Loss:   2.791


 73%|███████▎  | 73/100 [4:10:04<1:32:30, 205.57s/it]

	Train Loss:   2.012
	Valid Loss:   2.794


 74%|███████▍  | 74/100 [4:13:29<1:29:01, 205.44s/it]

	Train Loss:   2.002
	Valid Loss:   2.794


 75%|███████▌  | 75/100 [4:16:55<1:25:36, 205.45s/it]

	Train Loss:   1.993
	Valid Loss:   2.796


 76%|███████▌  | 76/100 [4:20:20<1:22:05, 205.24s/it]

	Train Loss:   1.987
	Valid Loss:   2.800


 77%|███████▋  | 77/100 [4:23:45<1:18:43, 205.35s/it]

	Train Loss:   1.976
	Valid Loss:   2.798


 78%|███████▊  | 78/100 [4:27:11<1:15:17, 205.33s/it]

	Train Loss:   1.967
	Valid Loss:   2.801


 79%|███████▉  | 79/100 [4:30:36<1:11:52, 205.36s/it]

	Train Loss:   1.958
	Valid Loss:   2.804


 80%|████████  | 80/100 [4:34:01<1:08:27, 205.37s/it]

	Train Loss:   1.949
	Valid Loss:   2.805


 81%|████████  | 81/100 [4:37:27<1:05:03, 205.43s/it]

	Train Loss:   1.942
	Valid Loss:   2.808


 82%|████████▏ | 82/100 [4:40:52<1:01:38, 205.45s/it]

	Train Loss:   1.932
	Valid Loss:   2.811


 83%|████████▎ | 83/100 [4:44:18<58:13, 205.53s/it]  

	Train Loss:   1.922
	Valid Loss:   2.813


 84%|████████▍ | 84/100 [4:47:44<54:51, 205.73s/it]

	Train Loss:   1.918
	Valid Loss:   2.815


 85%|████████▌ | 85/100 [4:51:10<51:23, 205.58s/it]

	Train Loss:   1.911
	Valid Loss:   2.817


 86%|████████▌ | 86/100 [4:54:36<47:59, 205.69s/it]

	Train Loss:   1.902
	Valid Loss:   2.819


 87%|████████▋ | 87/100 [4:58:01<44:33, 205.65s/it]

	Train Loss:   1.894
	Valid Loss:   2.820


 88%|████████▊ | 88/100 [5:01:26<41:06, 205.56s/it]

	Train Loss:   1.886
	Valid Loss:   2.825


 89%|████████▉ | 89/100 [5:04:52<37:42, 205.66s/it]

	Train Loss:   1.879
	Valid Loss:   2.829


 90%|█████████ | 90/100 [5:08:18<34:16, 205.64s/it]

	Train Loss:   1.870
	Valid Loss:   2.831


 91%|█████████ | 91/100 [5:11:43<30:50, 205.60s/it]

	Train Loss:   1.866
	Valid Loss:   2.831


 92%|█████████▏| 92/100 [5:15:09<27:25, 205.65s/it]

	Train Loss:   1.857
	Valid Loss:   2.834


 93%|█████████▎| 93/100 [5:18:36<24:01, 205.93s/it]

	Train Loss:   1.850
	Valid Loss:   2.837


 94%|█████████▍| 94/100 [5:22:02<20:35, 205.99s/it]

	Train Loss:   1.844
	Valid Loss:   2.841


 95%|█████████▌| 95/100 [5:25:28<17:10, 206.06s/it]

	Train Loss:   1.836
	Valid Loss:   2.842


 96%|█████████▌| 96/100 [5:28:54<13:44, 206.00s/it]

	Train Loss:   1.828
	Valid Loss:   2.844


 97%|█████████▋| 97/100 [5:32:19<10:17, 205.74s/it]

	Train Loss:   1.823
	Valid Loss:   2.847


 98%|█████████▊| 98/100 [5:35:44<06:51, 205.59s/it]

	Train Loss:   1.817
	Valid Loss:   2.852


 99%|█████████▉| 99/100 [5:39:10<03:25, 205.67s/it]

	Train Loss:   1.809
	Valid Loss:   2.852


100%|██████████| 100/100 [5:42:36<00:00, 205.56s/it]

	Train Loss:   1.803
	Valid Loss:   2.857



