## Before Running:
Please Install all from the requirements.txt (pip install -r requirements.txt)

## Set Hyper Parameters

In [20]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_length = 32
coco_dataset_ratio = 50
coco_dataset_dir = "./coco"
batch_size = 32
num_epochs = 10
learning_rate = 1e-3
patience = 3
weight_decay = 1e-5
encoder_model = "microsoft/swin-base-patch4-window7-224-in22k"
decoder_model = "gpt2"

## Downloading and Format datasets
This will take some time to finishing running the first time. It took me roughly 40 minutes.

This section does the following actions:
1. Downloads the Dataset
2. Keeps images with only 3 or 4 dim
3. Transforms the dataset 
4. Turns the data set into data loaders


In [27]:
import numpy as np
from datasets import load_dataset
from transformers import ViTImageProcessor, GPT2TokenizerFast
from torch.utils.data import DataLoader
import os
import torch

# Download the train, val and test splits of the COCO dataset
train_ds = load_dataset("HuggingFaceM4/COCO", split=f"train[:{coco_dataset_ratio}%]", cache_dir=coco_dataset_dir)
valid_ds = load_dataset("HuggingFaceM4/COCO", split=f"validation[:{coco_dataset_ratio}%]", cache_dir=coco_dataset_dir)
test_ds = load_dataset("HuggingFaceM4/COCO", split="test", cache_dir=coco_dataset_dir)


# Filter all non 3 or 4 dim images out
# Can change num_proc, but might be errors with np
torch.set_num_threads(1)
train_ds = train_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)
valid_ds = valid_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)
test_ds = test_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)


# Does pre processing on the data set
# This includes sending pixel value to dev and tokenizing captionsS
tokenizer = GPT2TokenizerFast.from_pretrained(decoder_model)
image_processor = ViTImageProcessor.from_pretrained(encoder_model)

def preprocess(items):
    pixel_values = image_processor(items["image"], return_tensors="pt").pixel_values.to(device)
    targets = tokenizer([sentence["raw"] for sentence in items["sentences"]],
                        max_length=max_length, padding="max_length", truncation=True, return_tensors="pt").to(device)
    return {'pixel_values': pixel_values, 'labels': targets["input_ids"]}

train_dataset = train_ds.with_transform(preprocess)
valid_dataset = valid_ds.with_transform(preprocess)
test_dataset = test_ds.with_transform(preprocess)


# Turns the dataset into a torch DataLoader
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch])
    }

train_dataset_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
valid_dataset_loader = DataLoader(valid_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)
test_dataset_loader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


3


Filter:  50%|█████     | 142000/283374 [09:44<09:41, 243.14 examples/s]


KeyboardInterrupt: 

## Creating the Model
Creates the PureT model from the paper

This section does the following actions:
1. Makes model

In [None]:
for i in range (10):
    print ("hello")

## Training loop
Trains the model and saves the best (lowest val error) and last model

This section does the following actions:
1. Trains the model "Epoch" times
2. Each Epoch has valadation accuracy calculated
3. Save the model with the best valadation accuracy
4. Save the model when the max number of epochs has been reached

In [None]:
for i in range (10):
    print ("hello")

## Post training Metrics
Evalutes the best model on BLEU, ROUGE, and SPICE

This section does the following actions:
1. Loads the model with the highest valadation accuracy
2. Calculates ROUGE score
3. Calculates BLEU score
4. Calculates SPICE score

In [None]:
for i in range (10):
    print ("hello")