## Before Running:
Please Install all from the requirements.txt (pip install -r requirements.txt)

## Set Hyper Parameters

In [20]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_length = 32
coco_dataset_ratio = 50
coco_dataset_dir = "./coco"
batch_size = 32
num_epochs = 10
learning_rate = 1e-3
patience = 3
weight_decay = 1e-5
encoder_model = "microsoft/swin-base-patch4-window7-224-in22k"
decoder_model = "gpt2"

## Downloading and Format datasets
This will take some time to finishing running the first time. It took me roughly 40 minutes.

This section does the following actions:
1. Downloads the Dataset
2. Keeps images with only 3 or 4 dim
3. Transforms the dataset 
4. Turns the data set into data loaders


In [34]:
import numpy as np
from datasets import load_dataset
from transformers import ViTImageProcessor, GPT2TokenizerFast
from torch.utils.data import DataLoader
import torch

# Download the train, val and test splits of the COCO dataset
train_ds = load_dataset("HuggingFaceM4/COCO", split=f"train[:{coco_dataset_ratio}%]", cache_dir=coco_dataset_dir)
valid_ds = load_dataset("HuggingFaceM4/COCO", split=f"validation[:{coco_dataset_ratio}%]", cache_dir=coco_dataset_dir)
test_ds = load_dataset("HuggingFaceM4/COCO", split="test", cache_dir=coco_dataset_dir)


# Filter all non 3 or 4 dim images out
# Can change num_proc, but might be errors with np
train_ds = train_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)
valid_ds = valid_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)
test_ds = test_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)


# Does pre processing on the data set
# This includes pre-trained ViTimage feature extraction and tokenizing captions
# I am unsure if the paper does any of this pre processing
tokenizer = GPT2TokenizerFast.from_pretrained(decoder_model)
image_processor = ViTImageProcessor.from_pretrained(encoder_model)

def preprocess(items):
    pixel_values = image_processor(items["image"], return_tensors="pt").pixel_values.to(device)
    targets = tokenizer([sentence["raw"] for sentence in items["sentences"]],
                        max_length=max_length, padding="max_length", truncation=True, return_tensors="pt").to(device)
    return {'pixel_values': pixel_values, 'labels': targets["input_ids"]}

train_dataset = train_ds.with_transform(preprocess)
valid_dataset = valid_ds.with_transform(preprocess)
test_dataset = test_ds.with_transform(preprocess)


# Turns the dataset into a torch DataLoader
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch])
    }

train_dataset_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
valid_dataset_loader = DataLoader(valid_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)
test_dataset_loader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


## Creating the Model
Creates the PureT model from the paper

This section does the following actions:
1. Defines the PureT model from the paper

In [37]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
from torch import nn, einsum
import numpy as np
from einops import rearrange, repeat


class PureT(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, window_size):
        super(PureT, self).__init__()

        # Swin Transformer
        self.swin_t = torchvision.models.swin_b()

        # Before Refining Encoder (bre)
        #self.bre_linear = nn.Linear() #[[add dime, just 512]]
        #self.bre_avg_pool = nn.AvgPool1d() #[[change dim?]]

        # Refining Encoder TODO
        self.refine_encoder = nn.Sequential(

        )


    def forward(self, images, captions):
        swin_output = self.swin_t(images)
        print (swin_output.shape)

        #TODO


        #return output


## Training loop
Trains the model and saves the best (lowest val error) and last model

This section does the following actions:
1. Creates the PureT model
2. Sets up optimizer, scheduler, counter for training
3. Trains for num_epochs epochs
2. Each Epoch has valadation accuracy calculated TODO
3. Save the model with the best valadation accuracy TODO
4. Save the model when the max number of epochs has been reached TODO

In [38]:
import torch
import numpy as np
import os
import argparse

from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from transformers import VisionEncoderDecoderModel
from torchvision import transforms


# Model setup
# Not sure where vocab_size (maybe from the gpt-2 tokenizer?) or embed_dim come from yet
# But num_heads and window_size are from Table 6 of the paper https://arxiv.org/pdf/2203.15350
model = PureT(vocab_size=30522, embed_dim=768, num_heads=3, window_size=12)

# Setup for training
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(optimizer, patience=patience)

# Values to remember training performance
stop_counter = 0
train_losses = []
best_val_loss = float('inf')
val_losses = []

# Training loop
for epoch in range(num_epochs):
    # Epoch setup
    model.train()
    train_loss = 0.0

    # Loop through data loader batches
    train_dataloader_iter = tqdm(train_dataset_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)
    for i, data in enumerate(train_dataloader_iter):
        
        # Get values from data loader
        pixel_vals = batch["pixel_values"].to(device)
        captions = batch["labels"].to(device)

        # Generate outputs
        optimizer.zero_grad()
        outputs = model(pixel_vals=pixel_vals, captions=captions)
        loss = outputs.loss

        # Grad descent
        loss.backwards()
        optimizer.step()

        train_loss += loss.item()
    
    



                                                    

ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`.

## Post training Metrics
TODO ALL
Evalutes the best model on BLEU, ROUGE, and SPICE

This section does the following actions:
1. Loads the model with the highest valadation accuracy
2. Calculates ROUGE score
3. Calculates BLEU score
4. Calculates SPICE score

In [None]:
for i in range (10):
    print ("hello")