# DL2023 Assignment 3

VH: Textblocks that start with 'VH' means that I inserted that block.


---

VH: The structure of the helper files and images has to be of the following structure in order to run properly:
```bash
- code
    |- flickr8k_split (unziped)
    |- test_examples
    |- assignment3.ipynb
    |- evaluate.py
    |- get_loader.py
    |- model.py
    |- utils.py
```

VH: Model description

### Encoder
Type: Convolutional Neutral Network
Input: Images
Output: Image features

The EncoderCNN takes an image as an input, applies a pre-trained CNN model - ResNet-50 - with pretrained weights models.ResNet50_Weights.IMAGENET1K_V2 from the ImageNet. It does output the image features that can be further used by the captioning model (the decoder) for generating captions.

More precisly, in the 'forward' method, the input 'images' are passed through the pre-trained ResNet-50 model. The output features are then batch-normalized, passed through the ReLU activation function, and finally, dropout is applied to the features before returning them.

##### Changes compared to the original model:
- Using the pretrained weights IMAGENET1K_V2 of the resnet50 model to improve the accuracy.
- Added Batch-Normlization to speed up training time and improving the model's performance by normalizing the inputs to each layer.


### Decoder
Type: Recurrent Neural Network with Gated Recurrent Unit (GRU) architecture.
Input: Image features and captions
Output: Predicted word scores for each time step

The DecoderRNN class takes image features and captions as input and generates predicted word scores at each time step using the GRU architecture. These scores can be used for further steps like generating captions or computing the loss during training.

More precisly, in the 'forward' method, the input 'features' (image features) and 'captions' are passed through the embedding layer to get the embedded captions. Important here is to mention, that during training, the teacher forcing mechanism is used, ie the predictions of the previous step are not considered for the current step. The image features are concatenated with the embedded captions. Theses embeddings are then passed through the GRU layer, producing hidden states at each time step. Finally, the hidden states are passed through the linear layer to obtain the predicted word scores.

##### Changes compared to the original model:
- Using a different architecture such as the GRU network instead of the RNN network.

### Training Process
This training procedure trains the model over 301 epochs, using batches of size 256. The model is optimized using the Adam optimizer with a learning rate of 1e-4. The loss is computed using CrossEntropyLoss, and gradient clipping is applied to prevent unstable gradient updates. The procedure iterates through the dataset, updating the model's parameters to minimize the loss and improve the model's performance. However fine-tuning the CNN is not used. 

#### Changes compared to the original training process
- Gradient clipping is used to prevent exploding gradients and improve the stability of the training process.

---

VH: To load the python scripts, we need to add the folder of the python scripts to the system .

In [1]:
import sys
sys.path.append('code')

### Warning

The code for this baseline is adopted mainly from the above link:
https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/more_advanced/image_captioning


In [2]:
import os 
import json
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from utils import save_checkpoint, load_checkpoint, print_examples
from get_loader import get_loader, build_vocab
from get_loader import transform_data as transform
from model import CNNtoRNN
from evaluate import evaluate_dataset_new

### Define vocabulary

In [3]:
vocab = build_vocab("flickr8k_split/captions.txt", freq_threshold = 5)

### Define hyperparameters

In [4]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
load_model = False
save_model = True
train_CNN = False

# Hyperparameters
batch_size = 256
embed_size = 128
hidden_size = 128
vocab_size = len(vocab)
num_layers = 1
learning_rate = 1e-4
num_epochs = 301

cuda


In [5]:
run_name = 'baseline'
output_dir = os.path.join('checkpoints', run_name)
os.makedirs(output_dir , exist_ok= True)
step = 0

### Define train and test loaders

In [6]:
train_loader, dataset = get_loader(
    root_folder="flickr8k_split/train_images",
    annotation_file="flickr8k_split/train_captions.txt",
    transform=transform,
    vocab = vocab, # None
    num_workers=8,
    batch_size = batch_size
)


val_loader, _ = get_loader(
    root_folder="flickr8k_split/val_images",
    annotation_file="flickr8k_split/val_captions.txt",
    transform=transform,
    vocab = vocab, #dataset.vocab,
    num_workers=1,
    batch_size = 5,
    shuffle = False,
) 

### Define another train loder that is used to evaluate BLEU-1 score

In [7]:
train_loader_eval, _ = get_loader(
    root_folder="flickr8k_split/train_images",
    annotation_file="flickr8k_split/train_captions.txt",
    transform=transform,
    vocab = vocab, # None
    num_workers=1,
    batch_size = 5, # each image has 5 different GT captions, for each batch we will load only one image and corresponding captions
    shuffle = False,
)

### Define your model and optimizers. 

In [8]:
# initialize model, loss etc
model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<PAD>"]) # it will ignore 'PAD' words 
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [9]:
# Only finetune the CNN
for name, param in model.encoderCNN.resnet50.named_parameters():
    if "fc.weight" in name or "fc.bias" in name:
        param.requires_grad = True
    else:
        param.requires_grad = train_CNN

In [10]:
if load_model:
    step = load_checkpoint(torch.load('SPECIFIY_PATH.pth.tar'), model, optimizer)

### Training loop

In [11]:
model.train()
for epoch in range(num_epochs):
    # Uncomment the line below to see a couple of test cases
    if epoch %50 ==0: # modify for your case
        print_examples(model, device, dataset)
    
    if epoch % 50 ==0: # modify for your case
        # calculate BLUE score on the validation set
        blue_score_val = evaluate_dataset_new(model, val_loader, vocab, device)
        print('BLUE SCORES validation ', epoch, blue_score_val)
        # calculate BLUE score on the training set
        blue_score_train = evaluate_dataset_new(model, train_loader_eval, vocab, device)
        print('BLUE SCORES TRAIN', epoch, blue_score_train)
        
        # logging
        log_stats = {'BLUE-1-VAL': blue_score_val,
                     'BLUE-1-TRAIN': blue_score_train,
                    'epoch': epoch}
        f = open(os.path.join(output_dir, "log_blue.txt"), "a+")
        f.write(json.dumps(log_stats) + "\n")
        f.close()

        # save model every epoch - modify it you do not want to save it every epoch
        if save_model:
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "step": step,
            }
            save_checkpoint(checkpoint, filename = os.path.join(output_dir, 'checkpoints_'+str(epoch)+'.pth.tar'))
    
    for idx, (imgs, captions) in tqdm(
        enumerate(train_loader), total=len(train_loader)
    ):

        imgs = imgs.to(device)
        captions = captions.to(device)
        outputs = model(imgs, captions[:-1])
        loss = criterion(
            outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
        )

        step += 1

        optimizer.zero_grad()
        loss.backward(loss)
        # use clip grad if needed
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = 10.)
        optimizer.step()
        # logging
        log_stats = {'loss': loss.item(),
                    'epoch': epoch,
                    'step': step}
        f = open(os.path.join(output_dir, "log_loss.txt"), "a+")
        f.write(json.dumps(log_stats) + "\n")
        f.close()
    print("Training loss", loss.item())

Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: overhead traffic brother breaks parlor handbag carts beside hugs struggle backward surfer bears wait shirtless video sparklers stomach lockers unseen earring along face hovering papers nails university league bird saber chased heavy eat rails cowboy cowboy stars carriage desert lunges wrestlers shallow piece well piece stroller papers downtown officer chased
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: mohawk return groom lean tall paved marx bending shirt tries tackled bag motorbike rural higher wet hiking overlooking band sell dad coat without posters officer form hold overlooks hamburgers masks athlete sort butter butter pants skirts hear cream silly skateboarder there wakeboard obama shepherd breaks snap hay armenian piano sun
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: coach dim self sparklers canoe squats bandanna surfing cowboy clad baseman unique crouched crouched wai

                                                 

BLUE SCORES validation  0 0.5922767116796966


                                                   

BLUE SCORES TRAIN 0 0.7109332422297603
=> Saving checkpoint


100%|██████████| 139/139 [00:56<00:00,  2.47it/s]


Training loss 6.267791271209717


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 4.973031044006348


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 4.67972469329834


100%|██████████| 139/139 [00:51<00:00,  2.67it/s]


Training loss 4.3025031089782715


100%|██████████| 139/139 [00:52<00:00,  2.67it/s]


Training loss 4.237387657165527


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 4.092973232269287


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 3.9617254734039307


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 3.8952505588531494


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 3.818037509918213


100%|██████████| 139/139 [00:51<00:00,  2.67it/s]


Training loss 3.723111391067505


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 3.6773362159729004


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 3.65189528465271


100%|██████████| 139/139 [00:52<00:00,  2.67it/s]


Training loss 3.6105690002441406


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 3.4929091930389404


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 3.46366286277771


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 3.545114278793335


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 3.4591662883758545


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 3.5257599353790283


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 3.300682306289673


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 3.4367640018463135


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 3.243307113647461


100%|██████████| 139/139 [00:52<00:00,  2.66it/s]


Training loss 3.379138231277466


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 3.5169870853424072


100%|██████████| 139/139 [00:51<00:00,  2.67it/s]


Training loss 3.2696869373321533


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 3.25646710395813


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 3.301893472671509


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 3.1601812839508057


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 3.2086949348449707


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 3.261551856994629


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 3.2296547889709473


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 3.126885175704956


100%|██████████| 139/139 [00:52<00:00,  2.66it/s]


Training loss 3.1088643074035645


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 3.07033634185791


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 3.1574556827545166


100%|██████████| 139/139 [00:52<00:00,  2.66it/s]


Training loss 3.0515034198760986


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 2.9039928913116455


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 3.065908908843994


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 3.0890865325927734


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 3.137471914291382


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 3.0320146083831787


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 3.006582021713257


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.9351084232330322


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.869105577468872


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.8928699493408203


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 3.0230793952941895


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 3.066803216934204


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.807677745819092


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 3.0107805728912354


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 2.8413891792297363


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.914557933807373
Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: <SOS> a dog is running through the water . <EOS>
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: <SOS> a young boy in a blue shirt is jumping into the water . <EOS>
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: <SOS> a man in a white shirt is standing on a bench . <EOS>
Example 4 CORRECT: A small boat in the ocean
Example 4 OUTPUT: <SOS> a man in a blue shirt is jumping over a <UNK> . <EOS>
Example 5 CORRECT: A cowboy riding a horse in the desert
Example 5 OUTPUT: <SOS> a man in a red shirt is standing on a rock . <EOS>


                                                 

BLUE SCORES validation  50 61.07663873257


                                                   

BLUE SCORES TRAIN 50 62.69672777091198
=> Saving checkpoint


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.823176383972168


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 3.024751663208008


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.848554849624634


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.706678867340088


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.9319820404052734


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.598647356033325


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 2.7646706104278564


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.7127366065979004


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.791078805923462


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.6893866062164307


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.733637571334839


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.8283538818359375


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.665097713470459


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 2.7903800010681152


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.7123091220855713


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.66802978515625


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.7056515216827393


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.6939163208007812


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.7213892936706543


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.5831515789031982


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.78045654296875


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.726997137069702


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.6675240993499756


100%|██████████| 139/139 [00:50<00:00,  2.73it/s]


Training loss 2.6909427642822266


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.7372493743896484


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.6928493976593018


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.692370891571045


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.6794257164001465


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.784769058227539


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.7259297370910645


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.6495678424835205


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.7048816680908203


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.6970832347869873


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.693446159362793


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.5900628566741943


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.60422945022583


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.665759563446045


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.638641119003296


100%|██████████| 139/139 [00:50<00:00,  2.73it/s]


Training loss 2.6671180725097656


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.6911396980285645


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.7386910915374756


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.796226739883423


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.7117819786071777


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.5444610118865967


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.55952787399292


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.6550252437591553


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.692101240158081


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.568418264389038


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.641014575958252


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.651546001434326
Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: <SOS> a dog is running through the water . <EOS>
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: <SOS> a little boy in a blue shirt is playing with a ball in a pool . <EOS>
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: <SOS> a man in a red shirt is standing on a bench . <EOS>
Example 4 CORRECT: A small boat in the ocean
Example 4 OUTPUT: <SOS> a man in a blue shirt is jumping over a large wave . <EOS>
Example 5 CORRECT: A cowboy riding a horse in the desert
Example 5 OUTPUT: <SOS> two people are walking through a snowy area . <EOS>


                                                 

BLUE SCORES validation  100 63.3771674625641


                                                   

BLUE SCORES TRAIN 100 65.93491540482744
=> Saving checkpoint


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.6528961658477783


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.538637638092041


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.563232660293579


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.6499111652374268


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.5447118282318115


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.529703378677368


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.5917177200317383


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.469534158706665


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.5565032958984375


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.4288952350616455


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.593369483947754


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.5447921752929688


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.462951898574829


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.440250873565674


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.6021769046783447


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.5184476375579834


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.5021042823791504


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.4639928340911865


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.5113770961761475


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.5436480045318604


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.5274741649627686


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.5027053356170654


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.575165033340454


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.4078269004821777


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.564199447631836


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.496821641921997


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.4894933700561523


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.470036506652832


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.374208927154541


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.541630744934082


100%|██████████| 139/139 [00:52<00:00,  2.65it/s]


Training loss 2.505770444869995


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 2.435981512069702


100%|██████████| 139/139 [00:52<00:00,  2.66it/s]


Training loss 2.5687479972839355


100%|██████████| 139/139 [00:52<00:00,  2.65it/s]


Training loss 2.493943214416504


100%|██████████| 139/139 [00:52<00:00,  2.66it/s]


Training loss 2.436210870742798


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 2.483510971069336


100%|██████████| 139/139 [00:52<00:00,  2.67it/s]


Training loss 2.461099624633789


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.446641445159912


100%|██████████| 139/139 [00:52<00:00,  2.66it/s]


Training loss 2.473747491836548


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.5056376457214355


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.477027177810669


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.593710422515869


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.5151376724243164


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.4060606956481934


100%|██████████| 139/139 [00:52<00:00,  2.67it/s]


Training loss 2.4842989444732666


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 2.433788299560547


100%|██████████| 139/139 [00:51<00:00,  2.68it/s]


Training loss 2.465688705444336


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.36578631401062


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.481297016143799


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.524923324584961
Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: <SOS> a brown dog is running through the water . <EOS>
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: <SOS> a boy in a blue shirt is playing in the water . <EOS>
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: <SOS> a man in a red shirt and a red shirt is standing on a sidewalk . <EOS>
Example 4 CORRECT: A small boat in the ocean
Example 4 OUTPUT: <SOS> a man in a blue shirt is jumping into the air . <EOS>
Example 5 CORRECT: A cowboy riding a horse in the desert
Example 5 OUTPUT: <SOS> two people are walking through a snowy area . <EOS>


                                                 

BLUE SCORES validation  150 64.2905059579


                                                   

BLUE SCORES TRAIN 150 67.87989267952358
=> Saving checkpoint


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.5233101844787598


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.4357705116271973


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.535033702850342


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.462601661682129


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.401498317718506


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.5184342861175537


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.383225440979004


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.399661064147949


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.3319642543792725


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.3599488735198975


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.404921531677246


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.4001991748809814


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.433535575866699


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.463986396789551


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.4650614261627197


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.4032649993896484


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.367421865463257


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.4218456745147705


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.4769539833068848


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.3962392807006836


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.355398416519165


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.3802011013031006


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.4072654247283936


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.3769893646240234


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.416626214981079


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.3835554122924805


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.49267840385437


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.4046499729156494


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.463934898376465


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.4518511295318604


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.3267135620117188


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.4607973098754883


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.409794569015503


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.368373155593872


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.337284803390503


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.3610785007476807


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.457327127456665


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.347020149230957


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.4238085746765137


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.390522003173828


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.4089856147766113


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.3911032676696777


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.309317111968994


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.3337175846099854


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.344980001449585


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.3519539833068848


100%|██████████| 139/139 [00:51<00:00,  2.73it/s]


Training loss 2.5023953914642334


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.333765745162964


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.2632832527160645


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.3752493858337402
Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: <SOS> a brown dog is running through the water . <EOS>
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: <SOS> a little boy in a blue shirt is playing with a ball in a park . <EOS>
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: <SOS> a man in a red shirt is standing on a motorcycle . <EOS>
Example 4 CORRECT: A small boat in the ocean
Example 4 OUTPUT: <SOS> a man in a wetsuit is surfing on a surfboard . <EOS>
Example 5 CORRECT: A cowboy riding a horse in the desert
Example 5 OUTPUT: <SOS> two people are walking through the snow . <EOS>


                                                 

BLUE SCORES validation  200 64.52307460937206


                                                   

BLUE SCORES TRAIN 200 68.69659456233428
=> Saving checkpoint


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.3077292442321777


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.3796467781066895


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.378504991531372


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.3503928184509277


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.3587281703948975


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.2373712062835693


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.332265853881836


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.256037712097168


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.4043171405792236


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.241027355194092


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.341268539428711


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.3136866092681885


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.1780500411987305


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.241220474243164


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.462735176086426


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.3558907508850098


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.36376690864563


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.300534725189209


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.2805471420288086


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.35023832321167


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.2375895977020264


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.304307222366333


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.2640843391418457


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.2371487617492676


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.1661088466644287


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.249478578567505


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.421264410018921


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.3742384910583496


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.374037504196167


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.358518123626709


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.313727617263794


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.1876955032348633


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.3432705402374268


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.373948097229004


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.294029474258423


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.3534884452819824


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.326349973678589


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.261746406555176


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.372321605682373


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.2443010807037354


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.1975014209747314


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.212489128112793


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.3954100608825684


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.3685147762298584


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.24245285987854


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.374094247817993


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.2033472061157227


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.2624425888061523


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.4395973682403564


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.3247246742248535
Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: <SOS> a brown dog is running through the water . <EOS>
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: <SOS> a young boy in a blue shirt is playing with a ball in a park . <EOS>
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: <SOS> a man in a red shirt and a red shirt is standing on a motorcycle . <EOS>
Example 4 CORRECT: A small boat in the ocean
Example 4 OUTPUT: <SOS> a man is standing on a dock . <EOS>
Example 5 CORRECT: A cowboy riding a horse in the desert
Example 5 OUTPUT: <SOS> a man in a black jacket is standing on a mountain . <EOS>


                                                 

BLUE SCORES validation  250 65.09102946487012


                                                   

BLUE SCORES TRAIN 250 69.44434649023182
=> Saving checkpoint


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.2920594215393066


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.3256149291992188


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.2652511596679688


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.271575927734375


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.375351667404175


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.325971841812134


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.3899080753326416


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.273435354232788


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.2382802963256836


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.2739806175231934


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.302678346633911


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.3667030334472656


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.261029005050659


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.323481798171997


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.157039165496826


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.3702948093414307


100%|██████████| 139/139 [00:51<00:00,  2.73it/s]


Training loss 2.2434778213500977


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.3072509765625


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.306252956390381


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.2315750122070312


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.18969988822937


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.381415605545044


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.1956372261047363


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.272514581680298


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.307643413543701


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.2728796005249023


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.219266176223755


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.2538270950317383


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.17399001121521


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.2617580890655518


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.3451087474823


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.185770273208618


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.187364339828491


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.236971855163574


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.394460916519165


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.2880587577819824


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.258474826812744


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.2384026050567627


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.2262814044952393


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.3279542922973633


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.2900123596191406


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.2142927646636963


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.2086429595947266


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.2121670246124268


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.2426416873931885


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.3595991134643555


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]


Training loss 2.1584460735321045


100%|██████████| 139/139 [00:51<00:00,  2.71it/s]


Training loss 2.098145008087158


100%|██████████| 139/139 [00:51<00:00,  2.72it/s]


Training loss 2.063641309738159


100%|██████████| 139/139 [00:51<00:00,  2.69it/s]


Training loss 2.1832804679870605
Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: <SOS> a brown dog is running through the water . <EOS>
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: <SOS> a little girl in a pink shirt is playing with a ball in a park . <EOS>
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: <SOS> a man in a red shirt and a red shirt is standing on a motorcycle . <EOS>
Example 4 CORRECT: A small boat in the ocean
Example 4 OUTPUT: <SOS> a man is standing on a dock . <EOS>
Example 5 CORRECT: A cowboy riding a horse in the desert
Example 5 OUTPUT: <SOS> a man in a red jacket is standing on a mountain . <EOS>


                                                 

BLUE SCORES validation  300 65.66297116545734


                                                   

BLUE SCORES TRAIN 300 70.03197163048456
=> Saving checkpoint


100%|██████████| 139/139 [00:51<00:00,  2.70it/s]

Training loss 2.2188913822174072





VH: Hence, the training blue score is 70.03 and the validation blue score is 65.66.

VH: Run the model for the test set to generate the captions.

In [28]:
import csv
from PIL import Image
import fnmatch

# Open the output file in write mode
with open('output.txt', mode='w') as file:
    model.eval()

    # Create a CSV writer object
    writer = csv.writer(file)

    # Write the header row
    writer.writerow(['image', 'caption'])
    directory = "flickr8k_split/test_images"
    test_dir_size = count = len(fnmatch.filter(os.listdir(directory), '*.*'))

    # Loop through the test images
    for i, image_name in enumerate(os.listdir(directory)):
        # Load the test image
        image_path = os.path.join(directory, image_name)
        test_img = transform(Image.open(image_path).convert("RGB")).unsqueeze(0)

        # Generate the caption using the model
        caption = " ".join(model.caption_image(test_img.to(device), dataset.vocab))
        caption = caption[len('<SOS> '): -len(' <EOS>')].strip().capitalize()

        # Write the image path and caption to the CSV file
        writer.writerow([image_name, caption])

        # Print the progress
        print(f'{i+1}/{test_dir_size}: {image_path} => {caption}')

1/500: flickr8k_split/test_images\130211457_be3f6b335d.jpg => A woman in a red shirt and a white shirt is standing on a sidewalk .
2/500: flickr8k_split/test_images\131632409_4de0d4e710.jpg => A girl in a pink bathing suit is standing on a beach .
3/500: flickr8k_split/test_images\132489044_3be606baf7.jpg => A woman in a white shirt and a white shirt is sitting on a bench .
4/500: flickr8k_split/test_images\133189853_811de6ab2a.jpg => A man in a black shirt and sunglasses is standing in front of a crowd .
5/500: flickr8k_split/test_images\133905560_9d012b47f3.jpg => A white dog is running through the snow .
6/500: flickr8k_split/test_images\134724228_30408cd77f.jpg => A man and a woman are standing in the middle of a large body of water .
7/500: flickr8k_split/test_images\134894450_dadea45d65.jpg => A boy in a red shirt is jumping off a swing .
8/500: flickr8k_split/test_images\135235570_5698072cd4.jpg => A man in a black shirt and white shorts is standing on a sidewalk with a woman in