## Before Running:
Please Install all from the requirements.txt (pip install -r requirements.txt).

## Set Hyper Parameters

In [2]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder_decoder_layers = 3
encoder_decoder_heads = 8
embedded_dim = 768 # Don't change
max_length = 256*2*2*2
coco_dataset_ratio = 50
coco_dataset_dir = "./coco"
vocab_size = 50257
batch_size = 64
num_epochs = 10
learning_rate = 1e-2 # was 1e-3
patience = 3
weight_decay = 1e-5
preprocess_swin_model = "microsoft/swin-tiny-patch4-window7-224"
encoder_model = "microsoft/swin-base-patch4-window7-224-in22k"
decoder_model = "gpt2"

## Downloading and Format datasets
This will take some time to finishing running the first time. It took me roughly 40 minutes.

This section does the following actions:
1. Downloads the Dataset
2. Keeps images with only 3 or 4 dim
3. Transforms the dataset 
4. Turns the data set into data loaders


In [4]:
import numpy as np
from datasets import load_dataset
from transformers import ViTImageProcessor, GPT2TokenizerFast, AutoImageProcessor, SwinModel
from torch.utils.data import DataLoader, Dataset
import torch
import os

# Download the train, val and test splits of the COCO dataset
train_ds = load_dataset("HuggingFaceM4/COCO", split=f"train[:{coco_dataset_ratio}%]", cache_dir=coco_dataset_dir)
valid_ds = load_dataset("HuggingFaceM4/COCO", split=f"validation[:{coco_dataset_ratio}%]", cache_dir=coco_dataset_dir)
test_ds = load_dataset("HuggingFaceM4/COCO", split="test", cache_dir=coco_dataset_dir)

# Filter all non 3 or 4 dim images out
# Can change num_proc, but might be errors with np
train_ds = train_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)
valid_ds = valid_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)
test_ds = test_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)

# Does pre processing on the data set
# This includes pre-trained ViTimage feature extraction and tokenizing captions
tokenizer = GPT2TokenizerFast.from_pretrained(decoder_model)
tokenizer.pad_token = tokenizer.eos_token
image_processor = ViTImageProcessor.from_pretrained(encoder_model)
image_processor_swin = SwinModel.from_pretrained(preprocess_swin_model).to(device)

def preprocess(items):
    # Image pre-processing
    # use ViT and SWIN since no back prop
    pixel_values = image_processor(items["image"], return_tensors="pt").pixel_values.to(device)
    with torch.no_grad():
        pixel_values = image_processor_swin (pixel_values).last_hidden_state
    pixel_values = pixel_values.to('cpu')

    # tokenize
    targets = tokenizer(items["sentences"]['raw'],
                        max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
    
    # Keep image file for easy examples later
    img_file = items['filepath']
    return {'pixel_values': pixel_values, 'labels': targets["input_ids"], 'image_file': img_file}


TRAIN_SET_PATH = './preprocessed/train_set.pt'
TEST_SET_PATH = './preprocessed/val_set.pt'
VALID_SET_PATH = './preprocessed/test_set.pt'

# Pre process train dataset if it doesn't exist
train_dataset = None
if os.path.isfile(TRAIN_SET_PATH):
    train_dataset = torch.load(TRAIN_SET_PATH)

else:
    train_dataset = train_ds.map(preprocess)
    torch.save(train_dataset, TRAIN_SET_PATH)
    
# Pre process val dataset if it doesn't exist
valid_dataset = None
if os.path.isfile(VALID_SET_PATH):
    valid_dataset = torch.load(VALID_SET_PATH)

else:
    valid_dataset = valid_ds.map(preprocess)
    torch.save(valid_dataset, VALID_SET_PATH)

# Pre process test dataset if it doesn't exist
test_dataset = None
if os.path.isfile(TEST_SET_PATH):
    test_dataset = torch.load(TEST_SET_PATH)

else:
    test_dataset = test_ds.map(preprocess)
    torch.save(test_dataset, TEST_SET_PATH)



# Turns the dataset into a torch DataLoader
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([torch.tensor(x['pixel_values']) for x in batch]),
        'labels': torch.stack([torch.tensor(x['labels']) for x in batch]),
        'image_file': [x["image_file"] for x in batch]
    }

train_dataset_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
valid_dataset_loader = DataLoader(valid_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_dataset_loader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

## Define the Model
Creates the PureT model from the paper

This section does the following actions:
1. Creates the SWIN Transformer used by PureT
2. Creates the PureT encoder
3. Creates the PureT decoder
4. Creates the PureT model

Download the pre-trained SwinT weights from here https://drive.google.com/drive/folders/1HBw5NGGw8DjkyNurksCP5v8a5f0FG7zU and put them in this folder before running

In [5]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from transformers import SwinModel

class BasicCaptionTransformer(torch.nn.Module):
    def __init__(self):
        super(BasicCaptionTransformer, self).__init__()

        # Image processing pre-done with ViTImageProcessor and Swin in the Downloading and Format datasets step
        #self.swin = SwinModel.from_pretrained(preprocess_swin_model)

        # build encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedded_dim, nhead=encoder_decoder_heads, batch_first=True)
        self.encoders = nn.TransformerEncoder(encoder_layer, encoder_decoder_layers)

        # embeddings
        self.embeddings = nn.Embedding(vocab_size, embedded_dim)

        # build decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model=embedded_dim, nhead=encoder_decoder_heads, batch_first=True)
        self.decoders = nn.TransformerDecoder(decoder_layer, encoder_decoder_layers)

        # final
        self.linear = nn.Linear(in_features=768, out_features=vocab_size)


    def forward(self, images, captions):
        x = images
        #with torch.no_grad():
        #    x = self.swin(x)
        x = self.encoders(x)
        
        captions = self.embeddings(captions)
        x = self.decoders(tgt=captions.to(x.dtype), memory=x)
        return self.linear(x)
    


## Training loop
Trains the model and saves the best (lowest val error) and last model

This section does the following actions:
1. Creates the Basic transformer model
2. Sets up optimizer, scheduler, counter for training
3. Trains for num_epochs epochs
2. Each Epoch has valadation accuracy calculated
3. Save the model with the best valadation accuracy TODO
4. Save the model when the max number of epochs has been reached TODO

In [5]:
from tqdm import tqdm

# Loop setup
model = BasicCaptionTransformer()
model = model.to(device)

loss_function = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)


# Train
predictions_per_batch = batch_size * max_length
stop_counter = 0
train_losses = []
val_losses = []
best_val_loss = float('inf')

train_len = len(train_dataset_loader)
val_len = len(valid_dataset_loader)


for epoch in range(num_epochs):
    model.train()

    # Loop through training data loader batches
    train_loss = 0.0
    train_dataloader_iter = tqdm(train_dataset_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)
    for i, data in enumerate(train_dataloader_iter):
        
        # Get values from data loader
        pixel_vals = data["pixel_values"].squeeze(1).to(device)
        captions = data["labels"].squeeze(1).to(device)

        optimizer.zero_grad()
        outputs = model(images=pixel_vals, captions=captions)

        loss = loss_function(outputs.permute(0,2,1), captions)
        loss.backward()
        optimizer.step()

        # save loss
        train_loss += loss.item()
        if i % 100 == 99:
            print ("Loss so far is: " + str (train_loss / i))
            #print ("tensor: ", outputs)


    # Validation
    val_loss = 0.0
    valid_dataset_iter = tqdm(valid_dataset_loader,  desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)
    with torch.no_grad():
        for i, data in enumerate(valid_dataset_iter):
            
            # Get values from data loader
            pixel_vals = data["pixel_values"].squeeze(1).to(device)
            captions = data["labels"].squeeze(1).to(device)

            outputs = model(images=pixel_vals, captions=captions)
            loss = loss_function(outputs.permute(0,2,1), captions)

            # save loss
            val_loss += loss.item()


    # Save model
    torch.save(model.state_dict(), f'./models/model_epoch_{epoch+1}.pt')

    train_losses.append(train_loss / train_len)
    val_losses.append(val_loss / val_len)

    # Print losses
    # divide by length of data loader... 
    print("\nEpoch: " + str(epoch) + 
        "\nTrain Loss: " + str(train_loss / train_len) +
        "\nVal Loss: " + str(val_loss / val_len))

print (train_losses)
print (val_losses)


Epoch 1/10:   2%|▏         | 100/4418 [02:29<1:37:03,  1.35s/it]

Loss so far is: 3.246267870219067


Epoch 1/10:   5%|▍         | 200/4418 [04:42<1:35:44,  1.36s/it]

Loss so far is: 3.0962029869232945


Epoch 1/10:   7%|▋         | 300/4418 [06:55<1:34:06,  1.37s/it]

Loss so far is: 3.0494830353204225


Epoch 1/10:   9%|▉         | 400/4418 [09:06<1:26:07,  1.29s/it]

Loss so far is: 3.0180560484864656


Epoch 1/10:  11%|█▏        | 500/4418 [11:14<1:23:56,  1.29s/it]

Loss so far is: 3.0021382305091753


Epoch 1/10:  14%|█▎        | 600/4418 [13:23<1:21:12,  1.28s/it]

Loss so far is: 2.9899919745520083


Epoch 1/10:  16%|█▌        | 700/4418 [15:30<1:19:42,  1.29s/it]

Loss so far is: 2.979787701359805


Epoch 1/10:  18%|█▊        | 800/4418 [17:38<1:16:23,  1.27s/it]

Loss so far is: 2.972508473748409


Epoch 1/10:  20%|██        | 900/4418 [19:45<1:14:25,  1.27s/it]

Loss so far is: 2.966315914446838


Epoch 1/10:  23%|██▎       | 1000/4418 [21:53<1:12:37,  1.27s/it]

Loss so far is: 2.960764994731059


Epoch 1/10:  25%|██▍       | 1100/4418 [24:01<1:10:41,  1.28s/it]

Loss so far is: 2.9568673440605213


Epoch 1/10:  27%|██▋       | 1200/4418 [26:09<1:09:09,  1.29s/it]

Loss so far is: 2.9530338728000363


Epoch 1/10:  29%|██▉       | 1300/4418 [28:17<1:07:00,  1.29s/it]

Loss so far is: 2.948928836495074


Epoch 1/10:  32%|███▏      | 1400/4418 [30:26<1:04:44,  1.29s/it]

Loss so far is: 2.9457534688809157


Epoch 1/10:  34%|███▍      | 1500/4418 [32:35<1:02:49,  1.29s/it]

Loss so far is: 2.943090080499172


Epoch 1/10:  36%|███▌      | 1600/4418 [34:45<1:00:32,  1.29s/it]

Loss so far is: 2.941162039295743


Epoch 1/10:  38%|███▊      | 1700/4418 [36:54<59:08,  1.31s/it]  

Loss so far is: 2.940208911895752


Epoch 1/10:  41%|████      | 1800/4418 [39:04<56:32,  1.30s/it]

Loss so far is: 2.9384285588606387


Epoch 1/10:  43%|████▎     | 1900/4418 [41:14<54:32,  1.30s/it]

Loss so far is: 2.9367279799251946


Epoch 1/10:  45%|████▌     | 2000/4418 [43:24<52:51,  1.31s/it]

Loss so far is: 2.9351540810469094


Epoch 1/10:  48%|████▊     | 2100/4418 [45:34<50:21,  1.30s/it]

Loss so far is: 2.9337767230266047


Epoch 1/10:  50%|████▉     | 2200/4418 [47:45<48:00,  1.30s/it]

Loss so far is: 2.9326644720083586


Epoch 1/10:  52%|█████▏    | 2300/4418 [49:55<46:27,  1.32s/it]

Loss so far is: 2.932642977462742


Epoch 1/10:  54%|█████▍    | 2400/4418 [52:05<43:55,  1.31s/it]

Loss so far is: 2.9319104932655438


Epoch 1/10:  57%|█████▋    | 2500/4418 [54:16<41:25,  1.30s/it]

Loss so far is: 2.931264463640681


Epoch 1/10:  59%|█████▉    | 2600/4418 [56:27<39:43,  1.31s/it]

Loss so far is: 2.930085625268717


Epoch 1/10:  61%|██████    | 2700/4418 [58:38<37:15,  1.30s/it]

Loss so far is: 2.929685355469491


Epoch 1/10:  63%|██████▎   | 2800/4418 [1:00:48<35:15,  1.31s/it]

Loss so far is: 2.9292165134241857


Epoch 1/10:  66%|██████▌   | 2900/4418 [1:02:59<33:08,  1.31s/it]

Loss so far is: 2.9280968900794364


Epoch 1/10:  68%|██████▊   | 3000/4418 [1:05:09<30:52,  1.31s/it]

Loss so far is: 2.926806284769967


Epoch 1/10:  70%|███████   | 3100/4418 [1:07:34<28:42,  1.31s/it]  

Loss so far is: 2.926285882810117


Epoch 1/10:  72%|███████▏  | 3200/4418 [1:09:44<26:32,  1.31s/it]

Loss so far is: 2.925504752865953


Epoch 1/10:  75%|███████▍  | 3300/4418 [1:11:55<24:25,  1.31s/it]

Loss so far is: 2.9246527859571305


Epoch 1/10:  77%|███████▋  | 3400/4418 [1:14:05<21:56,  1.29s/it]

Loss so far is: 2.9242315274121027


Epoch 1/10:  79%|███████▉  | 3500/4418 [1:16:16<19:50,  1.30s/it]

Loss so far is: 2.9236253104437484


Epoch 1/10:  81%|████████▏ | 3600/4418 [1:18:26<17:37,  1.29s/it]

Loss so far is: 2.9224916320603636


Epoch 1/10:  84%|████████▎ | 3700/4418 [1:20:36<15:31,  1.30s/it]

Loss so far is: 2.921131334230042


Epoch 1/10:  86%|████████▌ | 3800/4418 [1:22:46<13:23,  1.30s/it]

Loss so far is: 2.920829012005478


Epoch 1/10:  88%|████████▊ | 3900/4418 [1:24:57<11:22,  1.32s/it]

Loss so far is: 2.920508117913038


Epoch 1/10:  91%|█████████ | 4000/4418 [1:27:08<09:06,  1.31s/it]

Loss so far is: 2.9205526217546245


Epoch 1/10:  93%|█████████▎| 4100/4418 [1:29:18<06:53,  1.30s/it]

Loss so far is: 2.920360750906582


Epoch 1/10:  95%|█████████▌| 4200/4418 [1:31:29<04:45,  1.31s/it]

Loss so far is: 2.9201722737294373


Epoch 1/10:  97%|█████████▋| 4300/4418 [1:33:39<02:35,  1.32s/it]

Loss so far is: 2.9197919944963724


Epoch 1/10: 100%|█████████▉| 4400/4418 [1:35:50<00:23,  1.31s/it]

Loss so far is: 2.919723436436238


                                                                 


Epoch: 0
Train Loss: 2.9189488146843656
Val Loss: 2.9067098959898336


Epoch 2/10:   2%|▏         | 100/4418 [02:05<1:30:36,  1.26s/it]

Loss so far is: 2.933494970051929


Epoch 2/10:   5%|▍         | 200/4418 [04:11<1:28:33,  1.26s/it]

Loss so far is: 2.921930007599107


Epoch 2/10:   7%|▋         | 300/4418 [06:18<1:27:06,  1.27s/it]

Loss so far is: 2.916942979181092


Epoch 2/10:   9%|▉         | 400/4418 [08:25<1:24:55,  1.27s/it]

Loss so far is: 2.9153878527476373


Epoch 2/10:  11%|█▏        | 500/4418 [10:32<1:22:38,  1.27s/it]

Loss so far is: 2.914412021636963


Epoch 2/10:  14%|█▎        | 600/4418 [12:38<1:20:53,  1.27s/it]

Loss so far is: 2.9107060133913323


Epoch 2/10:  16%|█▌        | 700/4418 [14:46<1:18:03,  1.26s/it]

Loss so far is: 2.905799565908735


Epoch 2/10:  18%|█▊        | 800/4418 [16:53<1:17:21,  1.28s/it]

Loss so far is: 2.9013913513274305


Epoch 2/10:  20%|██        | 900/4418 [19:00<1:14:38,  1.27s/it]

Loss so far is: 2.896780491936061


Epoch 2/10:  23%|██▎       | 1000/4418 [21:07<1:12:34,  1.27s/it]

Loss so far is: 2.8931444077878385


Epoch 2/10:  25%|██▍       | 1100/4418 [23:15<1:11:05,  1.29s/it]

Loss so far is: 2.890477360542304


Epoch 2/10:  27%|██▋       | 1200/4418 [25:23<1:08:49,  1.28s/it]

Loss so far is: 2.8925099420587257


Epoch 2/10:  29%|██▉       | 1300/4418 [27:31<1:06:43,  1.28s/it]

Loss so far is: 2.8923405652049876


Epoch 2/10:  32%|███▏      | 1400/4418 [29:39<1:04:09,  1.28s/it]

Loss so far is: 2.8928085730023008


Epoch 2/10:  34%|███▍      | 1500/4418 [31:47<1:02:08,  1.28s/it]

Loss so far is: 2.894107369123259


Epoch 2/10:  36%|███▌      | 1600/4418 [33:56<1:00:28,  1.29s/it]

Loss so far is: 2.894778394639455


Epoch 2/10:  38%|███▊      | 1700/4418 [36:05<58:00,  1.28s/it]  

Loss so far is: 2.8951566535069566


Epoch 2/10:  41%|████      | 1800/4418 [38:14<55:41,  1.28s/it]

Loss so far is: 2.895473777485795


Epoch 2/10:  43%|████▎     | 1900/4418 [40:23<54:23,  1.30s/it]

Loss so far is: 2.8965829038946422


Epoch 2/10:  45%|████▌     | 2000/4418 [42:32<52:29,  1.30s/it]

Loss so far is: 2.896658944868934


Epoch 2/10:  48%|████▊     | 2100/4418 [44:42<50:16,  1.30s/it]

Loss so far is: 2.896775919008732


Epoch 2/10:  50%|████▉     | 2200/4418 [46:57<47:54,  1.30s/it]  

Loss so far is: 2.8964492426833655


Epoch 2/10:  52%|█████▏    | 2300/4418 [49:21<54:21,  1.54s/it]  

Loss so far is: 2.8960590024469415


Epoch 2/10:  54%|█████▍    | 2400/4418 [51:37<45:03,  1.34s/it]

Loss so far is: 2.8961651152300307


Epoch 2/10:  57%|█████▋    | 2500/4418 [53:48<41:44,  1.31s/it]

Loss so far is: 2.896670117670176


Epoch 2/10:  59%|█████▉    | 2600/4418 [55:59<39:46,  1.31s/it]

Loss so far is: 2.8972747511201384


Epoch 2/10:  61%|██████    | 2700/4418 [58:10<37:19,  1.30s/it]

Loss so far is: nan


Epoch 2/10:  63%|██████▎   | 2779/4418 [59:53<36:39,  1.34s/it]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x0000015C6BD43AD0>>
Traceback (most recent call last):
  File "c:\Programming\Python3-12-3\Lib\site-packages\ipykernel\ipkernel.py", line 790, in _clean_thread_parent_frames
    active_threads = {thread.ident for thread in threading.enumerate()}
                                                 ^^^^^^^^^^^^^^^^^^^^^
  File "c:\Programming\Python3-12-3\Lib\threading.py", line 1533, in enumerate
    def enumerate():
    
KeyboardInterrupt: 
Epoch 2/10:  63%|██████▎   | 2787/4418 [1:00:04<36:33,  1.35s/it]

## Post Training Metrics
Computes the test loss and common test metrics

This section does the following actions:
1. Loads the specified model
2. Runs through the test set and reports loss
3. Runs through the val set for BLEU and ROUGE metrics
4. Gives some images titles and saves them

In [1]:
import evaluate
from transformers import EvalPrediction
from tqdm import tqdm
import matplotlib.pyplot as plt

# Load model
MODEL_PATH = "./models/model_epoch_8.pt"
LOAD_MODEL = True

eval_model = None
if LOAD_MODEL:
       if "model" in locals():
              model.to('cpu')

       eval_model = BasicCaptionTransformer()
       eval_model.load_state_dict(torch.load(MODEL_PATH))
       eval_model = eval_model.to(device)

else:
       eval_model = model

# Eval setup
loss_function = nn.CrossEntropyLoss(ignore_index=-1)
test_loss = 0.0

predictions = []
labels = []

# Run through the test for test loss
with torch.no_grad():
       test_dataset_iter = tqdm(test_dataset_loader,  desc=f'Test Set Progress: ', leave=False)
       for data in test_dataset_iter:

              # get data from batch
              pixel_vals = data["pixel_values"].squeeze(1).to(device)
              labels = data["labels"].squeeze(1).to(device)

              # Predict captions
              outputs = eval_model(images=pixel_vals, captions=labels)
              test_loss += loss_function(outputs.permute(0,2,1), labels)
              print (test_loss)

print ("Test Loss: " + str(test_loss / len(test_dataset_loader)))


# Run through valadation set with best model
predictions = []
labels = []
with torch.no_grad():
       valid_dataset_iter = tqdm(valid_dataset_loader,  desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)
       for data in valid_dataset_iter:

              # get data from batch
              pixel_vals = data["pixel_values"].squeeze(1).to(device)
              labels = data["labels"].squeeze(1).to(device)
       
              # Predict captions
              outputs = eval_model(images=pixel_vals, captions=labels)

              # Format labels
              logits = outputs.detach().cpu()
              predictions.extend(logits.argmax(dim=-1).tolist())
              labels.extend(labels.tolist())
    

# Format predictions into Hugging Face class
eval_predictions = EvalPrediction(predictions=predictions, label_ids=labels)

predictions = eval_predictions.predictions
labels = eval_predictions.label_ids

# Tokenize predictions and reference captions
predictions_str = tokenizer.batch_decode(predictions, skip_special_tokens=True)
labels_str = tokenizer.batch_decode(labels, skip_special_tokens=True)


# Load test evaluators
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")

# Compute and print Rouge-1, Rogue-2, RougeL
rouge_result = rouge.compute(predictions=predictions_str, references=labels_str)
rouge_result = {k: round(v * 100, 4) for k, v in rouge_result.items()}
print ("ROUGE Metrics: \nROUGE-1: " + rouge_result.get("rouge1", 0) + 
       "\nROUGE-2: " + rouge_result.get("rouge2", 0) + 
       "\nROUGE-L: " + rouge_result.get("rougeL", 0))


# Compute and print BLEU metrics
bleu_result = bleu.compute(predictions=predictions_str, references=labels_str)
bleu_score = round(bleu_result["bleu"] * 100, 4)
print ("BLEU Metrics: " + bleu_score)


# Get first 16 images and give them captiosn
for i in range(16):
       file_path = test_dataset[i]["image_file"]
       model_input =  test_dataset[i]["pixel_values"]

       fig, ax = plt.subplot_mosaic([
       ['hopper', 'mri']
       ], figsize=(7, 3.5))

SyntaxError: invalid syntax (131980854.py, line 98)