## Before Running:
Please Install all from the requirements.txt (pip install -r requirements.txt).

## Set Hyper Parameters

In [1]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder_decoder_layers = 3
encoder_decoder_heads = 8
embedded_dim = 768 # Don't change
max_length = 32
coco_dataset_ratio = 50
coco_dataset_dir = "./coco"
vocab_size = 50257
batch_size = 64
num_epochs = 10
learning_rate = 1e-3
patience = 3
weight_decay = 1e-5
preprocess_swin_model = "microsoft/swin-tiny-patch4-window7-224"
encoder_model = "microsoft/swin-base-patch4-window7-224-in22k"
decoder_model = "gpt2"

## Downloading and Format datasets
This will take some time to finishing running the first time. It took me roughly 40 minutes.

This section does the following actions:
1. Downloads the Dataset
2. Keeps images with only 3 or 4 dim
3. Transforms the dataset 
4. Turns the data set into data loaders


In [44]:
import numpy as np
from datasets import load_dataset
from transformers import ViTImageProcessor, GPT2TokenizerFast, AutoImageProcessor, SwinModel
from torch.utils.data import DataLoader, Dataset
import torch

# Download the train, val and test splits of the COCO dataset
train_ds = load_dataset("HuggingFaceM4/COCO", split=f"train[:{coco_dataset_ratio}%]", cache_dir=coco_dataset_dir)
valid_ds = load_dataset("HuggingFaceM4/COCO", split=f"validation[:{coco_dataset_ratio}%]", cache_dir=coco_dataset_dir)
test_ds = load_dataset("HuggingFaceM4/COCO", split="test", cache_dir=coco_dataset_dir)

# Filter all non 3 or 4 dim images out
# Can change num_proc, but might be errors with np
train_ds = train_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)
valid_ds = valid_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)
test_ds = test_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)

# Does pre processing on the data set
# This includes pre-trained ViTimage feature extraction and tokenizing captions
tokenizer = GPT2TokenizerFast.from_pretrained(decoder_model)
tokenizer.pad_token = tokenizer.eos_token
image_processor = ViTImageProcessor.from_pretrained(encoder_model)
image_processor_swin = SwinModel.from_pretrained(preprocess_swin_model)

def preprocess(items):
    # Image pre-processing
    # use ViT and SWIN since no back prop
    pixel_values = image_processor(items["image"], return_tensors="pt").pixel_values
    with torch.no_grad():
        pixel_values = image_processor_swin (pixel_values)

    print (items)
    print (items["sentences"]['raw'])
    tokens = tokenizer(items["sentences"]['raw'],
                        max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
    print (tokens)

    # tokenize captions
    targets = tokenizer([sentence["raw"] for sentence in items["sentences"]],
                        max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
    
    # Keep image file

    return {'pixel_values': pixel_values, 'labels': targets["input_ids"]}

#for data in train_ds:
    #print (data)

test_dataset2 = test_ds.map(preprocess)
for dat in test_dataset2:    
    print (dat)

train_dataset = train_ds.with_transform(preprocess)
valid_dataset = valid_ds.with_transform(preprocess)
test_dataset = test_ds.with_transform(preprocess)


# Turns the dataset into a torch DataLoader
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch])
    }

train_dataset_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
valid_dataset_loader = DataLoader(valid_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)
test_dataset_loader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Map:   0%|          | 0/24920 [00:00<?, ? examples/s]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x360 at 0x1970BEC87D0>, 'filepath': 'COCO_val2014_000000391895.jpg', 'sentids': [770337, 771687, 772707, 776154, 781998], 'filename': 'COCO_val2014_000000391895.jpg', 'imgid': 0, 'split': 'test', 'sentences': {'tokens': ['a', 'man', 'with', 'a', 'red', 'helmet', 'on', 'a', 'small', 'moped', 'on', 'a', 'dirt', 'road'], 'raw': 'A man with a red helmet on a small moped on a dirt road. ', 'imgid': 0, 'sentid': 770337}, 'cocoid': 391895}
A man with a red helmet on a small moped on a dirt road. 





TypeError: string indices must be integers, not 'str'

## Define the Model
Creates the PureT model from the paper

This section does the following actions:
1. Creates the SWIN Transformer used by PureT
2. Creates the PureT encoder
3. Creates the PureT decoder
4. Creates the PureT model

Download the pre-trained SwinT weights from here https://drive.google.com/drive/folders/1HBw5NGGw8DjkyNurksCP5v8a5f0FG7zU and put them in this folder before running

In [33]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision.models
from transformers import SwinModel

class BasicCaptionTransformer(torch.nn.Module):
    def __init__(self):
        super(BasicCaptionTransformer, self).__init__()

        # Image processing pre-done with ViTImageProcessor in the Downloading and Format datasets step
        self.swin = SwinModel.from_pretrained(preprocess_swin_model)

        # build encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedded_dim, nhead=encoder_decoder_heads, batch_first=True)
        self.encoders = nn.TransformerEncoder(encoder_layer, encoder_decoder_layers)

        # embeddings
        self.embeddings = nn.Embedding(vocab_size, embedded_dim)

        # build decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model=embedded_dim, nhead=encoder_decoder_heads, batch_first=True)
        self.decoders = nn.TransformerDecoder(decoder_layer, encoder_decoder_layers)

        # final
        self.linear = nn.Linear(in_features=768, out_features=vocab_size)


    def forward(self, images, captions):
        x = images
        with torch.no_grad():
            x = self.swin(x)
        x = self.encoders(x.last_hidden_state)
        
        captions = self.embeddings(captions)
        x = self.decoders(tgt=captions.to(x.dtype), memory=x)
        return self.linear(x)
    


## Training loop
Trains the model and saves the best (lowest val error) and last model

This section does the following actions:
1. Creates the Basic transformer model
2. Sets up optimizer, scheduler, counter for training
3. Trains for num_epochs epochs
2. Each Epoch has valadation accuracy calculated
3. Save the model with the best valadation accuracy TODO
4. Save the model when the max number of epochs has been reached TODO

In [34]:
from tqdm import tqdm

# Loop setup
model = BasicCaptionTransformer()
model = model.to(device)

loss_function = nn.CrossEntropyLoss(ignore_index=-1)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)


# Train
predictions_per_batch = batch_size * max_length
stop_counter = 0
train_losses = []
val_losses = []
best_val_loss = float('inf')

train_len = len(train_dataset_loader)
val_len = len(valid_dataset_loader)


for epoch in range(num_epochs):
    model.train()

    # Loop through training data loader batches
    train_loss = 0.0
    train_dataloader_iter = tqdm(train_dataset_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)
    for i, data in enumerate(train_dataloader_iter):
        
        # Get values from data loader
        pixel_vals = data["pixel_values"].to(device)
        captions = data["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(images=pixel_vals, captions=captions) # outputs are NAN... bad...

        loss = loss_function(outputs.permute(0,2,1), captions)
        loss.backward()
        optimizer.step()

        # save loss
        train_loss += loss.item()
        if i % 100 == 99:
            print ("Loss so far is: " + str (train_loss / i))
            print ("tensor: ", outputs)


    # Validation
    val_loss = 0.0
    valid_dataset_iter = tqdm(valid_dataset_loader,  desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)
    for i, data in enumerate(valid_dataset_iter):
        
        # Get values from data loader
        pixel_vals = data["pixel_values"].to(device)
        captions = data["labels"].to(device)

        outputs = model(images=pixel_vals, captions=captions)
        loss = loss_function(outputs.permute(0,2,1), captions)

        # save loss
        val_loss += loss.item()


    # Save model
    torch.save(model.state_dict(), f'./models/model_epoch_{epoch+1}.pt')

    train_losses.append(train_loss / train_len)
    val_losses.append(val_loss / val_len)

    # Print losses
    # divide by length of data loader... 
    print("\nEpoch: " + str(epoch) + 
        "\nTrain Loss: " + str(train_loss / train_len) +
        "\nVal Loss: " + str(val_loss / val_len))

print (train_losses)
print (val_losses)


Epoch 1/10:   0%|          | 0/4418 [00:00<?, ?it/s]

SwinModelOutput(last_hidden_state=tensor([[[ 3.0675e-01,  5.0945e-01,  1.7325e-01,  ...,  1.1788e-01,
          -3.1754e-01,  4.1218e-01],
         [-8.1517e-01,  2.1400e-01, -4.5147e-01,  ...,  8.6758e-03,
          -6.0718e-01,  2.1204e-01],
         [-6.3092e-01,  1.8830e-01, -5.9638e-01,  ...,  1.0285e-01,
          -5.0117e-01,  1.5747e-01],
         ...,
         [-4.8707e-01,  5.3982e-01, -4.8427e-01,  ..., -9.9385e-02,
          -6.2420e-01,  8.4653e-02],
         [-6.4340e-01,  3.8748e-01, -2.9780e-01,  ..., -2.2808e-01,
          -6.6858e-01,  4.5301e-02],
         [ 3.0318e-01,  4.8212e-01,  4.9809e-02,  ..., -1.5281e-02,
          -4.1127e-01,  4.0977e-01]],

        [[ 3.4669e-01,  1.0651e-01,  2.1525e-01,  ...,  1.8687e-01,
           2.9659e-02, -1.4759e-02],
         [ 2.5557e-01,  7.6197e-01, -7.0156e-01,  ..., -3.1938e+00,
           1.3024e-01,  3.0944e-01],
         [-1.4593e+00,  5.5855e-01,  6.8674e-01,  ..., -4.4938e-01,
           3.0534e-01, -5.4908e-01],
     

Epoch 1/10:   0%|          | 1/4418 [00:04<5:42:04,  4.65s/it]

SwinModelOutput(last_hidden_state=tensor([[[ 5.2467e-01, -1.6158e-01,  3.2696e-02,  ...,  2.9319e-01,
          -6.6524e-02,  1.4002e-01],
         [ 2.0795e-01, -9.2373e-01, -1.6926e+00,  ...,  1.2636e-01,
          -6.5223e-01,  1.0950e-01],
         [-3.6473e-01, -1.6991e+00, -1.4850e+00,  ..., -1.2395e-01,
          -6.0512e-01,  8.9079e-02],
         ...,
         [-9.9634e-02,  1.1863e+00,  1.4123e+00,  ...,  6.1957e+00,
           2.4381e-01, -1.8393e+00],
         [ 1.4298e-01,  1.1754e+00,  7.6745e-01,  ...,  5.5598e+00,
          -4.3329e-01, -2.0899e+00],
         [ 6.4411e-01, -4.9268e-01, -3.1999e-01,  ...,  2.8826e-01,
          -1.4891e-01,  1.3572e-01]],

        [[ 5.9943e-01,  6.7352e-02,  1.1703e-01,  ...,  4.4320e-01,
          -1.8730e-02,  2.1719e-01],
         [ 1.8858e-01,  1.1023e-01,  6.4541e-02,  ...,  2.3339e-01,
           2.1339e-01,  2.0589e-01],
         [-7.5265e-02,  1.1456e-01, -2.2233e-01,  ...,  2.8940e-01,
          -3.9021e-02,  8.4188e-02],
     

                                                              

KeyboardInterrupt: 

## Post Training Metrics
Computes the test loss and common test metrics

This section does the following actions:
1. Loads the specified model
2. Runs through the test set and reports loss
3. Runs through the val set for BLEU and ROUGE metrics

In [15]:
import evaluate
from transformers import EvalPrediction
from tqdm import tqdm

# Load model
MODEL_PATH = "./models/model_epoch_8.pt"
LOAD_MODEL = True

eval_model = None
if LOAD_MODEL:
       if "model" in locals():
              model.to('cpu')

       eval_model = BasicCaptionTransformer()
       eval_model.load_state_dict(torch.load(MODEL_PATH))
       eval_model = eval_model.to(device)

else:
       eval_model = model

# Eval setup
loss_function = nn.CrossEntropyLoss(ignore_index=-1)
test_loss = 0.0

predictions = []
labels = []

# Run through the test for test loss
with torch.no_grad():
       test_dataset_iter = tqdm(test_dataset_loader,  desc=f'Test Set Progress: ', leave=False)
       for data in test_dataset_iter:

              # get data from batch
              pixel_vals = data["pixel_values"].to(device)
              labels = data["labels"].to(device)

              # Predict captions
              outputs = eval_model(images=pixel_vals, captions=labels)
              test_loss += loss_function(outputs.permute(0,2,1), labels)
              print (test_loss)

print ("Test Loss: " + str(test_loss / len(test_dataset_loader)))


# Run through valadation set with best model
predictions = []
labels = []
with torch.no_grad():
       for data in valid_dataset_loader:

              # get data from batch
              pixel_vals = data["pixel_values"].to(device)
              labels = data["labels"].to(device)
       
              # Predict captions
              outputs = eval_model(images=pixel_vals, captions=labels)

              # Format labels
              logits = outputs.detach().cpu()
              predictions.extend(logits.argmax(dim=-1).tolist())
              labels.extend(labels.tolist())
    

# Format predictions into Hugging Face class
eval_predictions = EvalPrediction(predictions=predictions, label_ids=labels)

predictions = eval_predictions.predictions
labels = eval_predictions.label_ids

# Tokenize predictions and reference captions
predictions_str = tokenizer.batch_decode(predictions, skip_special_tokens=True)
labels_str = tokenizer.batch_decode(labels, skip_special_tokens=True)


# Load test evaluators
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")

# Compute and print Rouge-1, Rogue-2, RougeL
rouge_result = rouge.compute(predictions=predictions_str, references=labels_str)
rouge_result = {k: round(v * 100, 4) for k, v in rouge_result.items()}
print ("ROUGE Metrics: \nROUGE-1: " + rouge_result.get("rouge1", 0) + 
       "\nROUGE-2: " + rouge_result.get("rouge2", 0) + 
       "\nROUGE-L: " + rouge_result.get("rougeL", 0))


# Compute and print BLEU metrics
bleu_result = bleu.compute(predictions=predictions_str, references=labels_str)
bleu_score = round(bleu_result["bleu"] * 100, 4)
print ("BLEU Metrics: " + bleu_score)

Test Set Progress:   0%|          | 1/390 [00:00<03:24,  1.90it/s]

tensor([[   32,   582,   351,  ..., 50256, 50256, 50256],
        [ 5124, 10311,   257,  ..., 50256, 50256, 50256],
        [   32,   582, 10311,  ..., 50256, 50256, 50256],
        ...,
        [   32,  7684,   286,  ..., 50256, 50256, 50256],
        [   64,   736, 12525,  ..., 50256, 50256, 50256],
        [   32,  1271,   286,  ..., 50256, 50256, 50256]], device='cuda:0')
tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan, 

Test Set Progress:   1%|          | 2/390 [00:01<03:26,  1.87it/s]

tensor([[   64,  7684,   286,  ..., 50256, 50256, 50256],
        [   32, 39145,  2330,  ..., 50256, 50256, 50256],
        [   32, 28774,  5017,  ..., 50256, 50256, 50256],
        ...,
        [   32, 27638,   286,  ..., 50256, 50256, 50256],
        [14945, 15900, 40470,  ..., 50256, 50256, 50256],
        [13247,   286, 15900,  ..., 50256, 50256, 50256]], device='cuda:0')
tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan, 

Test Set Progress:   1%|          | 3/390 [00:01<03:35,  1.80it/s]

tensor([[   32, 29556,   286,  ..., 50256, 50256, 50256],
        [ 3347,   538,   389,  ..., 50256, 50256, 50256],
        [   32, 29556,   286,  ..., 50256, 50256, 50256],
        ...,
        [   32,  9875,   319,  ..., 50256, 50256, 50256],
        [  464,  5044,   318,  ..., 50256, 50256, 50256],
        [   32,  7586,  9875,  ..., 50256, 50256, 50256]], device='cuda:0')
tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan, 

Test Set Progress:   1%|          | 4/390 [00:02<03:39,  1.76it/s]

tensor([[   32,  9875,   318,  ..., 50256, 50256, 50256],
        [   32,  6473,  5055,  ..., 50256, 50256, 50256],
        [   64,  7586,   290,  ..., 50256, 50256, 50256],
        ...,
        [   32,  2415, 10718,  ..., 50256, 50256, 50256],
        [   32,  2415,  5586,  ..., 50256, 50256, 50256],
        [   32,  2415,  5586,  ..., 50256, 50256, 50256]], device='cuda:0')
tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan, 

Test Set Progress:   1%|▏         | 5/390 [00:02<03:35,  1.78it/s]

tensor([[   64,  2415, 10718,  ..., 50256, 50256, 50256],
        [   32,  2576,   351,  ..., 50256, 50256, 50256],
        [   32,  2415,   351,  ..., 50256, 50256, 50256],
        ...,
        [   32,  2415,   318,  ..., 50256, 50256, 50256],
        [   32,  2415,   318,  ..., 50256, 50256, 50256],
        [   64,  2415,  1016,  ..., 50256, 50256, 50256]], device='cuda:0')
tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan, 

                                                                  

tensor([[   32,   582, 10311,  ..., 50256, 50256, 50256],
        [   32, 17876,  7976,  ..., 50256, 50256, 50256],
        [   32,  8223,  1891,  ..., 50256, 50256, 50256],
        ...,
        [ 7571,  1862, 22647,  ..., 50256, 50256, 50256],
        [ 7571,  1862,  1450,  ..., 50256, 50256, 50256],
        [ 1858,   389,   734,  ..., 50256, 50256, 50256]], device='cuda:0')
tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan, 



KeyboardInterrupt: 