In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount = True)

Mounted at /content/drive


In [2]:
import sys
sys.path.append('/content/drive/MyDrive')

In [3]:
import string
import numpy as np
import pandas as pd
from numpy import array
from PIL import Image
import pickle
from keras.utils import to_categorical
from keras.preprocessing.text import Tokenizer
import tensorflow as tf
import os
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import LambdaLR

import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
from model import TransformerModel
from torch.utils.data import DataLoader
import torch.optim as optim
from torch import nn
import torch
from dataset import DataGenerator
import time
from tqdm import tqdm
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler


from utils import *


tf.get_logger().setLevel('ERROR')
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"


In [4]:
captions_path = os.path.join('/content/drive/MyDrive/data/captions.txt')

In [5]:
top_k = 5000
caption_dict, word_to_index, index_to_word = preprocess_captions(captions_path)


In [6]:
keys = list(caption_dict.keys())
train_keys, test_keys = train_test_split(keys, test_size = 0.2, random_state = 42)
train_keys, val_keys = train_test_split(train_keys, test_size = 0.25, random_state = 42)

In [7]:
features_path = os.path.join('/content/drive/MyDrive/encodings_swin.pkl')
with open(features_path, 'rb') as f:
    features_dict = pickle.load(f)

In [8]:
train_captions = {k: caption_dict[k] for k in train_keys}
val_captions = {k: caption_dict[k] for k in val_keys}
test_captions = {k: caption_dict[k] for k in test_keys}

train_features = {k: features_dict[k] for k in train_keys}
val_features = {k: features_dict[k] for k in val_keys}
test_features = {k: features_dict[k] for k in test_keys}

In [9]:
tokenizer = Tokenizer(num_words = top_k,
                      oov_token="<unk>",
                      filters='!"#$%&()*+.,-/:;=?@[\]^_`{|}~ ')

In [10]:
tc = get_captions(train_captions)
tokenizer.fit_on_texts(tc)
tokenizer.word_index['<pad>'] = 0
tokenizer.index_word[0] = '<pad>'
max_seq_length = get_max_length(caption_dict)

In [11]:
train_data_generator = DataGenerator(train_captions, train_features, tokenizer, max_seq_length)
train_data = DataLoader(train_data_generator, batch_size=32)
val_data_generator = DataGenerator(val_captions, val_features, tokenizer, max_seq_length)
val_data = DataLoader(val_data_generator, batch_size=32)
test_data_generator = DataGenerator(test_captions, test_features, tokenizer, max_seq_length)
test_data = DataLoader(test_data_generator, batch_size=32)

In [12]:
model = TransformerModel(model_dim = 512, input_dim = 768, num_heads = 8, num_layers = 4, vocab_size = top_k + 1, sequence_length = top_k + 1, height = 7, width = 7)

In [13]:
model.load_state_dict(torch.load(os.path.join('/content/drive/MyDrive/model.pth')))

<All keys matched successfully>

In [15]:
class LRScheduler(_LRScheduler):
    def __init__(self,
                 optimizer: Optimizer,
                 dim_embed: int,
                 warmup_steps: int,
                 last_epoch: int=-1,
                 verbose: bool=False) -> None:

        self.dim_embed = dim_embed
        self.warmup_steps = warmup_steps
        self.num_param_groups = len(optimizer.param_groups)

        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self) -> float:
        lr = calc_lr(self._step_count, self.dim_embed, self.warmup_steps)
        return [lr] * self.num_param_groups


def calc_lr(step, dim_embed, warmup_steps):
    return dim_embed**(-0.5) * min(step**(-0.5), step * warmup_steps**(-1.5))

In [16]:
loss_fn = nn.CrossEntropyLoss(reduction = 'none')
optimizer = optim.Adam(model.parameters(), lr = 0.001, betas = (0.9, 0.98), eps = 1e-9)
lr_mul = 1.0 / 512 ** 0.5
scheduler = LRScheduler(optimizer, dim_embed = 512, warmup_steps = 4000)


In [17]:
num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [18]:
def loss_function(real, pred):
    mask = torch.logical_not(torch.eq(real, 0))
    loss_ = loss_fn(pred.reshape(-1, pred.size(-1)), real.reshape(-1))

    mask = mask.float()
    loss_ *= mask.view_as(loss_)

    return torch.sum(loss_) / torch.sum(mask)

In [19]:
def evaluate_step(img_tensor, tar):
    target_in = tar[:, :-1]
    target_real = tar[:, 1:]

    predictions = model(img_tensor, target_in)
    loss = loss_function(target_real, predictions)

    _, predicted_ids = torch.max(predictions, dim=-1)
    accuracy = torch.sum(predicted_ids == target_real).item() / target_real.numel()

    return loss, accuracy

In [None]:
for epoch in range(num_epochs):
    start = time.time()

    model.train()
    epoch_loss = 0.0
    epoch_accuracy = 0.0

    with tqdm(total=len(train_data), desc=f'Epoch {epoch + 1} Training') as pbar:
        for batch, (img_tensor, tar) in enumerate(train_data):
            img_tensor, tar = img_tensor.to(device), tar.to(device)
            img_tensor = img_tensor.float()
            tar = tar.long()
            optimizer.zero_grad()
            loss, batch_accuracy = evaluate_step(img_tensor, tar)
            batch_loss = loss.item()
            loss.backward()
            optimizer.step()
            scheduler.step()

            epoch_loss += batch_loss
            epoch_accuracy += batch_accuracy

            pbar.set_postfix({'Loss': batch_loss, 'Accuracy': batch_accuracy})
            pbar.update()

    mean_epoch_loss = epoch_loss / len(train_data)
    mean_epoch_accuracy = epoch_accuracy / len(train_data)

    print('Epoch {} Training Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, mean_epoch_loss, mean_epoch_accuracy))

    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), 'model.pth')

    model.eval()
    val_loss = 0.0
    val_accuracy = 0.0

    with torch.no_grad():
        for batch, (val_img_tensor, val_tar) in enumerate(val_data):
            val_img_tensor, val_tar = val_img_tensor.to(device), val_tar.to(device)
            val_img_tensor = val_img_tensor.float()
            val_tar = val_tar.long()
            val_batch_loss, val_batch_accuracy = evaluate_step(val_img_tensor, val_tar)

            val_loss += val_batch_loss
            val_accuracy += val_batch_accuracy

            pbar.set_postfix({'Loss': val_batch_loss, 'Accuracy': val_batch_accuracy})
            pbar.update()

    mean_val_loss = val_loss / len(val_data)
    mean_val_accuracy = val_accuracy / len(val_data)

    print('Epoch {} Validation Loss {:.4f} Validation Accuracy {:.4f}'.format(epoch + 1, mean_val_loss, mean_val_accuracy))

Epoch 1 Training: 100%|██████████| 152/152 [00:14<00:00, 10.44it/s, Loss=6.29, Accuracy=0.0344]


Epoch 1 Training Loss 7.2670 Accuracy 0.0265
Epoch 1 Validation Loss 6.1868 Validation Accuracy 0.0379


Epoch 2 Training: 100%|██████████| 152/152 [00:14<00:00, 10.34it/s, Loss=5.1, Accuracy=0.0675]


Epoch 2 Training Loss 5.6075 Accuracy 0.0539
Epoch 2 Validation Loss 4.9716 Validation Accuracy 0.0666


Epoch 3 Training: 100%|██████████| 152/152 [00:16<00:00,  9.43it/s, Loss=4.32, Accuracy=0.0799]


Epoch 3 Training Loss 4.8133 Accuracy 0.0715
Epoch 3 Validation Loss 4.4527 Validation Accuracy 0.0802


Epoch 4 Training: 100%|██████████| 152/152 [00:14<00:00, 10.28it/s, Loss=4.18, Accuracy=0.0813]


Epoch 4 Training Loss 4.4005 Accuracy 0.0815
Epoch 4 Validation Loss 4.1397 Validation Accuracy 0.0851


Epoch 5 Training: 100%|██████████| 152/152 [00:14<00:00, 10.56it/s, Loss=4.29, Accuracy=0.073]


Epoch 5 Training Loss 4.1696 Accuracy 0.0847
Epoch 5 Validation Loss 3.9696 Validation Accuracy 0.0895


Epoch 6 Training: 100%|██████████| 152/152 [00:14<00:00, 10.51it/s, Loss=4.08, Accuracy=0.0799]


Epoch 6 Training Loss 3.9694 Accuracy 0.0890
Epoch 6 Validation Loss 3.7741 Validation Accuracy 0.0923


Epoch 7 Training: 100%|██████████| 152/152 [00:14<00:00, 10.40it/s, Loss=3.53, Accuracy=0.103]


Epoch 7 Training Loss 3.7813 Accuracy 0.0938
Epoch 7 Validation Loss 3.7420 Validation Accuracy 0.0917


Epoch 8 Training: 100%|██████████| 152/152 [00:14<00:00, 10.39it/s, Loss=3.23, Accuracy=0.107]


Epoch 8 Training Loss 3.6787 Accuracy 0.0965
Epoch 8 Validation Loss 3.6489 Validation Accuracy 0.0940


Epoch 9 Training: 100%|██████████| 152/152 [00:14<00:00, 10.53it/s, Loss=3.33, Accuracy=0.0992]


Epoch 9 Training Loss 3.5434 Accuracy 0.1010
Epoch 9 Validation Loss 3.5812 Validation Accuracy 0.0967


Epoch 10 Training: 100%|██████████| 152/152 [00:14<00:00, 10.52it/s, Loss=3.56, Accuracy=0.0854]


Epoch 10 Training Loss 3.4748 Accuracy 0.1013
Epoch 10 Validation Loss 3.5290 Validation Accuracy 0.0996


Epoch 11 Training: 100%|██████████| 152/152 [00:14<00:00, 10.42it/s, Loss=3.44, Accuracy=0.106]


Epoch 11 Training Loss 3.3840 Accuracy 0.1040
Epoch 11 Validation Loss 3.4472 Validation Accuracy 0.0993


Epoch 12 Training: 100%|██████████| 152/152 [00:14<00:00, 10.55it/s, Loss=3.03, Accuracy=0.116]


Epoch 12 Training Loss 3.2847 Accuracy 0.1078
Epoch 12 Validation Loss 3.4570 Validation Accuracy 0.0993


Epoch 13 Training: 100%|██████████| 152/152 [00:14<00:00, 10.55it/s, Loss=2.89, Accuracy=0.121]


Epoch 13 Training Loss 3.2421 Accuracy 0.1087
Epoch 13 Validation Loss 3.4508 Validation Accuracy 0.0991


Epoch 14 Training: 100%|██████████| 152/152 [00:14<00:00, 10.53it/s, Loss=2.86, Accuracy=0.11]


Epoch 14 Training Loss 3.1554 Accuracy 0.1123
Epoch 14 Validation Loss 3.4300 Validation Accuracy 0.1005


Epoch 15 Training: 100%|██████████| 152/152 [00:14<00:00, 10.52it/s, Loss=3.18, Accuracy=0.103]


Epoch 15 Training Loss 3.1268 Accuracy 0.1116
Epoch 15 Validation Loss 3.3997 Validation Accuracy 0.1027


Epoch 16 Training: 100%|██████████| 152/152 [00:14<00:00, 10.51it/s, Loss=3.06, Accuracy=0.12]


Epoch 16 Training Loss 3.0576 Accuracy 0.1146
Epoch 16 Validation Loss 3.3693 Validation Accuracy 0.1009


Epoch 17 Training: 100%|██████████| 152/152 [00:14<00:00, 10.52it/s, Loss=2.71, Accuracy=0.129]


Epoch 17 Training Loss 2.9935 Accuracy 0.1166
Epoch 17 Validation Loss 3.4250 Validation Accuracy 0.1000


Epoch 18 Training: 100%|██████████| 152/152 [00:14<00:00, 10.53it/s, Loss=2.54, Accuracy=0.138]


Epoch 18 Training Loss 2.9703 Accuracy 0.1176
Epoch 18 Validation Loss 3.4356 Validation Accuracy 0.0999


Epoch 19 Training: 100%|██████████| 152/152 [00:14<00:00, 10.58it/s, Loss=2.54, Accuracy=0.121]


Epoch 19 Training Loss 2.9104 Accuracy 0.1198
Epoch 19 Validation Loss 3.4171 Validation Accuracy 0.1015


Epoch 20 Training: 100%|██████████| 152/152 [00:14<00:00, 10.60it/s, Loss=2.86, Accuracy=0.112]


Epoch 20 Training Loss 2.9000 Accuracy 0.1193
Epoch 20 Validation Loss 3.4354 Validation Accuracy 0.1028


Epoch 21 Training: 100%|██████████| 152/152 [00:14<00:00, 10.60it/s, Loss=2.9, Accuracy=0.116]


Epoch 21 Training Loss 2.8545 Accuracy 0.1214
Epoch 21 Validation Loss 3.3951 Validation Accuracy 0.1000


Epoch 22 Training: 100%|██████████| 152/152 [00:14<00:00, 10.43it/s, Loss=2.56, Accuracy=0.139]


Epoch 22 Training Loss 2.8228 Accuracy 0.1224
Epoch 22 Validation Loss 3.4505 Validation Accuracy 0.1001


Epoch 23 Training: 100%|██████████| 152/152 [00:14<00:00, 10.52it/s, Loss=2.4, Accuracy=0.131]


Epoch 23 Training Loss 2.8207 Accuracy 0.1212
Epoch 23 Validation Loss 3.4710 Validation Accuracy 0.0987


Epoch 24 Training: 100%|██████████| 152/152 [00:14<00:00, 10.60it/s, Loss=2.45, Accuracy=0.124]


Epoch 24 Training Loss 2.7854 Accuracy 0.1234
Epoch 24 Validation Loss 3.4595 Validation Accuracy 0.1009


Epoch 25 Training: 100%|██████████| 152/152 [00:14<00:00, 10.60it/s, Loss=2.74, Accuracy=0.123]


Epoch 25 Training Loss 2.7906 Accuracy 0.1228
Epoch 25 Validation Loss 3.5009 Validation Accuracy 0.1011


Epoch 26 Training: 100%|██████████| 152/152 [00:14<00:00, 10.64it/s, Loss=2.63, Accuracy=0.132]


Epoch 26 Training Loss 2.7972 Accuracy 0.1215
Epoch 26 Validation Loss 3.4682 Validation Accuracy 0.1004


Epoch 27 Training: 100%|██████████| 152/152 [00:14<00:00, 10.63it/s, Loss=2.51, Accuracy=0.138]


Epoch 27 Training Loss 2.7713 Accuracy 0.1226
Epoch 27 Validation Loss 3.4854 Validation Accuracy 0.1013


Epoch 28 Training: 100%|██████████| 152/152 [00:14<00:00, 10.63it/s, Loss=2.26, Accuracy=0.149]


Epoch 28 Training Loss 2.7665 Accuracy 0.1219
Epoch 28 Validation Loss 3.4881 Validation Accuracy 0.1007


Epoch 29 Training: 100%|██████████| 152/152 [00:14<00:00, 10.61it/s, Loss=2.4, Accuracy=0.118]


Epoch 29 Training Loss 2.6860 Accuracy 0.1262
Epoch 29 Validation Loss 3.4753 Validation Accuracy 0.1021


Epoch 30 Training: 100%|██████████| 152/152 [00:14<00:00, 10.63it/s, Loss=2.49, Accuracy=0.135]


Epoch 30 Training Loss 2.6921 Accuracy 0.1245
Epoch 30 Validation Loss 3.4924 Validation Accuracy 0.1007


Epoch 31 Training: 100%|██████████| 152/152 [00:14<00:00, 10.65it/s, Loss=2.58, Accuracy=0.123]


Epoch 31 Training Loss 2.6405 Accuracy 0.1271
Epoch 31 Validation Loss 3.4172 Validation Accuracy 0.1013


Epoch 32 Training: 100%|██████████| 152/152 [00:14<00:00, 10.59it/s, Loss=2.33, Accuracy=0.132]


Epoch 32 Training Loss 2.5984 Accuracy 0.1283
Epoch 32 Validation Loss 3.4552 Validation Accuracy 0.1044


Epoch 33 Training: 100%|██████████| 152/152 [00:14<00:00, 10.58it/s, Loss=2.13, Accuracy=0.142]


Epoch 33 Training Loss 2.5696 Accuracy 0.1302
Epoch 33 Validation Loss 3.5056 Validation Accuracy 0.1010


Epoch 34 Training: 100%|██████████| 152/152 [00:14<00:00, 10.57it/s, Loss=2.18, Accuracy=0.135]


Epoch 34 Training Loss 2.4954 Accuracy 0.1330
Epoch 34 Validation Loss 3.4649 Validation Accuracy 0.1041


Epoch 35 Training: 100%|██████████| 152/152 [00:14<00:00, 10.69it/s, Loss=2.27, Accuracy=0.134]


Epoch 35 Training Loss 2.4955 Accuracy 0.1322
Epoch 35 Validation Loss 3.4965 Validation Accuracy 0.1019


Epoch 36 Training: 100%|██████████| 152/152 [00:14<00:00, 10.68it/s, Loss=2.25, Accuracy=0.145]


Epoch 36 Training Loss 2.4403 Accuracy 0.1343
Epoch 36 Validation Loss 3.4067 Validation Accuracy 0.1023


Epoch 37 Training: 100%|██████████| 152/152 [00:14<00:00, 10.62it/s, Loss=2.15, Accuracy=0.163]


Epoch 37 Training Loss 2.4179 Accuracy 0.1358
Epoch 37 Validation Loss 3.4452 Validation Accuracy 0.1026


Epoch 38 Training: 100%|██████████| 152/152 [00:14<00:00, 10.57it/s, Loss=2.04, Accuracy=0.158]


Epoch 38 Training Loss 2.3727 Accuracy 0.1389
Epoch 38 Validation Loss 3.5882 Validation Accuracy 0.0994


Epoch 39 Training: 100%|██████████| 152/152 [00:14<00:00, 10.63it/s, Loss=2.05, Accuracy=0.131]


Epoch 39 Training Loss 2.3150 Accuracy 0.1402
Epoch 39 Validation Loss 3.4836 Validation Accuracy 0.1044


Epoch 40 Training: 100%|██████████| 152/152 [00:14<00:00, 10.68it/s, Loss=2.05, Accuracy=0.153]


Epoch 40 Training Loss 2.3135 Accuracy 0.1391
Epoch 40 Validation Loss 3.5340 Validation Accuracy 0.1020


Epoch 41 Training: 100%|██████████| 152/152 [00:14<00:00, 10.72it/s, Loss=1.94, Accuracy=0.156]


Epoch 41 Training Loss 2.2345 Accuracy 0.1430
Epoch 41 Validation Loss 3.4505 Validation Accuracy 0.1025


Epoch 42 Training: 100%|██████████| 152/152 [00:14<00:00, 10.68it/s, Loss=1.94, Accuracy=0.167]


Epoch 42 Training Loss 2.2206 Accuracy 0.1447
Epoch 42 Validation Loss 3.5160 Validation Accuracy 0.1030


Epoch 43 Training: 100%|██████████| 152/152 [00:14<00:00, 10.61it/s, Loss=1.78, Accuracy=0.165]


Epoch 43 Training Loss 2.1776 Accuracy 0.1470
Epoch 43 Validation Loss 3.5521 Validation Accuracy 0.1036


Epoch 44 Training: 100%|██████████| 152/152 [00:14<00:00, 10.61it/s, Loss=1.74, Accuracy=0.149]


Epoch 44 Training Loss 2.1291 Accuracy 0.1485
Epoch 44 Validation Loss 3.5483 Validation Accuracy 0.1043


Epoch 45 Training: 100%|██████████| 152/152 [00:14<00:00, 10.65it/s, Loss=1.91, Accuracy=0.161]


Epoch 45 Training Loss 2.1429 Accuracy 0.1467
Epoch 45 Validation Loss 3.5422 Validation Accuracy 0.1026


Epoch 46 Training: 100%|██████████| 152/152 [00:14<00:00, 10.64it/s, Loss=1.74, Accuracy=0.168]


Epoch 46 Training Loss 2.0649 Accuracy 0.1507
Epoch 46 Validation Loss 3.4938 Validation Accuracy 0.1023


Epoch 47 Training: 100%|██████████| 152/152 [00:14<00:00, 10.66it/s, Loss=1.94, Accuracy=0.161]


Epoch 47 Training Loss 2.0485 Accuracy 0.1525
Epoch 47 Validation Loss 3.5720 Validation Accuracy 0.1054


Epoch 48 Training: 100%|██████████| 152/152 [00:14<00:00, 10.68it/s, Loss=1.59, Accuracy=0.172]


Epoch 48 Training Loss 1.9940 Accuracy 0.1550
Epoch 48 Validation Loss 3.5987 Validation Accuracy 0.1019


Epoch 49 Training: 100%|██████████| 152/152 [00:14<00:00, 10.67it/s, Loss=1.52, Accuracy=0.167]


Epoch 49 Training Loss 1.9494 Accuracy 0.1579
Epoch 49 Validation Loss 3.6085 Validation Accuracy 0.1040


Epoch 50 Training: 100%|██████████| 152/152 [00:14<00:00, 10.64it/s, Loss=1.69, Accuracy=0.161]


Epoch 50 Training Loss 1.9573 Accuracy 0.1553
Epoch 50 Validation Loss 3.6067 Validation Accuracy 0.1008


Epoch 51 Training: 100%|██████████| 152/152 [00:14<00:00, 10.68it/s, Loss=1.51, Accuracy=0.18]


Epoch 51 Training Loss 1.8951 Accuracy 0.1591
Epoch 51 Validation Loss 3.5443 Validation Accuracy 0.1017


Epoch 52 Training: 100%|██████████| 152/152 [00:14<00:00, 10.71it/s, Loss=1.56, Accuracy=0.19]


Epoch 52 Training Loss 1.8804 Accuracy 0.1607
Epoch 52 Validation Loss 3.6479 Validation Accuracy 0.1039


Epoch 53 Training: 100%|██████████| 152/152 [00:14<00:00, 10.65it/s, Loss=1.41, Accuracy=0.18]


Epoch 53 Training Loss 1.8191 Accuracy 0.1641
Epoch 53 Validation Loss 3.7056 Validation Accuracy 0.1015


Epoch 54 Training: 100%|██████████| 152/152 [00:14<00:00, 10.63it/s, Loss=1.45, Accuracy=0.167]


Epoch 54 Training Loss 1.8029 Accuracy 0.1657
Epoch 54 Validation Loss 3.6369 Validation Accuracy 0.1055


Epoch 55 Training: 100%|██████████| 152/152 [00:14<00:00, 10.68it/s, Loss=1.49, Accuracy=0.18]


Epoch 55 Training Loss 1.7716 Accuracy 0.1659
Epoch 55 Validation Loss 3.7360 Validation Accuracy 0.0992


Epoch 56 Training: 100%|██████████| 152/152 [00:14<00:00, 10.64it/s, Loss=1.37, Accuracy=0.183]


Epoch 56 Training Loss 1.7181 Accuracy 0.1690
Epoch 56 Validation Loss 3.6604 Validation Accuracy 0.0995


Epoch 57 Training: 100%|██████████| 152/152 [00:14<00:00, 10.65it/s, Loss=1.39, Accuracy=0.194]


Epoch 57 Training Loss 1.7008 Accuracy 0.1708
Epoch 57 Validation Loss 3.7747 Validation Accuracy 0.1028


Epoch 58 Training: 100%|██████████| 152/152 [00:14<00:00, 10.65it/s, Loss=1.27, Accuracy=0.186]


Epoch 58 Training Loss 1.6569 Accuracy 0.1735
Epoch 58 Validation Loss 3.8196 Validation Accuracy 0.1017


Epoch 59 Training: 100%|██████████| 152/152 [00:14<00:00, 10.65it/s, Loss=1.24, Accuracy=0.193]


Epoch 59 Training Loss 1.6330 Accuracy 0.1753
Epoch 59 Validation Loss 3.7597 Validation Accuracy 0.1012


Epoch 60 Training: 100%|██████████| 152/152 [00:14<00:00, 10.64it/s, Loss=1.28, Accuracy=0.196]


Epoch 60 Training Loss 1.5994 Accuracy 0.1761
Epoch 60 Validation Loss 3.8546 Validation Accuracy 0.0979


Epoch 61 Training: 100%|██████████| 152/152 [00:14<00:00, 10.68it/s, Loss=1.13, Accuracy=0.207]


Epoch 61 Training Loss 1.5691 Accuracy 0.1774
Epoch 61 Validation Loss 3.7503 Validation Accuracy 0.1000


Epoch 62 Training: 100%|██████████| 152/152 [00:14<00:00, 10.69it/s, Loss=1.19, Accuracy=0.211]


Epoch 62 Training Loss 1.5545 Accuracy 0.1795
Epoch 62 Validation Loss 3.8717 Validation Accuracy 0.1010


Epoch 63 Training: 100%|██████████| 152/152 [00:14<00:00, 10.70it/s, Loss=1.14, Accuracy=0.213]


Epoch 63 Training Loss 1.5165 Accuracy 0.1813
Epoch 63 Validation Loss 3.9224 Validation Accuracy 0.1017


Epoch 64 Training: 100%|██████████| 152/152 [00:14<00:00, 10.64it/s, Loss=1.12, Accuracy=0.19]


Epoch 64 Training Loss 1.4905 Accuracy 0.1840
Epoch 64 Validation Loss 3.8406 Validation Accuracy 0.1025


Epoch 65 Training: 100%|██████████| 152/152 [00:14<00:00, 10.67it/s, Loss=1.17, Accuracy=0.197]


Epoch 65 Training Loss 1.4561 Accuracy 0.1846
Epoch 65 Validation Loss 3.9370 Validation Accuracy 0.0976


Epoch 66 Training: 100%|██████████| 152/152 [00:14<00:00, 10.66it/s, Loss=1.04, Accuracy=0.215]


Epoch 66 Training Loss 1.4196 Accuracy 0.1871
Epoch 66 Validation Loss 3.8627 Validation Accuracy 0.0989


Epoch 67 Training: 100%|██████████| 152/152 [00:14<00:00, 10.63it/s, Loss=1.07, Accuracy=0.233]


Epoch 67 Training Loss 1.4069 Accuracy 0.1886
Epoch 67 Validation Loss 4.0088 Validation Accuracy 0.1006


Epoch 68 Training: 100%|██████████| 152/152 [00:14<00:00, 10.65it/s, Loss=1.04, Accuracy=0.212]


Epoch 68 Training Loss 1.3707 Accuracy 0.1903
Epoch 68 Validation Loss 4.0408 Validation Accuracy 0.0998


Epoch 69 Training: 100%|██████████| 152/152 [00:14<00:00, 10.64it/s, Loss=0.938, Accuracy=0.207]


Epoch 69 Training Loss 1.3511 Accuracy 0.1929
Epoch 69 Validation Loss 3.9591 Validation Accuracy 0.0995


Epoch 70 Training: 100%|██████████| 152/152 [00:14<00:00, 10.63it/s, Loss=0.909, Accuracy=0.23]


Epoch 70 Training Loss 1.3298 Accuracy 0.1933
Epoch 70 Validation Loss 4.0722 Validation Accuracy 0.0971


Epoch 71 Training: 100%|██████████| 152/152 [00:14<00:00, 10.66it/s, Loss=0.849, Accuracy=0.226]


Epoch 71 Training Loss 1.2864 Accuracy 0.1964
Epoch 71 Validation Loss 4.0002 Validation Accuracy 0.1001


Epoch 72 Training: 100%|██████████| 152/152 [00:14<00:00, 10.64it/s, Loss=1.05, Accuracy=0.231]


Epoch 72 Training Loss 1.2723 Accuracy 0.1973
Epoch 72 Validation Loss 4.1511 Validation Accuracy 0.0981


Epoch 73 Training: 100%|██████████| 152/152 [00:14<00:00, 10.66it/s, Loss=0.808, Accuracy=0.231]


Epoch 73 Training Loss 1.2396 Accuracy 0.1997
Epoch 73 Validation Loss 4.1726 Validation Accuracy 0.0994


Epoch 74 Training: 100%|██████████| 152/152 [00:14<00:00, 10.66it/s, Loss=0.89, Accuracy=0.209]


Epoch 74 Training Loss 1.2401 Accuracy 0.2009
Epoch 74 Validation Loss 4.1325 Validation Accuracy 0.0993


Epoch 75 Training: 100%|██████████| 152/152 [00:14<00:00, 10.68it/s, Loss=0.872, Accuracy=0.231]


Epoch 75 Training Loss 1.2083 Accuracy 0.2015
Epoch 75 Validation Loss 4.1961 Validation Accuracy 0.0956


Epoch 76 Training: 100%|██████████| 152/152 [00:14<00:00, 10.67it/s, Loss=0.858, Accuracy=0.223]


Epoch 76 Training Loss 1.1922 Accuracy 0.2025
Epoch 76 Validation Loss 4.1116 Validation Accuracy 0.1000


Epoch 77 Training: 100%|██████████| 152/152 [00:14<00:00, 10.65it/s, Loss=0.878, Accuracy=0.241]


Epoch 77 Training Loss 1.1621 Accuracy 0.2059
Epoch 77 Validation Loss 4.2813 Validation Accuracy 0.0996


Epoch 78 Training: 100%|██████████| 152/152 [00:14<00:00, 10.63it/s, Loss=0.861, Accuracy=0.222]


Epoch 78 Training Loss 1.1434 Accuracy 0.2071
Epoch 78 Validation Loss 4.3142 Validation Accuracy 0.0974


Epoch 79 Training: 100%|██████████| 152/152 [00:14<00:00, 10.68it/s, Loss=0.812, Accuracy=0.219]


Epoch 79 Training Loss 1.1283 Accuracy 0.2083
Epoch 79 Validation Loss 4.2617 Validation Accuracy 0.0992


Epoch 80 Training: 100%|██████████| 152/152 [00:14<00:00, 10.72it/s, Loss=0.792, Accuracy=0.233]


Epoch 80 Training Loss 1.0933 Accuracy 0.2107
Epoch 80 Validation Loss 4.3258 Validation Accuracy 0.0963


Epoch 81 Training: 100%|██████████| 152/152 [00:14<00:00, 10.72it/s, Loss=0.681, Accuracy=0.249]


Epoch 81 Training Loss 1.1030 Accuracy 0.2090
Epoch 81 Validation Loss 4.2222 Validation Accuracy 0.0984


Epoch 82 Training: 100%|██████████| 152/152 [00:14<00:00, 10.66it/s, Loss=0.774, Accuracy=0.263]


Epoch 82 Training Loss 1.0805 Accuracy 0.2118
Epoch 82 Validation Loss 4.4332 Validation Accuracy 0.0983


Epoch 83 Training: 100%|██████████| 152/152 [00:14<00:00, 10.64it/s, Loss=0.739, Accuracy=0.227]


Epoch 83 Training Loss 1.0508 Accuracy 0.2138
Epoch 83 Validation Loss 4.4158 Validation Accuracy 0.0979


Epoch 84 Training: 100%|██████████| 152/152 [00:14<00:00, 10.63it/s, Loss=0.837, Accuracy=0.225]


Epoch 84 Training Loss 1.0331 Accuracy 0.2165
Epoch 84 Validation Loss 4.3703 Validation Accuracy 0.0963


Epoch 85 Training: 100%|██████████| 152/152 [00:14<00:00, 10.64it/s, Loss=0.79, Accuracy=0.23]


Epoch 85 Training Loss 1.0059 Accuracy 0.2175
Epoch 85 Validation Loss 4.4024 Validation Accuracy 0.0965


Epoch 86 Training: 100%|██████████| 152/152 [00:14<00:00, 10.68it/s, Loss=0.791, Accuracy=0.231]


Epoch 86 Training Loss 0.9933 Accuracy 0.2178
Epoch 86 Validation Loss 4.3491 Validation Accuracy 0.0987


Epoch 87 Training: 100%|██████████| 152/152 [00:14<00:00, 10.64it/s, Loss=0.69, Accuracy=0.256]


Epoch 87 Training Loss 0.9955 Accuracy 0.2190
Epoch 87 Validation Loss 4.5593 Validation Accuracy 0.0987


Epoch 88 Training: 100%|██████████| 152/152 [00:14<00:00, 10.65it/s, Loss=0.71, Accuracy=0.238]


Epoch 88 Training Loss 0.9764 Accuracy 0.2197
Epoch 88 Validation Loss 4.5670 Validation Accuracy 0.0992


Epoch 89 Training: 100%|██████████| 152/152 [00:14<00:00, 10.67it/s, Loss=0.787, Accuracy=0.219]


Epoch 89 Training Loss 0.9561 Accuracy 0.2229
Epoch 89 Validation Loss 4.4948 Validation Accuracy 0.0965


Epoch 90 Training: 100%|██████████| 152/152 [00:14<00:00, 10.72it/s, Loss=0.753, Accuracy=0.238]


Epoch 90 Training Loss 0.9377 Accuracy 0.2229
Epoch 90 Validation Loss 4.5708 Validation Accuracy 0.0917


Epoch 91 Training: 100%|██████████| 152/152 [00:14<00:00, 10.73it/s, Loss=0.656, Accuracy=0.249]


Epoch 91 Training Loss 0.9292 Accuracy 0.2231
Epoch 91 Validation Loss 4.4741 Validation Accuracy 0.0966


Epoch 92 Training: 100%|██████████| 152/152 [00:14<00:00, 10.72it/s, Loss=0.744, Accuracy=0.244]


Epoch 92 Training Loss 0.9216 Accuracy 0.2247
Epoch 92 Validation Loss 4.6817 Validation Accuracy 0.0968


Epoch 93 Training: 100%|██████████| 152/152 [00:14<00:00, 10.72it/s, Loss=0.628, Accuracy=0.241]


Epoch 93 Training Loss 0.8945 Accuracy 0.2263
Epoch 93 Validation Loss 4.6764 Validation Accuracy 0.0961


Epoch 94 Training: 100%|██████████| 152/152 [00:14<00:00, 10.72it/s, Loss=0.634, Accuracy=0.236]


Epoch 94 Training Loss 0.8927 Accuracy 0.2279
Epoch 94 Validation Loss 4.6066 Validation Accuracy 0.0966


Epoch 95 Training: 100%|██████████| 152/152 [00:14<00:00, 10.68it/s, Loss=0.645, Accuracy=0.242]


Epoch 95 Training Loss 0.8682 Accuracy 0.2293
Epoch 95 Validation Loss 4.6816 Validation Accuracy 0.0920


Epoch 96 Training: 100%|██████████| 152/152 [00:14<00:00, 10.69it/s, Loss=0.618, Accuracy=0.251]


Epoch 96 Training Loss 0.8618 Accuracy 0.2284
Epoch 96 Validation Loss 4.5859 Validation Accuracy 0.0965


Epoch 97 Training: 100%|██████████| 152/152 [00:14<00:00, 10.69it/s, Loss=0.746, Accuracy=0.244]


Epoch 97 Training Loss 0.8531 Accuracy 0.2300
Epoch 97 Validation Loss 4.7531 Validation Accuracy 0.0950


Epoch 98 Training: 100%|██████████| 152/152 [00:14<00:00, 10.69it/s, Loss=0.628, Accuracy=0.245]


Epoch 98 Training Loss 0.8363 Accuracy 0.2313
Epoch 98 Validation Loss 4.7661 Validation Accuracy 0.0956


Epoch 99 Training: 100%|██████████| 152/152 [00:14<00:00, 10.68it/s, Loss=0.623, Accuracy=0.238]


Epoch 99 Training Loss 0.8202 Accuracy 0.2336
Epoch 99 Validation Loss 4.7362 Validation Accuracy 0.0972


Epoch 100 Training: 100%|██████████| 152/152 [00:14<00:00, 10.66it/s, Loss=0.649, Accuracy=0.247]


Epoch 100 Training Loss 0.8069 Accuracy 0.2339
Epoch 100 Validation Loss 4.7815 Validation Accuracy 0.0913


In [26]:
def beam_search(model, img_tensor, start, end, k=3, max_len=50):
    batch_size = img_tensor.size(0)
    sequences_batch = [[[(start,), 1.0]] for _ in range(batch_size)]

    for _ in range(max_len):
        all_candidates_batch = [[] for _ in range(batch_size)]
        all_sequences_complete = True

        for batch in range(batch_size):
            sequences = sequences_batch[batch]
            all_candidates = all_candidates_batch[batch]
            for seq_and_score in sequences:
                seq, score = seq_and_score[0], seq_and_score[1]
                if len(seq) >= max_len or seq[-1] == end:
                    all_candidates.append(seq_and_score)
                    continue

                all_sequences_complete = False

                dec_input = torch.tensor([seq], dtype=torch.long, device=device)

                predictions = model(img_tensor[batch].unsqueeze(0), dec_input)
                predictions = predictions[:, -1:, :]
                predictions = torch.softmax(predictions, dim=-1)

                top_k_probs, top_k_tokens = torch.topk(predictions, k)
                top_k_log_probs = top_k_probs.log().squeeze()

                for j in range(k):
                    next_token = top_k_tokens[0][0][j].item()
                    if next_token == end:
                        all_candidates.append((seq + (next_token,), score + top_k_log_probs[j].item()))
                        continue
                    else:
                        candidate = (seq + (next_token,), score + top_k_log_probs[j].item())
                        all_candidates.append(candidate)

            ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True)

            sequences_batch[batch] = ordered[:k]

        if all_sequences_complete:
            break

    return sequences_batch


In [27]:
def evaluate_bleu(references, hypothesis, weights):
    bleu_scores = []
    for ref in references:
        ref_tokens = [token for token in ref.split()[1:-1] if token != '<unk>']
        hyp_tokens = [token for token in hypothesis.split()[1:-1] if token != '<unk>']
        bleu_score = sentence_bleu([ref_tokens], hyp_tokens, weights=weights, smoothing_function=SmoothingFunction().method1)
        bleu_scores.append(bleu_score * 100)
    avg_bleu_score = sum(bleu_scores) / len(bleu_scores)
    return avg_bleu_score



weights_list = [(1.0, 0, 0, 0), (0.5, 0.5, 0, 0), (0.3, 0.3, 0.3, 0), (0.25, 0.25, 0.25, 0.25)]


total_bleu_scores = {weights: [] for weights in weights_list}

for batch, (img_tensor, tar) in enumerate(test_data):
    img_tensor, tar = img_tensor.to(device), tar.to(device)
    img_tensor = img_tensor.float()
    tar = tar.long()
    start = tokenizer.word_index['<start>']
    end = tokenizer.word_index['<end>']

    batch_bleu_scores = {weights: [] for weights in weights_list}

    output_sequences_batch = beam_search(model, img_tensor, start, end, k=3, max_len=50)

    for i, (output_sequences, target) in enumerate(zip(output_sequences_batch, tar)):
        target_sentence = tokenizer.sequences_to_texts([target.cpu().numpy()])[0]

        best_generated_sentences = {weights: None for weights in weights_list}
        best_scores = {weights: -1 for weights in weights_list}

        for j, (seq, _) in enumerate(output_sequences):
            generated_sentence = tokenizer.sequences_to_texts([seq])[0]

            for weights in weights_list:
                bleu_score = evaluate_bleu([target_sentence], generated_sentence, weights)

                if bleu_score > best_scores[weights]:
                    best_generated_sentences[weights] = generated_sentence
                    best_scores[weights] = bleu_score

        for weights in weights_list:
            batch_bleu_scores[weights].append(best_scores[weights])

    for weights in weights_list:
        avg_batch_bleu_score = np.mean(batch_bleu_scores[weights])
        total_bleu_scores[weights].append(avg_batch_bleu_score)
    break

overall_avg_bleu_scores = {weights: np.mean(scores) for weights, scores in total_bleu_scores.items()}

for weights, avg_score in overall_avg_bleu_scores.items():
    print(f"Average BLEU Score for weights {weights}: {avg_score}")



Average BLEU Score for weights (1.0, 0, 0, 0): 3.8530133027160964
Average BLEU Score for weights (0.5, 0.5, 0, 0): 2.050459568302278
Average BLEU Score for weights (0.3, 0.3, 0.3, 0): 1.6340909479885535
Average BLEU Score for weights (0.25, 0.25, 0.25, 0.25): 0.9452908998030194


In [28]:
def evaluate(img_tensor, target):
    start_token = tokenizer.word_index['<start>']
    end_token = tokenizer.word_index['<end>']

    beam_width = 3
    beam_result = beam_search(model, img_tensor, start_token, end_token, k=beam_width)

    output_sequence = beam_result[0][0][0][1:-1]
    output_sequence = [tokenizer.index_word[idx] for idx in output_sequence]

    target_seq = [tokenizer.index_word[int(idx)] for idx in target if idx not in [1, 2, 0, 3]]

    print('Target sequence: ', ' '.join(target_seq))
    print('Predicted sequence: ', ' '.join(output_sequence))

    bleu_score = sentence_bleu([target_seq], output_sequence,
                               weights=(0.25, 0.25, 0.25, 0.25),
                               smoothing_function=SmoothingFunction().method1) * 100

    print('Bleu Score:', bleu_score)


In [29]:
i = 0
for img_tensors, targets in test_data:
  target = targets[i]
  img_tensor = img_tensors[i]
  img_tensor = img_tensor[None, :, :]
  img_tensor, target = img_tensor.to(device), target.to(device)
  evaluate(img_tensor, target)
  i += 1
  if (i == 5):
    break

Target sequence:  black and white dog is jumping through field of brown grass
Predicted sequence:  black and white dog running through field
Bleu Score: 27.610369103579487
Target sequence:  little girl in bathing suit leaps up in the water
Predicted sequence:  young boy jumping into the ocean
Bleu Score: 2.096016611399374
Target sequence:  the two children watch the small yellow dog on the tree
Predicted sequence:  four dogs are playing around tree
Bleu Score: 1.774239756616722
Target sequence:  boy jumping rail on his skateboard
Predicted sequence:  young man wearing gray shirt and jeans skateboards down street
Bleu Score: 0
Target sequence:  two people on grassy plain are gathering parachute that one of the people just used
Predicted sequence:  two people are in the air above the grass
Bleu Score: 3.096277640217347
