## Before Running:
Please Install all from the requirements.txt (pip install -r requirements.txt).

## Set Hyper Parameters

In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder_decoder_layers = 3
encoder_decoder_heads = 8
embedded_dim = 768 # Don't change
max_length = 32 
coco_dataset_ratio = 50
coco_dataset_dir = "./coco"
vocab_size = 50257 # Don't change
batch_size = 64
num_epochs = 2
learning_rate = 1e-4
patience = 3
weight_decay = 1e-5
preprocess_swin_model = "microsoft/swin-tiny-patch4-window7-224"
encoder_model = "microsoft/swin-base-patch4-window7-224-in22k"
decoder_model = "gpt2"

## Downloading and Format datasets
This will take some time to finishing running the first time. It took me roughly 40 minutes.

This section does the following actions:
1. Downloads the Dataset
2. Keeps images with only 3 or 4 dim
3. Transforms the dataset 
4. Turns the data set into data loaders


In [7]:
import numpy as np
from datasets import load_dataset
from transformers import ViTImageProcessor, GPT2TokenizerFast, AutoImageProcessor, SwinModel
from torch.utils.data import DataLoader, Dataset
import torch
import os

# Download the train, val and test splits of the COCO dataset
train_ds = load_dataset("HuggingFaceM4/COCO", split=f"train[:{coco_dataset_ratio}%]", cache_dir=coco_dataset_dir)
valid_ds = load_dataset("HuggingFaceM4/COCO", split=f"validation[:{coco_dataset_ratio}%]", cache_dir=coco_dataset_dir)
test_ds = load_dataset("HuggingFaceM4/COCO", split="test", cache_dir=coco_dataset_dir)

# Filter all non 3 or 4 dim images out
# Can change num_proc, but might be errors with np
train_ds = train_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)
valid_ds = valid_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)
test_ds = test_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)

# Does pre processing on the data set
# This includes pre-trained ViTimage feature extraction and tokenizing captions
tokenizer = GPT2TokenizerFast.from_pretrained(decoder_model)
tokenizer.pad_token = tokenizer.eos_token
image_processor = ViTImageProcessor.from_pretrained(encoder_model)
image_processor_swin = SwinModel.from_pretrained(preprocess_swin_model).to(device)

def preprocess(items):
    # Image pre-processing
    # use ViT and SWIN since no back prop
    pixel_values = image_processor(items["image"], return_tensors="pt").pixel_values.to(device)
    with torch.no_grad():
        pixel_values = image_processor_swin (pixel_values).last_hidden_state
    pixel_values = pixel_values.to('cpu')

    # tokenize
    targets = tokenizer(items["sentences"]['raw'],
                        max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
    
    # Keep image file for easy examples later
    img_file = items['filepath']
    return {'pixel_values': pixel_values, 'labels': targets["input_ids"], 'image_file': img_file}


TRAIN_SET_PATH = './preprocessed/train_set.pt'
TEST_SET_PATH = './preprocessed/val_set.pt'
VALID_SET_PATH = './preprocessed/test_set.pt'

# Pre process train dataset if it doesn't exist
train_dataset = None
if os.path.isfile(TRAIN_SET_PATH):
    train_dataset = torch.load(TRAIN_SET_PATH)

else:
    train_dataset = train_ds.map(preprocess)
    torch.save(train_dataset, TRAIN_SET_PATH)
    
# Pre process val dataset if it doesn't exist
valid_dataset = None
if os.path.isfile(VALID_SET_PATH):
    valid_dataset = torch.load(VALID_SET_PATH)

else:
    valid_dataset = valid_ds.map(preprocess)
    torch.save(valid_dataset, VALID_SET_PATH)

# Pre process test dataset if it doesn't exist
test_dataset = None
if os.path.isfile(TEST_SET_PATH):
    test_dataset = torch.load(TEST_SET_PATH)

else:
    test_dataset = test_ds.map(preprocess)
    torch.save(test_dataset, TEST_SET_PATH)



# Turns the dataset into a torch DataLoader
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([torch.tensor(x['pixel_values']) for x in batch]),
        'labels': torch.stack([torch.tensor(x['labels']) for x in batch]),
        'image_file': [x["image_file"] for x in batch]
    }

train_dataset_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
valid_dataset_loader = DataLoader(valid_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)
test_dataset_loader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


## Define the Model
Creates the PureT model from the paper

This section does the following actions:
1. Creates the SWIN Transformer used by PureT
2. Creates the PureT encoder
3. Creates the PureT decoder
4. Creates the PureT model

Download the pre-trained SwinT weights from here https://drive.google.com/drive/folders/1HBw5NGGw8DjkyNurksCP5v8a5f0FG7zU and put them in this folder before running

In [11]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from transformers import SwinModel

class BasicCaptionTransformer(torch.nn.Module):
    def __init__(self):
        super(BasicCaptionTransformer, self).__init__()

        # Image processing pre-done with ViTImageProcessor and Swin in the Downloading and Format datasets step
        #self.swin = SwinModel.from_pretrained(preprocess_swin_model)

        # build encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedded_dim, nhead=encoder_decoder_heads, batch_first=True)
        self.encoders = nn.TransformerEncoder(encoder_layer, encoder_decoder_layers)

        # embeddings
        self.embeddings = nn.Embedding(vocab_size, embedded_dim)

        # build decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model=embedded_dim, nhead=encoder_decoder_heads, batch_first=True)
        self.decoders = nn.TransformerDecoder(decoder_layer, encoder_decoder_layers)
        
        # decoder mask
        mask = torch.BoolTensor(max_length, max_length)
        for i in range(max_length):
            for j in range(max_length):
                if (i >= j):
                    mask[i][j] = False
                else:
                    mask[i][j] = True
        self.mask = mask

        # final
        self.linear = nn.Linear(in_features=embedded_dim, out_features=vocab_size)


    def forward(self, images, captions):
        x = images
        #with torch.no_grad():
        #    x = self.swin(x)
        x = self.encoders(x)
        
        captions = self.embeddings(captions)
        x = self.decoders(tgt=captions.to(x.dtype), memory=x, tgt_mask=self.mask)
        return self.linear(x)
    
mod = BasicCaptionTransformer()

tensor([[False,  True,  True,  ...,  True,  True,  True],
        [False, False,  True,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False,  True,  True],
        [False, False, False,  ..., False, False,  True],
        [False, False, False,  ..., False, False, False]])


## Training loop
Trains the model and saves the best (lowest val error) and last model

This section does the following actions:
1. Creates the Basic transformer model
2. Sets up optimizer, scheduler, counter for training
3. Trains for num_epochs epochs
2. Each Epoch has valadation accuracy calculated
3. Save the model with the best valadation accuracy TODO
4. Save the model when the max number of epochs has been reached TODO

In [9]:
from tqdm import tqdm
from torch.optim.lr_scheduler import ExponentialLR

# Loop setup
model = BasicCaptionTransformer()
model = model.to(device)

loss_function = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = ExponentialLR(optimizer, gamma=0.9)


# Train
predictions_per_batch = batch_size * max_length
stop_counter = 0
train_losses = []
val_losses = []
best_val_loss = float('inf')

train_len = len(train_dataset_loader)
val_len = len(valid_dataset_loader)


for epoch in range(num_epochs):
    model.train()

    # Loop through training data loader batches
    train_loss = 0.0
    train_dataloader_iter = tqdm(train_dataset_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)
    for i, data in enumerate(train_dataloader_iter):
        
        # Get values from data loader
        pixel_vals = data["pixel_values"].squeeze(1).to(device)
        captions = data["labels"].squeeze(1).to(device)

        optimizer.zero_grad()
        outputs = model(images=pixel_vals, captions=captions)

        loss = loss_function(outputs.permute(0,2,1), captions)
        loss.backward()
        optimizer.step()

        # save loss
        train_loss += loss.item()
        if i % 100 == 99:
            print ("Loss so far is: " + str (train_loss / i))
            #print ("tensor: ", outputs)
    
    scheduler.step()


    # Validation
    val_loss = 0.0
    valid_dataset_iter = tqdm(valid_dataset_loader,  desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)
    with torch.no_grad():
        for i, data in enumerate(valid_dataset_iter):
            
            # Get values from data loader
            pixel_vals = data["pixel_values"].squeeze(1).to(device)
            captions = data["labels"].squeeze(1).to(device)

            outputs = model(images=pixel_vals, captions=captions)
            loss = loss_function(outputs.permute(0,2,1), captions)

            # save loss
            val_loss += loss.item()


    # Save model
    torch.save(model.state_dict(), f'./models/model_epoch_{epoch+1}.pt')

    train_losses.append(train_loss / train_len)
    val_losses.append(val_loss / val_len)

    # Print losses
    # divide by length of data loader... 
    print("\nEpoch: " + str(epoch) + 
        "\nTrain Loss: " + str(train_loss / train_len) +
        "\nVal Loss: " + str(val_loss / val_len))

print (train_losses)
print (val_losses)


Epoch 1/10:   2%|▏         | 100/4418 [02:14<1:28:56,  1.24s/it]

Loss so far is: 2.2383471829722628


Epoch 1/10:   5%|▍         | 200/4418 [04:22<1:30:26,  1.29s/it]

Loss so far is: 1.473855088553836


Epoch 1/10:   7%|▋         | 300/4418 [06:39<1:36:12,  1.40s/it]

Loss so far is: 1.1302580444709114


Epoch 1/10:   9%|▉         | 400/4418 [08:51<1:26:47,  1.30s/it]

Loss so far is: 0.9281628365654097


Epoch 1/10:  11%|█▏        | 500/4418 [11:03<1:24:49,  1.30s/it]

Loss so far is: 0.7941833448732545


Epoch 1/10:  14%|█▎        | 600/4418 [13:12<1:21:37,  1.28s/it]

Loss so far is: 0.6976284710612639


Epoch 1/10:  16%|█▌        | 700/4418 [15:22<1:20:01,  1.29s/it]

Loss so far is: 0.6235939171884705


Epoch 1/10:  18%|█▊        | 800/4418 [17:33<1:20:50,  1.34s/it]

Loss so far is: 0.5662909750272992


Epoch 1/10:  20%|██        | 900/4418 [19:42<1:16:24,  1.30s/it]

Loss so far is: 0.5187142032703117


Epoch 1/10:  23%|██▎       | 1000/4418 [21:56<1:25:33,  1.50s/it]

Loss so far is: 0.47915351732789696


Epoch 1/10:  25%|██▍       | 1100/4418 [24:21<1:10:53,  1.28s/it]

Loss so far is: 0.44607896525226365


Epoch 1/10:  27%|██▋       | 1200/4418 [26:48<1:21:38,  1.52s/it]

Loss so far is: 0.4175792997633049


Epoch 1/10:  29%|██▉       | 1300/4418 [29:04<1:12:35,  1.40s/it]

Loss so far is: 0.39258438328695444


Epoch 1/10:  32%|███▏      | 1400/4418 [31:17<1:05:03,  1.29s/it]

Loss so far is: 0.3707203237290592


Epoch 1/10:  34%|███▍      | 1500/4418 [33:30<1:02:45,  1.29s/it]

Loss so far is: 0.3514191595100617


Epoch 1/10:  36%|███▌      | 1600/4418 [35:43<1:10:52,  1.51s/it]

Loss so far is: 0.33407198980315506


Epoch 1/10:  38%|███▊      | 1700/4418 [37:55<57:34,  1.27s/it]  

Loss so far is: 0.318456441044518


Epoch 1/10:  41%|████      | 1800/4418 [40:03<59:45,  1.37s/it]  

Loss so far is: 0.3042808462538989


Epoch 1/10:  43%|████▎     | 1900/4418 [42:13<52:53,  1.26s/it]

Loss so far is: 0.2915240293844412


Epoch 1/10:  45%|████▌     | 2000/4418 [44:18<50:30,  1.25s/it]

Loss so far is: 0.2798953792518224


Epoch 1/10:  48%|████▊     | 2100/4418 [46:24<48:29,  1.26s/it]

Loss so far is: 0.2692509806218454


Epoch 1/10:  50%|████▉     | 2200/4418 [48:30<46:12,  1.25s/it]

Loss so far is: 0.2593470756019575


Epoch 1/10:  52%|█████▏    | 2300/4418 [50:37<44:28,  1.26s/it]

Loss so far is: 0.250236824315273


Epoch 1/10:  54%|█████▍    | 2400/4418 [52:42<42:37,  1.27s/it]

Loss so far is: 0.24183109724484989


Epoch 1/10:  57%|█████▋    | 2500/4418 [54:49<40:40,  1.27s/it]

Loss so far is: 0.23397880769809898


Epoch 1/10:  59%|█████▉    | 2600/4418 [56:54<38:09,  1.26s/it]

Loss so far is: 0.2266582198137874


Epoch 1/10:  61%|██████    | 2700/4418 [59:00<38:31,  1.35s/it]

Loss so far is: 0.2198542242482575


Epoch 1/10:  63%|██████▎   | 2800/4418 [1:01:17<34:46,  1.29s/it]

Loss so far is: 0.2133803132646286


Epoch 1/10:  66%|██████▌   | 2900/4418 [1:03:27<33:17,  1.32s/it]

Loss so far is: 0.20729510254516076


Epoch 1/10:  68%|██████▊   | 3000/4418 [1:05:36<30:24,  1.29s/it]

Loss so far is: 0.20164634901598572


Epoch 1/10:  70%|███████   | 3100/4418 [1:07:45<28:22,  1.29s/it]

Loss so far is: 0.19626700603509836


Epoch 1/10:  72%|███████▏  | 3200/4418 [1:09:54<25:54,  1.28s/it]

Loss so far is: 0.1912099159836231


Epoch 1/10:  75%|███████▍  | 3300/4418 [1:12:04<24:15,  1.30s/it]

Loss so far is: 0.18644391618644754


Epoch 1/10:  77%|███████▋  | 3400/4418 [1:14:13<21:50,  1.29s/it]

Loss so far is: 0.18197275173819397


Epoch 1/10:  79%|███████▉  | 3500/4418 [1:16:22<19:37,  1.28s/it]

Loss so far is: 0.17766370519020983


Epoch 1/10:  81%|████████▏ | 3600/4418 [1:18:31<17:28,  1.28s/it]

Loss so far is: 0.17357952231409302


Epoch 1/10:  84%|████████▎ | 3700/4418 [1:20:41<15:26,  1.29s/it]

Loss so far is: 0.16969148371815299


Epoch 1/10:  86%|████████▌ | 3800/4418 [1:22:50<13:28,  1.31s/it]

Loss so far is: 0.16600091625215835


Epoch 1/10:  88%|████████▊ | 3900/4418 [1:24:59<11:11,  1.30s/it]

Loss so far is: 0.16251021785750913


Epoch 1/10:  91%|█████████ | 4000/4418 [1:27:08<08:57,  1.29s/it]

Loss so far is: 0.15914839280666293


Epoch 1/10:  93%|█████████▎| 4100/4418 [1:29:17<06:52,  1.30s/it]

Loss so far is: 0.1559599619132707


Epoch 1/10:  95%|█████████▌| 4200/4418 [1:31:27<04:39,  1.28s/it]

Loss so far is: 0.15293128268492026


Epoch 1/10:  97%|█████████▋| 4300/4418 [1:33:37<02:35,  1.31s/it]

Loss so far is: 0.150001976362766


Epoch 1/10: 100%|█████████▉| 4400/4418 [1:35:47<00:23,  1.29s/it]

Loss so far is: 0.14714507172666455


                                                                 


Epoch: 0
Train Loss: 0.1466198884141328
Val Loss: 0.025249600248076976


Epoch 2/10:   2%|▏         | 100/4418 [02:06<1:29:45,  1.25s/it]

Loss so far is: 0.019475546486750997


Epoch 2/10:   5%|▍         | 200/4418 [04:10<1:28:06,  1.25s/it]

Loss so far is: 0.01951482949132596


Epoch 2/10:   7%|▋         | 300/4418 [06:17<1:25:52,  1.25s/it]

Loss so far is: 0.01994005465579459


Epoch 2/10:   9%|▉         | 400/4418 [08:26<1:23:49,  1.25s/it]

Loss so far is: 0.019820754440503668


Epoch 2/10:  11%|█▏        | 500/4418 [10:43<1:25:12,  1.30s/it]

Loss so far is: 0.019577765179413533


Epoch 2/10:  14%|█▎        | 600/4418 [12:52<1:20:30,  1.27s/it]

Loss so far is: 0.01957675511015887


Epoch 2/10:  16%|█▌        | 700/4418 [15:16<3:44:52,  3.63s/it]

Loss so far is: 0.01968293181901667


Epoch 2/10:  18%|█▊        | 800/4418 [17:40<1:14:19,  1.23s/it]

Loss so far is: 0.019717961093568208


Epoch 2/10:  20%|██        | 900/4418 [19:43<1:11:59,  1.23s/it]

Loss so far is: 0.019532507648189097


Epoch 2/10:  23%|██▎       | 1000/4418 [21:47<1:10:13,  1.23s/it]

Loss so far is: 0.01949796915015633


Epoch 2/10:  25%|██▍       | 1100/4418 [23:51<1:08:20,  1.24s/it]

Loss so far is: 0.019512600913407403


Epoch 2/10:  27%|██▋       | 1200/4418 [25:55<1:06:48,  1.25s/it]

Loss so far is: 0.01944839896197688


Epoch 2/10:  29%|██▉       | 1300/4418 [27:59<1:04:42,  1.25s/it]

Loss so far is: 0.019276568136092638


Epoch 2/10:  32%|███▏      | 1400/4418 [30:04<1:03:28,  1.26s/it]

Loss so far is: 0.019106618420795415


Epoch 2/10:  34%|███▍      | 1500/4418 [32:09<1:00:36,  1.25s/it]

Loss so far is: 0.018998923099191586


Epoch 2/10:  36%|███▌      | 1600/4418 [34:13<58:52,  1.25s/it]  

Loss so far is: 0.018885166608519372


Epoch 2/10:  38%|███▊      | 1700/4418 [36:19<56:23,  1.24s/it]

Loss so far is: 0.018857284027986063


Epoch 2/10:  41%|████      | 1800/4418 [38:25<55:45,  1.28s/it]

Loss so far is: 0.018836522576805015


Epoch 2/10:  43%|████▎     | 1900/4418 [40:31<53:33,  1.28s/it]

Loss so far is: 0.01871010717273047


Epoch 2/10:  45%|████▌     | 2000/4418 [42:37<50:59,  1.27s/it]

Loss so far is: 0.018678491031357356


Epoch 2/10:  48%|████▊     | 2100/4418 [44:43<48:41,  1.26s/it]

Loss so far is: 0.018635275621404636


Epoch 2/10:  50%|████▉     | 2200/4418 [46:50<46:41,  1.26s/it]

Loss so far is: 0.01854689058726175


Epoch 2/10:  52%|█████▏    | 2300/4418 [48:56<43:59,  1.25s/it]

Loss so far is: 0.01845265369504413


Epoch 2/10:  54%|█████▍    | 2400/4418 [51:02<42:22,  1.26s/it]

Loss so far is: 0.018365945538699822


Epoch 2/10:  57%|█████▋    | 2500/4418 [53:09<40:24,  1.26s/it]

Loss so far is: 0.018204787751056384


Epoch 2/10:  59%|█████▉    | 2600/4418 [55:17<37:48,  1.25s/it]

Loss so far is: 0.018112277766489347


Epoch 2/10:  61%|██████    | 2700/4418 [57:23<36:03,  1.26s/it]

Loss so far is: 0.018046267156476427


Epoch 2/10:  63%|██████▎   | 2800/4418 [59:31<34:26,  1.28s/it]

Loss so far is: 0.01797824882060055


Epoch 2/10:  66%|██████▌   | 2900/4418 [1:01:37<32:06,  1.27s/it]

Loss so far is: 0.01787244088551182


Epoch 2/10:  68%|██████▊   | 3000/4418 [1:03:45<29:43,  1.26s/it]

Loss so far is: 0.017799410517099237


Epoch 2/10:  70%|███████   | 3100/4418 [1:05:52<28:07,  1.28s/it]

Loss so far is: 0.017741614514632766


Epoch 2/10:  72%|███████▏  | 3200/4418 [1:07:59<25:43,  1.27s/it]

Loss so far is: 0.017688980452752304


Epoch 2/10:  75%|███████▍  | 3300/4418 [1:10:06<23:41,  1.27s/it]

Loss so far is: 0.017652169907811825


Epoch 2/10:  77%|███████▋  | 3400/4418 [1:12:12<21:12,  1.25s/it]

Loss so far is: 0.017593012174110415


Epoch 2/10:  79%|███████▉  | 3500/4418 [1:14:19<19:33,  1.28s/it]

Loss so far is: 0.01754961338122241


Epoch 2/10:  81%|████████▏ | 3600/4418 [1:16:26<17:22,  1.27s/it]

Loss so far is: 0.017469396630765147


Epoch 2/10:  84%|████████▎ | 3700/4418 [1:18:32<15:07,  1.26s/it]

Loss so far is: 0.017387536800228295


Epoch 2/10:  86%|████████▌ | 3800/4418 [1:20:39<12:54,  1.25s/it]

Loss so far is: 0.017323796819778078


Epoch 2/10:  88%|████████▊ | 3900/4418 [1:22:46<10:58,  1.27s/it]

Loss so far is: 0.01724914757300299


Epoch 2/10:  91%|█████████ | 4000/4418 [1:24:52<08:49,  1.27s/it]

Loss so far is: 0.01721749418666591


Epoch 2/10:  93%|█████████▎| 4100/4418 [1:26:59<06:50,  1.29s/it]

Loss so far is: 0.017185278738724243


Epoch 2/10:  95%|█████████▌| 4200/4418 [1:29:06<04:37,  1.27s/it]

Loss so far is: 0.017102785878497034


Epoch 2/10:  97%|█████████▋| 4300/4418 [1:31:12<02:30,  1.28s/it]

Loss so far is: 0.0170251180828065


Epoch 2/10: 100%|█████████▉| 4400/4418 [1:33:19<00:22,  1.27s/it]

Loss so far is: 0.016945736534569274


                                                                 


Epoch: 1
Train Loss: 0.016925680291919945
Val Loss: 0.01608021901263736


Epoch 3/10:   2%|▏         | 100/4418 [02:02<1:27:58,  1.22s/it]

Loss so far is: 0.011716289217630871


Epoch 3/10:   5%|▍         | 200/4418 [04:04<1:25:20,  1.21s/it]

Loss so far is: 0.01131978802366566


Epoch 3/10:   7%|▋         | 300/4418 [06:11<1:23:43,  1.22s/it]

Loss so far is: 0.011066823577731038


Epoch 3/10:   9%|▉         | 400/4418 [08:15<1:24:01,  1.25s/it]

Loss so far is: 0.01122155762408839


Epoch 3/10:  11%|█▏        | 500/4418 [10:18<1:19:47,  1.22s/it]

Loss so far is: 0.011239349839505263


Epoch 3/10:  14%|█▎        | 600/4418 [12:20<1:18:13,  1.23s/it]

Loss so far is: 0.011241838235512239


Epoch 3/10:  16%|█▌        | 700/4418 [14:23<1:16:27,  1.23s/it]

Loss so far is: 0.011181405421016


Epoch 3/10:  18%|█▊        | 800/4418 [16:26<1:15:04,  1.25s/it]

Loss so far is: 0.01108693949289874


Epoch 3/10:  20%|██        | 900/4418 [18:31<1:12:51,  1.24s/it]

Loss so far is: 0.011026438951553948


Epoch 3/10:  23%|██▎       | 1000/4418 [20:34<1:10:10,  1.23s/it]

Loss so far is: 0.011046157061139605


Epoch 3/10:  25%|██▍       | 1100/4418 [22:39<1:08:25,  1.24s/it]

Loss so far is: 0.01106677797988008


Epoch 3/10:  27%|██▋       | 1200/4418 [24:43<1:06:34,  1.24s/it]

Loss so far is: 0.011135843557789548


Epoch 3/10:  29%|██▉       | 1300/4418 [26:47<1:04:25,  1.24s/it]

Loss so far is: 0.011070838756169234


Epoch 3/10:  32%|███▏      | 1400/4418 [28:52<1:02:39,  1.25s/it]

Loss so far is: 0.01111622751414013


Epoch 3/10:  34%|███▍      | 1500/4418 [30:57<1:01:32,  1.27s/it]

Loss so far is: 0.011083858939261424


Epoch 3/10:  36%|███▌      | 1600/4418 [33:02<59:20,  1.26s/it]  

Loss so far is: 0.011056698746727331


Epoch 3/10:  38%|███▊      | 1700/4418 [35:08<56:42,  1.25s/it]

Loss so far is: 0.011077722290846267


Epoch 3/10:  41%|████      | 1800/4418 [37:13<54:34,  1.25s/it]

Loss so far is: 0.011090581501367398


Epoch 3/10:  43%|████▎     | 1900/4418 [39:19<52:26,  1.25s/it]

Loss so far is: 0.011068285327147754


Epoch 3/10:  45%|████▌     | 2000/4418 [41:25<50:44,  1.26s/it]

Loss so far is: 0.011132632907276872


Epoch 3/10:  48%|████▊     | 2100/4418 [43:32<48:40,  1.26s/it]

Loss so far is: 0.011161062083987537


Epoch 3/10:  50%|████▉     | 2200/4418 [45:38<47:07,  1.27s/it]

Loss so far is: 0.01119055732462738


Epoch 3/10:  52%|█████▏    | 2300/4418 [47:45<43:59,  1.25s/it]

Loss so far is: nan


Epoch 3/10:  54%|█████▍    | 2400/4418 [49:51<42:24,  1.26s/it]

Loss so far is: nan


Epoch 3/10:  57%|█████▋    | 2500/4418 [51:56<40:00,  1.25s/it]

Loss so far is: nan


Epoch 3/10:  59%|█████▉    | 2600/4418 [54:02<37:37,  1.24s/it]

Loss so far is: nan


Epoch 3/10:  61%|██████    | 2700/4418 [56:08<36:03,  1.26s/it]

Loss so far is: nan


Epoch 3/10:  63%|██████▎   | 2800/4418 [58:14<34:13,  1.27s/it]

Loss so far is: nan


Epoch 3/10:  66%|██████▌   | 2900/4418 [1:00:20<32:07,  1.27s/it]

Loss so far is: nan


Epoch 3/10:  68%|██████▊   | 3000/4418 [1:02:26<29:56,  1.27s/it]

Loss so far is: nan


Epoch 3/10:  70%|███████   | 3100/4418 [1:04:32<28:00,  1.28s/it]

Loss so far is: nan


Epoch 3/10:  72%|███████▏  | 3200/4418 [1:06:38<25:22,  1.25s/it]

Loss so far is: nan


Epoch 3/10:  75%|███████▍  | 3300/4418 [1:08:44<23:14,  1.25s/it]

Loss so far is: nan


Epoch 3/10:  77%|███████▋  | 3400/4418 [1:10:51<21:22,  1.26s/it]

Loss so far is: nan


Epoch 3/10:  79%|███████▉  | 3500/4418 [1:12:57<19:41,  1.29s/it]

Loss so far is: nan


Epoch 3/10:  81%|████████▏ | 3600/4418 [1:15:03<17:08,  1.26s/it]

Loss so far is: nan


Epoch 3/10:  84%|████████▎ | 3700/4418 [1:17:09<15:04,  1.26s/it]

Loss so far is: nan


Epoch 3/10:  86%|████████▌ | 3800/4418 [1:19:17<13:02,  1.27s/it]

Loss so far is: nan


Epoch 3/10:  88%|████████▊ | 3900/4418 [1:21:24<10:55,  1.26s/it]

Loss so far is: nan


Epoch 3/10:  91%|█████████ | 4000/4418 [1:23:30<08:51,  1.27s/it]

Loss so far is: nan


Epoch 3/10:  93%|█████████▎| 4100/4418 [1:25:37<06:43,  1.27s/it]

Loss so far is: nan


Epoch 3/10:  95%|█████████▌| 4200/4418 [1:27:42<04:34,  1.26s/it]

Loss so far is: nan


Epoch 3/10:  97%|█████████▋| 4300/4418 [1:29:49<02:31,  1.28s/it]

Loss so far is: nan


Epoch 3/10: 100%|█████████▉| 4400/4418 [1:31:55<00:22,  1.26s/it]

Loss so far is: nan


                                                                 


Epoch: 2
Train Loss: nan
Val Loss: nan


Epoch 4/10:   2%|▏         | 100/4418 [02:02<1:28:24,  1.23s/it]

Loss so far is: nan


Epoch 4/10:   5%|▍         | 200/4418 [04:04<1:25:01,  1.21s/it]

Loss so far is: nan


Epoch 4/10:   7%|▋         | 300/4418 [06:06<1:23:15,  1.21s/it]

Loss so far is: nan


Epoch 4/10:   9%|▉         | 400/4418 [08:08<1:23:12,  1.24s/it]

Loss so far is: nan


Epoch 4/10:  11%|█▏        | 500/4418 [10:11<1:19:41,  1.22s/it]

Loss so far is: nan


Epoch 4/10:  14%|█▎        | 600/4418 [12:12<1:17:14,  1.21s/it]

Loss so far is: nan


Epoch 4/10:  16%|█▌        | 700/4418 [14:15<1:16:33,  1.24s/it]

Loss so far is: nan


Epoch 4/10:  18%|█▊        | 800/4418 [16:17<1:13:05,  1.21s/it]

Loss so far is: nan


Epoch 4/10:  20%|██        | 900/4418 [18:20<1:12:32,  1.24s/it]

Loss so far is: nan


Epoch 4/10:  23%|██▎       | 1000/4418 [20:23<1:09:52,  1.23s/it]

Loss so far is: nan


Epoch 4/10:  25%|██▍       | 1100/4418 [22:26<1:08:15,  1.23s/it]

Loss so far is: nan


Epoch 4/10:  27%|██▋       | 1200/4418 [24:29<1:07:04,  1.25s/it]

Loss so far is: nan


Epoch 4/10:  29%|██▉       | 1300/4418 [26:33<1:03:19,  1.22s/it]

Loss so far is: nan


Epoch 4/10:  32%|███▏      | 1400/4418 [28:36<1:03:09,  1.26s/it]

Loss so far is: nan


Epoch 4/10:  34%|███▍      | 1500/4418 [30:40<1:00:30,  1.24s/it]

Loss so far is: nan


Epoch 4/10:  36%|███▌      | 1600/4418 [32:44<58:36,  1.25s/it]  

Loss so far is: nan


Epoch 4/10:  38%|███▊      | 1700/4418 [34:49<56:52,  1.26s/it]

Loss so far is: nan


Epoch 4/10:  41%|████      | 1800/4418 [36:54<54:18,  1.24s/it]

Loss so far is: nan


Epoch 4/10:  43%|████▎     | 1900/4418 [38:58<52:03,  1.24s/it]

Loss so far is: nan


Epoch 4/10:  45%|████▌     | 2000/4418 [41:03<50:17,  1.25s/it]

Loss so far is: nan


Epoch 4/10:  48%|████▊     | 2100/4418 [43:09<48:11,  1.25s/it]

Loss so far is: nan


Epoch 4/10:  50%|████▉     | 2200/4418 [45:14<46:43,  1.26s/it]

Loss so far is: nan


Epoch 4/10:  52%|█████▏    | 2300/4418 [47:20<44:20,  1.26s/it]

Loss so far is: nan


Epoch 4/10:  54%|█████▍    | 2400/4418 [49:25<41:56,  1.25s/it]

Loss so far is: nan


Epoch 4/10:  57%|█████▋    | 2500/4418 [51:35<39:50,  1.25s/it]  

Loss so far is: nan


Epoch 4/10:  59%|█████▉    | 2600/4418 [53:41<38:22,  1.27s/it]

Loss so far is: nan


Epoch 4/10:  61%|██████    | 2700/4418 [55:47<36:08,  1.26s/it]

Loss so far is: nan


Epoch 4/10:  63%|██████▎   | 2800/4418 [57:54<33:40,  1.25s/it]

Loss so far is: nan


Epoch 4/10:  66%|██████▌   | 2900/4418 [59:59<31:40,  1.25s/it]

Loss so far is: nan


Epoch 4/10:  68%|██████▊   | 3000/4418 [1:02:05<29:16,  1.24s/it]

Loss so far is: nan


Epoch 4/10:  70%|███████   | 3100/4418 [1:04:11<27:27,  1.25s/it]

Loss so far is: nan


Epoch 4/10:  72%|███████▏  | 3200/4418 [1:06:28<25:52,  1.27s/it]  

Loss so far is: nan


Epoch 4/10:  75%|███████▍  | 3300/4418 [1:08:34<23:33,  1.26s/it]

Loss so far is: nan


Epoch 4/10:  77%|███████▋  | 3400/4418 [1:10:41<22:04,  1.30s/it]

Loss so far is: nan


Epoch 4/10:  79%|███████▉  | 3500/4418 [1:12:47<19:19,  1.26s/it]

Loss so far is: nan


Epoch 4/10:  81%|████████▏ | 3600/4418 [1:14:52<17:04,  1.25s/it]

Loss so far is: nan


Epoch 4/10:  84%|████████▎ | 3700/4418 [1:17:04<15:08,  1.27s/it]

Loss so far is: nan


Epoch 4/10:  86%|████████▌ | 3800/4418 [1:19:10<13:01,  1.26s/it]

Loss so far is: nan


Epoch 4/10:  88%|████████▊ | 3900/4418 [1:21:16<10:59,  1.27s/it]

Loss so far is: nan


Epoch 4/10:  91%|█████████ | 4000/4418 [1:23:23<08:45,  1.26s/it]

Loss so far is: nan


Epoch 4/10:  93%|█████████▎| 4100/4418 [1:25:29<06:49,  1.29s/it]

Loss so far is: nan


Epoch 4/10:  95%|█████████▌| 4200/4418 [1:27:35<04:29,  1.24s/it]

Loss so far is: nan


Epoch 4/10:  97%|█████████▋| 4300/4418 [1:29:41<02:28,  1.26s/it]

Loss so far is: nan


Epoch 4/10: 100%|█████████▉| 4400/4418 [1:31:47<00:22,  1.27s/it]

Loss so far is: nan


                                                                 


Epoch: 3
Train Loss: nan
Val Loss: nan


Epoch 5/10:   2%|▏         | 100/4418 [02:01<1:27:41,  1.22s/it]

Loss so far is: nan


Epoch 5/10:   5%|▍         | 200/4418 [04:03<1:25:05,  1.21s/it]

Loss so far is: nan


Epoch 5/10:   5%|▍         | 215/4418 [04:21<1:27:18,  1.25s/it]

## Post Training Metrics
Computes the test loss and common test metrics

This section does the following actions:
1. Loads the specified model
2. Runs through the test set and reports loss
3. Runs through the val set for BLEU and ROUGE metrics
4. Gives some images titles and saves them

In [1]:
import evaluate
from transformers import EvalPrediction
from tqdm import tqdm
import matplotlib.pyplot as plt

# Load model
MODEL_PATH = "./models/model_epoch_8.pt"
LOAD_MODEL = True

eval_model = None
if LOAD_MODEL:
       if "model" in locals():
              model.to('cpu')

       eval_model = BasicCaptionTransformer()
       eval_model.load_state_dict(torch.load(MODEL_PATH))
       eval_model = eval_model.to(device)

else:
       eval_model = model

# Eval setup
loss_function = nn.CrossEntropyLoss(ignore_index=-1)
test_loss = 0.0

predictions = []
labels = []

# Run through the test for test loss
with torch.no_grad():
       test_dataset_iter = tqdm(test_dataset_loader,  desc=f'Test Set Progress: ', leave=False)
       for data in test_dataset_iter:

              # get data from batch
              pixel_vals = data["pixel_values"].squeeze(1).to(device)
              labels = data["labels"].squeeze(1).to(device)

              # Predict captions
              outputs = eval_model(images=pixel_vals, captions=labels)
              test_loss += loss_function(outputs.permute(0,2,1), labels)
              print (test_loss)

print ("Test Loss: " + str(test_loss / len(test_dataset_loader)))


# Run through valadation set with best model
predictions = []
labels = []
with torch.no_grad():
       valid_dataset_iter = tqdm(valid_dataset_loader,  desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)
       for data in valid_dataset_iter:

              # get data from batch
              pixel_vals = data["pixel_values"].squeeze(1).to(device)
              labels = data["labels"].squeeze(1).to(device)
       
              # Predict captions
              outputs = eval_model(images=pixel_vals, captions=labels)

              # Format labels
              logits = outputs.detach().cpu()
              predictions.extend(logits.argmax(dim=-1).tolist())
              labels.extend(labels.tolist())
    

# Format predictions into Hugging Face class
eval_predictions = EvalPrediction(predictions=predictions, label_ids=labels)

predictions = eval_predictions.predictions
labels = eval_predictions.label_ids

# Tokenize predictions and reference captions
predictions_str = tokenizer.batch_decode(predictions, skip_special_tokens=True)
labels_str = tokenizer.batch_decode(labels, skip_special_tokens=True)


# Load test evaluators
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")

# Compute and print Rouge-1, Rogue-2, RougeL
rouge_result = rouge.compute(predictions=predictions_str, references=labels_str)
rouge_result = {k: round(v * 100, 4) for k, v in rouge_result.items()}
print ("ROUGE Metrics: \nROUGE-1: " + rouge_result.get("rouge1", 0) + 
       "\nROUGE-2: " + rouge_result.get("rouge2", 0) + 
       "\nROUGE-L: " + rouge_result.get("rougeL", 0))


# Compute and print BLEU metrics
bleu_result = bleu.compute(predictions=predictions_str, references=labels_str)
bleu_score = round(bleu_result["bleu"] * 100, 4)
print ("BLEU Metrics: " + bleu_score)


# Get first 16 images and give them captiosn
for i in range(16):
       file_path = test_dataset[i]["image_file"]
       model_input =  test_dataset[i]["pixel_values"]

       # Make caption just beginning of sentence token (50256)
       # It's also the padding token, maybe an issue?
       sos_caption = torch.tensor([1, 2, 3])
       output = model(model_input, sos_caption)


       fig, ax = plt.subplot_mosaic([
       ['hopper', 'mri']
       ], figsize=(7, 3.5))

SyntaxError: invalid syntax (131980854.py, line 98)