## Before Running:
Please Install all from the requirements.txt (pip install -r requirements.txt).

## Set Hyper Parameters

In [4]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder_decoder_layers = 3
encoder_decoder_heads = 8
embedded_dim = 768 # Don't change
max_length = 32 
coco_dataset_ratio = 50
coco_dataset_dir = "./coco"
vocab_size = 50257 # Don't change
batch_size = 64
num_epochs = 10
learning_rate = 1e-4 # was 1e-3, not the problem
patience = 3
weight_decay = 1e-5 # was 1e-5, not the problem
preprocess_swin_model = "microsoft/swin-tiny-patch4-window7-224"
encoder_model = "microsoft/swin-base-patch4-window7-224-in22k"
decoder_model = "gpt2"

## Downloading and Format datasets
This will take some time to finishing running the first time. It took me roughly 40 minutes.

This section does the following actions:
1. Downloads the Dataset
2. Keeps images with only 3 or 4 dim
3. Transforms the dataset 
4. Turns the data set into data loaders


In [2]:
import numpy as np
from datasets import load_dataset
from transformers import ViTImageProcessor, GPT2TokenizerFast, AutoImageProcessor, SwinModel
from torch.utils.data import DataLoader, Dataset
import torch
import os

# Download the train, val and test splits of the COCO dataset
train_ds = load_dataset("HuggingFaceM4/COCO", split=f"train[:{coco_dataset_ratio}%]", cache_dir=coco_dataset_dir)
valid_ds = load_dataset("HuggingFaceM4/COCO", split=f"validation[:{coco_dataset_ratio}%]", cache_dir=coco_dataset_dir)
test_ds = load_dataset("HuggingFaceM4/COCO", split="test", cache_dir=coco_dataset_dir)

# Filter all non 3 or 4 dim images out
# Can change num_proc, but might be errors with np
train_ds = train_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)
valid_ds = valid_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)
test_ds = test_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)

# Does pre processing on the data set
# This includes pre-trained ViTimage feature extraction and tokenizing captions
tokenizer = GPT2TokenizerFast.from_pretrained(decoder_model)
tokenizer.pad_token = tokenizer.eos_token
image_processor = ViTImageProcessor.from_pretrained(encoder_model)
image_processor_swin = SwinModel.from_pretrained(preprocess_swin_model).to(device)

def preprocess(items):
    # Image pre-processing
    # use ViT and SWIN since no back prop
    pixel_values = image_processor(items["image"], return_tensors="pt").pixel_values.to(device)
    with torch.no_grad():
        pixel_values = image_processor_swin (pixel_values).last_hidden_state
    pixel_values = pixel_values.to('cpu')

    # tokenize
    targets = tokenizer(items["sentences"]['raw'],
                        max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
    
    # Keep image file for easy examples later
    img_file = items['filepath']
    return {'pixel_values': pixel_values, 'labels': targets["input_ids"], 'image_file': img_file}


TRAIN_SET_PATH = './preprocessed/train_set.pt'
TEST_SET_PATH = './preprocessed/val_set.pt'
VALID_SET_PATH = './preprocessed/test_set.pt'

# Pre process train dataset if it doesn't exist
train_dataset = None
if os.path.isfile(TRAIN_SET_PATH):
    train_dataset = torch.load(TRAIN_SET_PATH)

else:
    train_dataset = train_ds.map(preprocess)
    torch.save(train_dataset, TRAIN_SET_PATH)
    
# Pre process val dataset if it doesn't exist
valid_dataset = None
if os.path.isfile(VALID_SET_PATH):
    valid_dataset = torch.load(VALID_SET_PATH)

else:
    valid_dataset = valid_ds.map(preprocess)
    torch.save(valid_dataset, VALID_SET_PATH)

# Pre process test dataset if it doesn't exist
test_dataset = None
if os.path.isfile(TEST_SET_PATH):
    test_dataset = torch.load(TEST_SET_PATH)

else:
    test_dataset = test_ds.map(preprocess)
    torch.save(test_dataset, TEST_SET_PATH)



# Turns the dataset into a torch DataLoader
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([torch.tensor(x['pixel_values']) for x in batch]),
        'labels': torch.stack([torch.tensor(x['labels']) for x in batch]),
        'image_file': [x["image_file"] for x in batch]
    }

train_dataset_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
valid_dataset_loader = DataLoader(valid_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)
test_dataset_loader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


## Define the Model
Creates the PureT model from the paper

This section does the following actions:
1. Creates the SWIN Transformer used by PureT
2. Creates the PureT encoder
3. Creates the PureT decoder
4. Creates the PureT model

Download the pre-trained SwinT weights from here https://drive.google.com/drive/folders/1HBw5NGGw8DjkyNurksCP5v8a5f0FG7zU and put them in this folder before running

In [3]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from transformers import SwinModel

class BasicCaptionTransformer(torch.nn.Module):
    def __init__(self):
        super(BasicCaptionTransformer, self).__init__()

        # Image processing pre-done with ViTImageProcessor and Swin in the Downloading and Format datasets step
        #self.swin = SwinModel.from_pretrained(preprocess_swin_model)

        # build encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedded_dim, nhead=encoder_decoder_heads, batch_first=True)
        self.encoders = nn.TransformerEncoder(encoder_layer, encoder_decoder_layers)

        # embeddings
        self.embeddings = nn.Embedding(vocab_size, embedded_dim)

        # build decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model=embedded_dim, nhead=encoder_decoder_heads, batch_first=True)
        self.decoders = nn.TransformerDecoder(decoder_layer, encoder_decoder_layers)

        # final
        self.linear = nn.Linear(in_features=768, out_features=vocab_size)


    def forward(self, images, captions):
        x = images
        #with torch.no_grad():
        #    x = self.swin(x)
        x = self.encoders(x)
        
        captions = self.embeddings(captions)
        x = self.decoders(tgt=captions.to(x.dtype), memory=x)
        return self.linear(x)
    


## Training loop
Trains the model and saves the best (lowest val error) and last model

This section does the following actions:
1. Creates the Basic transformer model
2. Sets up optimizer, scheduler, counter for training
3. Trains for num_epochs epochs
2. Each Epoch has valadation accuracy calculated
3. Save the model with the best valadation accuracy TODO
4. Save the model when the max number of epochs has been reached TODO

In [5]:
from tqdm import tqdm

# Loop setup
model = BasicCaptionTransformer()
model = model.to(device)

loss_function = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)


# Train
predictions_per_batch = batch_size * max_length
stop_counter = 0
train_losses = []
val_losses = []
best_val_loss = float('inf')

train_len = len(train_dataset_loader)
val_len = len(valid_dataset_loader)


for epoch in range(num_epochs):
    model.train()

    # Loop through training data loader batches
    train_loss = 0.0
    train_dataloader_iter = tqdm(train_dataset_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)
    for i, data in enumerate(train_dataloader_iter):
        
        # Get values from data loader
        pixel_vals = data["pixel_values"].squeeze(1).to(device)
        captions = data["labels"].squeeze(1).to(device)

        optimizer.zero_grad()
        outputs = model(images=pixel_vals, captions=captions)

        loss = loss_function(outputs.permute(0,2,1), captions)
        loss.backward()
        optimizer.step()

        # save loss
        train_loss += loss.item()
        if i % 100 == 99:
            print ("Loss so far is: " + str (train_loss / i))
            #print ("tensor: ", outputs)


    # Validation
    val_loss = 0.0
    valid_dataset_iter = tqdm(valid_dataset_loader,  desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)
    with torch.no_grad():
        for i, data in enumerate(valid_dataset_iter):
            
            # Get values from data loader
            pixel_vals = data["pixel_values"].squeeze(1).to(device)
            captions = data["labels"].squeeze(1).to(device)

            outputs = model(images=pixel_vals, captions=captions)
            loss = loss_function(outputs.permute(0,2,1), captions)

            # save loss
            val_loss += loss.item()


    # Save model
    torch.save(model.state_dict(), f'./models/model_epoch_{epoch+1}.pt')

    train_losses.append(train_loss / train_len)
    val_losses.append(val_loss / val_len)

    # Print losses
    # divide by length of data loader... 
    print("\nEpoch: " + str(epoch) + 
        "\nTrain Loss: " + str(train_loss / train_len) +
        "\nVal Loss: " + str(val_loss / val_len))

print (train_losses)
print (val_losses)


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Epoch 1/10:   2%|▏         | 100/4418 [02:11<1:31:11,  1.27s/it]

Loss so far is: 3.5211477351911142


Epoch 1/10:   5%|▍         | 200/4418 [04:16<1:27:34,  1.25s/it]

Loss so far is: 3.291331408610895


Epoch 1/10:   7%|▋         | 300/4418 [06:19<1:24:05,  1.23s/it]

Loss so far is: 3.2253024968813895


Epoch 1/10:   9%|▉         | 400/4418 [08:22<1:21:42,  1.22s/it]

Loss so far is: 3.1825974674750688


Epoch 1/10:  11%|█▏        | 500/4418 [10:33<1:34:15,  1.44s/it]

Loss so far is: 3.1494571727836775


Epoch 1/10:  14%|█▎        | 600/4418 [12:56<1:25:16,  1.34s/it]

Loss so far is: 3.1233302512033556


Epoch 1/10:  16%|█▌        | 700/4418 [15:20<1:26:00,  1.39s/it]

Loss so far is: 3.1015545070086086


                                                                

KeyboardInterrupt: 

## Post Training Metrics
Computes the test loss and common test metrics

This section does the following actions:
1. Loads the specified model
2. Runs through the test set and reports loss
3. Runs through the val set for BLEU and ROUGE metrics
4. Gives some images titles and saves them

In [1]:
import evaluate
from transformers import EvalPrediction
from tqdm import tqdm
import matplotlib.pyplot as plt

# Load model
MODEL_PATH = "./models/model_epoch_8.pt"
LOAD_MODEL = True

eval_model = None
if LOAD_MODEL:
       if "model" in locals():
              model.to('cpu')

       eval_model = BasicCaptionTransformer()
       eval_model.load_state_dict(torch.load(MODEL_PATH))
       eval_model = eval_model.to(device)

else:
       eval_model = model

# Eval setup
loss_function = nn.CrossEntropyLoss(ignore_index=-1)
test_loss = 0.0

predictions = []
labels = []

# Run through the test for test loss
with torch.no_grad():
       test_dataset_iter = tqdm(test_dataset_loader,  desc=f'Test Set Progress: ', leave=False)
       for data in test_dataset_iter:

              # get data from batch
              pixel_vals = data["pixel_values"].squeeze(1).to(device)
              labels = data["labels"].squeeze(1).to(device)

              # Predict captions
              outputs = eval_model(images=pixel_vals, captions=labels)
              test_loss += loss_function(outputs.permute(0,2,1), labels)
              print (test_loss)

print ("Test Loss: " + str(test_loss / len(test_dataset_loader)))


# Run through valadation set with best model
predictions = []
labels = []
with torch.no_grad():
       valid_dataset_iter = tqdm(valid_dataset_loader,  desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)
       for data in valid_dataset_iter:

              # get data from batch
              pixel_vals = data["pixel_values"].squeeze(1).to(device)
              labels = data["labels"].squeeze(1).to(device)
       
              # Predict captions
              outputs = eval_model(images=pixel_vals, captions=labels)

              # Format labels
              logits = outputs.detach().cpu()
              predictions.extend(logits.argmax(dim=-1).tolist())
              labels.extend(labels.tolist())
    

# Format predictions into Hugging Face class
eval_predictions = EvalPrediction(predictions=predictions, label_ids=labels)

predictions = eval_predictions.predictions
labels = eval_predictions.label_ids

# Tokenize predictions and reference captions
predictions_str = tokenizer.batch_decode(predictions, skip_special_tokens=True)
labels_str = tokenizer.batch_decode(labels, skip_special_tokens=True)


# Load test evaluators
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")

# Compute and print Rouge-1, Rogue-2, RougeL
rouge_result = rouge.compute(predictions=predictions_str, references=labels_str)
rouge_result = {k: round(v * 100, 4) for k, v in rouge_result.items()}
print ("ROUGE Metrics: \nROUGE-1: " + rouge_result.get("rouge1", 0) + 
       "\nROUGE-2: " + rouge_result.get("rouge2", 0) + 
       "\nROUGE-L: " + rouge_result.get("rougeL", 0))


# Compute and print BLEU metrics
bleu_result = bleu.compute(predictions=predictions_str, references=labels_str)
bleu_score = round(bleu_result["bleu"] * 100, 4)
print ("BLEU Metrics: " + bleu_score)


# Get first 16 images and give them captiosn
for i in range(16):
       file_path = test_dataset[i]["image_file"]
       model_input =  test_dataset[i]["pixel_values"]

       # Make caption just beginning of sentence token (50256)
       # It's also the padding token, maybe an issue?
       sos_caption = torch.tensor([1, 2, 3])
       output = model(model_input, sos_caption)


       fig, ax = plt.subplot_mosaic([
       ['hopper', 'mri']
       ], figsize=(7, 3.5))

SyntaxError: invalid syntax (131980854.py, line 98)