Introducing manual compression of image captions on stale (offline) data

In [151]:
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
from transformers import AdamW
from datasets import load_dataset
import torch
from collections import Counter
import fiftyone
import torch.nn as nn
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch.nn.functional as F
import numpy as np
import os
import time
import glob

In [155]:
# Load the pre-trained model and its components
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")



In [156]:
# Load a dataset (for example, a subset of the COCO dataset)
# TODO: Potential datasets with repititive nature that can be used: MS COCO, Flickr30k, Visual Genome, SBU Captions 

# load small part of the coco dataset from all the .jpg images in datasets/mscoco/test2015
dataset = load_dataset("datasets/mscoco/test2015/", split="test[:10]")

Resolving data files: 100%|██████████| 81434/81434 [00:00<00:00, 250560.97it/s]


In [157]:
def generate_caption(image):
    inputs = feature_extractor(images=image, return_tensors="pt")
    output_ids = model.generate(inputs["pixel_values"])
    caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return caption

In [158]:
# Iterate over the dataset and generate captions
generated_captions = []

for data in dataset:
    image = data['image']
    caption = generate_caption(image)
    generated_captions.append(caption)

In [159]:
def update_encoding_dict(captions, encoding_dict):
    for caption in captions:
        words = caption.split()
        encoding_dict.update(words) # purpose of update is to add the words to the dictionary if they don't exist
    return encoding_dict

In [160]:
encoding_dict = Counter() # Counter is a subclass of dictionary for counting hashable objects
threshold = 2 # threshold for word frequency # TODO: find a good threshold

update_encoding_dict(generated_captions, encoding_dict)

# Optionally, create a more compressed form based on frequency
compressed_dict = {word: idx for idx, (word, freq) in enumerate(encoding_dict.items()) if freq > threshold}

In [162]:
class CaptionDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, captions):
        self.encodings = encodings
        self.captions = captions

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.captions[idx])
        return item

    def __len__(self):
        return len(self.captions)

In [163]:
# Assuming `dataset` is your dataset containing images and captions
images = [data['image'] for data in dataset]
captions = generated_captions

# Process images and captions
inputs = feature_extractor(images=images, return_tensors="pt")
outputs = tokenizer(captions, padding="max_length", truncation=True, max_length=128, return_tensors="pt")

# Create dataset and dataloader
train_dataset = CaptionDataset(inputs, outputs["input_ids"])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [164]:
class LoRALayer(nn.Module):
    def __init__(self, original_weight, rank):
        super(LoRALayer, self).__init__()
        self.original_weight = original_weight
        self.rank = rank
        self.U = nn.Parameter(torch.Tensor(self.original_weight.size(0), self.rank))
        self.V = nn.Parameter(torch.Tensor(self.rank, self.original_weight.size(1)))
        nn.init.xavier_uniform_(self.U)
        nn.init.xavier_uniform_(self.V)

    def forward(self):
        return self.original_weight + self.U @ self.V

In [165]:
# Modify the first attention layer of the encoder
# TODO: Try modifying other layers as well and check the results
lora_layers = []

with torch.no_grad():
    original_weight = model.encoder.encoder.layer[0].attention.output.dense.weight
    lora_layer = LoRALayer(original_weight, rank=10).forward()  # Choose an appropriate rank
    # assign the new layer to the model
    model.encoder.encoder.layer[0].attention.output.dense.weight.copy_(lora_layer)
    # add the layer of the model to the list of LoRA layers
    lora_layers.append(model.encoder.encoder.layer[0].attention.output.dense)

In [166]:
def is_lora_param(param, lora_layer):
    # check if the parameter is part of the LoRA layer
    print (lora_layer.parameters())
    print ("nuj")
    print (param)
    return param in lora_layer.parameters()


In [167]:
def custom_loss(outputs, batch, encoding_dict, lora_layers, lambda_val=0.1, lora_lambda_val = 0.01):
    # Standard captioning loss
    standard_loss = outputs.loss

    # Additional compression loss
    compression_loss = 0
    for word_id in batch['labels'].view(-1):
        word = tokenizer.decode([word_id])
        if word in encoding_dict:
            compression_loss += lambda_val / (encoding_dict[word] + 1) # purpose of compression loss is to penalize the model for using words that are not in the encoding dictionary

    # Optionally, add a term for LoRA regularization if needed
    lora_regularization = 0
    # for param in model.parameters():
    #     for lora_layer in lora_layers:
    #         if is_lora_param(param, lora_layer):
    #             lora_regularization += torch.norm(param)

    return standard_loss + compression_loss + lora_lambda_val * lora_regularization

In [168]:
# Fine tuning using custom loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()

lr = 5e-5
num_epochs = 10

optimizer = AdamW([param for param in model.parameters() if param.requires_grad], lr=lr)

for epoch in range(num_epochs):
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # Forward pass
        outputs = model(**batch)
        loss = custom_loss(outputs, batch, encoding_dict, lora_layers)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update progress bar
        loop.set_description(f"Epoch {epoch}")
        loop.set_postfix(loss=loss.item())

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
  item['labels'] = torch.tensor(self.captions[idx])
Epoch 0: 100%|██████████| 1/1 [00:03<00:00,  3.84s/it, loss=4.56]
Epoch 1: 100%|██████████| 1/1 [00:04<00:00,  4.85s/it, loss=0.196]
Epoch 2: 100%|██████████| 1/1 [00:04<00:00,  4.36s/it, loss=0.17]
Epoch 3: 100%|██████████| 1/1 [00:04<00:00,  4.78s/it, loss=0.161]
Epoch 4: 100%|██████████| 1/1 [00:04<00:00,  4.24s/it, loss=0.157]
Epoch 5: 100%|██████████| 1/1 [00:04<00:00,  4.16s/it, loss=0.157]
Epoch 6: 100%|██████████| 1/1 [00:04<00:00,  4.58s/it, loss=0.15]
Epoch 7: 100%|██████████| 1/1 [00:04<00:00,  4.19s/it, loss=0.148]
Epoch 8: 100%|██████████| 1/1 [00:04<00:00,  4.70s/it, loss=0.142]
Epoch 9: 100%|██████████| 1/1 [00:04<00:00,  4.26s/it, loss=0.142]


In [169]:
# create directory to save the model if it doesn't exist
if not os.path.exists("models"):
    os.mkdir("models")
# save model checkpoint to models directory using current timestamp and date
torch.save(model.state_dict(), f"models/{time.strftime('%Y%m%d-%H%M%S')}.pth")


In [170]:
# load latest model checkpoint among all the saved models
latest_model = torch.load(max(glob.glob('models/*.pth'), key=os.path.getctime))
# load the model with the latest checkpoint
model.load_state_dict(latest_model)

<All keys matched successfully>

In [171]:
# Generate captions for the test dataset
generated_captions_custom_model = []
# Iterate over the dataset and generate captions
for data in dataset:
    image = data['image']
    caption = generate_caption(image)
    generated_captions_custom_model.append(caption)




In [None]:
# Encode compressed dictionary word using manual huffman encoding

In [None]:
# Replace compressed_dict words occurring in the generated_captions_custom_model with their corresponding huffman encoding

In [None]:
# compare encoded generated_captions_custom_model + huffman encoding dictionary information with the original generated_captions to calculate compression ratio

In [172]:
# print generated_captions and generated_captions_custom_model elementwise to compare the results
for i in range(len(generated_captions)):
    print (generated_captions[i], generated_captions_custom_model[i])

a green truck parked next to a curb  a green truck and a white truck 
a man is walking down the street with a skateboard  a man standing next to a street sign 
a baseball player swinging a bat at a ball  a baseball player swinging a bat at a ball 
a cow is standing in a field of grass  a cow standing in a grass field 
a black dog sitting in the back of a truck  a black dog sitting in a car 
a man wearing a bow tie and glasses  a man wearing a tie and a bow tie 
a dining room table with a large bowl of food  a bar with a table and a sink 
a man standing next to a wall with a bunch of guitars  a man standing in front of a shelf holding a record 
a man is playing tennis on a clay court  a man playing tennis on a court 
a man and a woman playing a game of frisbee  a baseball player and a ball 
