Introducing manual compression of image captions on stale (offline) data

In [1]:
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
from torch.nn import TransformerEncoderLayer, TransformerEncoder, TransformerDecoderLayer, TransformerDecoder
from transformers import AdamW
from datasets import load_dataset
import torch
import torch.optim as optim
from collections import Counter
import fiftyone
import torch.nn as nn
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch.nn.functional as F
import numpy as np
import os
import time
import glob
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the pre-trained model and its components
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")



In [3]:
# Load a dataset (for example, a subset of the COCO dataset)
# TODO: Potential datasets with repititive nature that can be used: MS COCO, Flickr30k, Visual Genome, SBU Captions 

# load small part of the coco dataset from all the .jpg images in datasets/mscoco/test2015
dataset = load_dataset("datasets/mscoco/test2015/", split="test[:10]")

Resolving data files: 100%|██████████| 81434/81434 [00:00<00:00, 1139829.25it/s]


In [4]:
def generate_caption(image, max_length=128):
    inputs = feature_extractor(images=image, return_tensors="pt")
    output_ids = model.generate(inputs["pixel_values"], max_length=max_length)
    caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return caption

In [5]:
# Iterate over the dataset and generate captions
max_length = 128
generated_captions = []

for data in dataset:
    image = data['image']
    caption = generate_caption(image, max_length=max_length)
    generated_captions.append(caption)

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


In [6]:
class CaptionAutoencoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, max_seq_length):
        super(CaptionAutoencoder, self).__init__()

        # Encoder
        self.encoder_embedding = nn.Embedding(vocab_size, embedding_dim) # input shape has to be (batch_size, sequence_length), output shape is (batch_size, sequence_length, embedding_dim)
        self.encoder_rnn = nn.GRU(embedding_dim, hidden_dim, batch_first=True) # output shape is (batch_size, sequence_length, hidden_dim)
        self.max_seq_length = max_seq_length

        # Decoder
        self.decoder_rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True) # output shape is (batch_size, sequence_length, hidden_dim)
        self.decoder_output = nn.Linear(hidden_dim, vocab_size)

    def encode(self, captions):
        embedded = self.encoder_embedding(captions)
        encoded, _ = self.encoder_rnn(embedded)
        return encoded[:, -1, :]

    def decode(self, encoded):
        # Repeat the encoded state across the sequence length
        repeated_encoded = encoded.unsqueeze(1).repeat(1, self.max_seq_length, 1) 
        decoded, _ = self.decoder_rnn(repeated_encoded)
        output = self.decoder_output(decoded)
        return output

    def forward(self, captions):
        encoded = self.encode(captions)
        decoded = self.decode(encoded)
        return decoded


In [7]:
class TransformerAutoencoder(nn.Module):
    def __init__(self, vocab_size, embed_size, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length):
        super(TransformerAutoencoder, self).__init__()
        self.embed_size = embed_size
        self.vocab_size = vocab_size
        self.max_seq_length = max_seq_length
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_size)
        
        # Transformer Encoder
        encoder_layers = TransformerEncoderLayer(d_model=embed_size, nhead=nhead, dim_feedforward=dim_feedforward)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers=num_encoder_layers)
        
        # Transformer Decoder
        decoder_layers = TransformerDecoderLayer(d_model=embed_size, nhead=nhead, dim_feedforward=dim_feedforward)
        self.transformer_decoder = TransformerDecoder(decoder_layers, num_layers=num_decoder_layers)
        
        # Output layer
        self.out = nn.Linear(embed_size, vocab_size)
        
    def forward(self, src):
        # For a generic forward pass, we'll assume encoding and then decoding
        memory = self.encode(src)
        output = self.decode(memory)
        return output
    
    def encode(self, src):
        # Embed input tokens and scale
        src = self.embedding(src) * math.sqrt(self.embed_size)
        # Encoder
        memory = self.transformer_encoder(src)
        return memory
    
    def decode(self, memory):
        # Decoder
        output = self.transformer_decoder(memory, memory)
        # Pass through the output layer
        output = self.out(output)
        return output

In [8]:
class CaptionDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, captions):
        self.encodings = encodings
        self.captions = captions

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.captions[idx])
        return item

    def __len__(self):
        return len(self.captions)

In [9]:
# Assuming `dataset` is your dataset containing images and captions
images = [data['image'] for data in dataset]
captions = generated_captions

# Process images and captions
inputs = feature_extractor(images=images, return_tensors="pt")
outputs = tokenizer(captions, padding="max_length", truncation=True, max_length=128, return_tensors="pt")

# Assuming 'captions' is a tensor of tokenized captions generated by VLM
vocab_size = 50257  # Size of your vocabulary
embed_size = 4096  # Embedding dimension
nhead = 8  # Number of attention heads
num_encoder_layers = 6  # Number of encoder layers
num_decoder_layers = 6  # Number of decoder layers
dim_feedforward = 2048  # Dimension of the feedforward network
max_seq_length = 128  # Maximum sequence length
# autoencoder = CaptionAutoencoder(vocab_size, embedding_dim, hidden_dim, max_seq_length)
autoencoder = TransformerAutoencoder(vocab_size, embed_size, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length)
autoencoder_output = autoencoder(outputs["input_ids"])

# Create dataset and dataloader
train_dataset = CaptionDataset(inputs, outputs["input_ids"])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)



In [10]:
# print autoencoder_output dtype
print(autoencoder_output.dtype)
# print autoencoder_output shape
print(autoencoder_output.shape)
# print autoencoder_output[0]
print(autoencoder_output[1])

torch.float32
torch.Size([10, 128, 50257])
tensor([[ 0.1738, -0.9830, -0.7030,  ...,  0.2738,  0.0443,  0.3444],
        [-0.4892, -0.2859, -0.4371,  ...,  0.8710, -0.2428, -0.0778],
        [-0.4789, -0.1542, -0.3004,  ...,  0.0101,  0.3103,  0.4744],
        ...,
        [-0.7164, -0.0073,  0.4418,  ...,  0.3863,  0.1172,  0.5472],
        [-1.1427,  0.1207,  0.4334,  ...,  0.6281, -0.0603,  0.3994],
        [-0.8056,  0.1040,  0.6135,  ...,  0.2913,  0.0521,  0.8215]],
       grad_fn=<SelectBackward0>)


In [11]:
# Convert logits to token IDs
token_ids = torch.argmax(autoencoder_output, dim=-1)
# Convert token IDs to text for each sequence in the batch
detokenized_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in token_ids]
# Example output
print (detokenized_texts[0])
print (captions[0])

risk Gig Vern comrades coordinator Yenithing iTunes TwistedYOUocoboCrimeocobo reasoned reasoned reasoned reasoned reasoned reasoned reasoned reasoned reasoned Characters Charactersstars reasoned reasoned Characters Characters Characters reasoned reasoned Characters reasoned Ips Characters reasoned Citiz reasoned reasoned Characters reasoned reasoned Characters Characters reasoned reasoned Ips reasoned reasoned reasoned reasoned reasoned reasoned reasoned reasonedalyst Characters reasoned reasoned Characters reasoned reasoned reasoned reasoned reasoned reasoned reasoned reasoned reasoned Characters Characters Characters reasonedstars reasoned reasoned reasoned reasoned reasoned reasonedstars reasoned reasoned reasoned Characters reasoned reasoned Characters reasoned Characters reasoned Characters reasoned reasoned reasoned reasoned reasoned reasoned Characters Characters reasoned reasoned reasoned reasoned reasoned reasoned Characters reasoned reasoned reasoned reasoned Characters reaso

In [12]:
class LoRALayer(nn.Module):
    def __init__(self, original_weight, rank):
        super(LoRALayer, self).__init__()
        self.original_weight = original_weight
        self.rank = rank
        self.U = nn.Parameter(torch.Tensor(self.original_weight.size(0), self.rank))
        self.V = nn.Parameter(torch.Tensor(self.rank, self.original_weight.size(1)))
        nn.init.xavier_uniform_(self.U)
        nn.init.xavier_uniform_(self.V)

    def forward(self):
        return self.original_weight + self.U @ self.V

In [13]:
# Modify the first attention layer of the encoder
# TODO: Try modifying other layers as well and check the results
lora_layers = []

with torch.no_grad():
    original_weight = model.encoder.encoder.layer[0].attention.output.dense.weight
    lora_layer = LoRALayer(original_weight, rank=10).forward()  # Choose an appropriate rank
    # assign the new layer to the model
    model.encoder.encoder.layer[0].attention.output.dense.weight.copy_(lora_layer)
    # add the layer of the model to the list of LoRA layers
    lora_layers.append(model.encoder.encoder.layer[0].attention.output.dense)

In [14]:
def is_lora_param(param, lora_layer):
    # check if the parameter is part of the LoRA layer
    print (lora_layer.parameters())
    print ("nuj")
    print (param)
    return param in lora_layer.parameters()

In [15]:
def custom_loss(outputs, batch, lora_layers, autoencoder, standard_lambda_val = 0.001, lora_lambda_val = 0.01, compression_lambda_val = 1):
    # Standard captioning loss
    standard_loss = outputs.loss

    # Autoencoder compression reward
    captions = batch['labels']
    compressed_captions = autoencoder.encode(captions)
    # Measure the sparsity of the compressed representation (e.g., using L1 norm) # TODO: Try other measures
    compression_reward = torch.norm(compressed_captions, p=1)
    # Adjust the reward: lower norm (more sparse) should lead to lower loss (higher reward)
    compression_loss = -compression_reward

    # Optionally, add a term for LoRA regularization if needed
    lora_regularization = 0
    # for param in model.parameters():
    #     for lora_layer in lora_layers:
    #         if is_lora_param(param, lora_layer):
    #             lora_regularization += torch.norm(param)
    return standard_lambda_val* standard_loss + compression_lambda_val * compression_loss + lora_lambda_val * lora_regularization

In [16]:
ae_criterion1 = nn.CrossEntropyLoss()

def ae_criterion2 (reconstructed_caption, original_caption, end_of_text_token_id):

    loss= 0

    for i in range(len(original_caption)):

        # remove all end of text tokens from the right in original caption
        trim_index = 0
        for j in range(len(original_caption[i])-1, -1, -1):
            if original_caption[i][j] != end_of_text_token_id:
                trim_index = j
                break
        trim_index += 1

        # Trim the trailing spaces
        trimmed_original = original_caption[i][:trim_index]
        trimmed_reconstructed = reconstructed_caption[i][:trim_index]

        # Calculate the loss (assuming cross-entropy loss)
        loss += ae_criterion1(trimmed_reconstructed, trimmed_original)

    return loss

In [17]:
# Fine tuning using custom loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()

vlm_lr = 5e-5 # TODO: find a good learning rate
ae_lr = 1e-3 # TODO: find a good learning rate
num_epochs = 20 # TODO: find a good number of epochs

vlm_optimizer = AdamW([param for param in model.parameters() if param.requires_grad], lr=vlm_lr)
ae_optimizer = optim.Adam(autoencoder.parameters(), lr=ae_lr) 

for epoch in range(num_epochs):
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # Fine tune VLM with custom loss
        # Forward pass
        # model.zero_grad()
        # outputs = model(**batch)
        # vlm_loss = custom_loss(outputs, batch, lora_layers, autoencoder)
        # # Backward pass and optimization
        # vlm_optimizer.zero_grad()
        # vlm_loss.backward()
        # vlm_optimizer.step()

        # Train the autoencoder
        autoencoder.zero_grad()
        captions = batch['labels']
        compressed_captions = autoencoder.encode(captions)
        reconstructed_captions = autoencoder.decode(compressed_captions)
        # reconstructed_flat = reconstructed_captions.view(-1, reconstructed_captions.size(-1))
        # captions_flat = captions.view(-1)
        end_of_text_token_id = tokenizer.encode('<|endoftext|>')[0]
        ae_loss = ae_criterion2(reconstructed_captions, captions, end_of_text_token_id)
        ae_loss.backward()
        ae_optimizer.step()

        # TODO: change loss as combination of vlm_loss and ae_loss instead of individual losses

        # Update progress bar
        loop.set_description(f"Epoch {epoch}")
        # loop.set_postfix(vlm_loss=vlm_loss.item(), ae_loss=ae_loss.item())
        loop.set_postfix( ae_loss=ae_loss.item())

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
  item['labels'] = torch.tensor(self.captions[idx])
Epoch 0: 100%|██████████| 1/1 [00:41<00:00, 41.09s/it, ae_loss=110]
Epoch 1: 100%|██████████| 1/1 [01:19<00:00, 79.66s/it, ae_loss=127]
Epoch 2: 100%|██████████| 1/1 [01:12<00:00, 72.72s/it, ae_loss=226]
Epoch 3: 100%|██████████| 1/1 [01:18<00:00, 78.00s/it, ae_loss=213]
Epoch 4: 100%|██████████| 1/1 [01:15<00:00, 75.72s/it, ae_loss=114]
Epoch 5: 100%|██████████| 1/1 [01:07<00:00, 67.53s/it, ae_loss=103]
Epoch 6: 100%|██████████| 1/1 [01:15<00:00, 75.97s/it, ae_loss=86.3]
Epoch 7: 100%|██████████| 1/1 [01:04<00:00, 64.64s/it, ae_loss=83.9]
Epoch 8: 100%|██████████| 1/1 [01:15<00:00, 75.32s/it, ae_loss=85.7]
Epoch 9: 100%|██████████| 1/1 [01:03<00:00, 63.02s/it, ae_loss=85]
Epoch 10: 100%|██████████| 1/1 [01:01<00:00, 61.82s/it, ae_loss=78.2]
Epoch 11: 100%|██████████| 1/1 [05:34<00:00, 334.09s/it, ae_loss=67.4]
Epoch 12: 100%|██████████| 1/1 [02:01<00:00, 12

In [18]:
# create directory to save the model if it doesn't exist
if not os.path.exists("models_auto_compress_online_data"):
    os.mkdir("models_auto_compress_online_data")
# save model checkpoint to models directory using current timestamp and date
torch.save(model.state_dict(), f"models_auto_compress_online_data/{time.strftime('%Y%m%d-%H%M%S')}.pth")


In [19]:
# load latest model checkpoint among all the saved models
latest_model = torch.load(max(glob.glob('models_auto_compress_online_data/*.pth'), key=os.path.getctime))
# load the model with the latest checkpoint
model.load_state_dict(latest_model)

<All keys matched successfully>

In [20]:
# Generate captions for the test dataset
generated_captions_custom_model = []
generated_captions_custom_model_pre_compression = []
# Iterate over the dataset and generate captions
for data in dataset:
    image = data['image']
    # use autoencoder to encode and decode the caption
    caption = generate_caption(image)
    generated_captions_custom_model_pre_compression.append(caption)
    caption = tokenizer(caption, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
    caption = caption['input_ids']
    caption = caption.to(device)
    print (caption)
    compressed_caption = autoencoder.encode(caption)
    compressed_caption = compressed_caption.to(device)
    # print (compressed_caption)
    reconstructed_caption = autoencoder.decode(compressed_caption)
    reconstructed_caption = reconstructed_caption.to(device)
    reconstructed_caption = reconstructed_caption.cpu()
    reconstructed_caption = reconstructed_caption.detach().numpy()
    reconstructed_caption = np.argmax(reconstructed_caption, axis=2)
    print (reconstructed_caption)
    reconstructed_caption = tokenizer.decode(reconstructed_caption[0], skip_special_tokens=True)
    generated_captions_custom_model.append(reconstructed_caption)

tensor([[   64,  4077,  7779, 19584,  1097,   319,   257,  4675,  1306,   284,
           257, 13990,   220, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 5

In [21]:
# Encode compressed dictionary word using manual huffman encoding

In [22]:
# Replace compressed_dict words occurring in the generated_captions_custom_model with their corresponding huffman encoding

In [23]:
# compare encoded generated_captions_custom_model + huffman encoding dictionary information with the original generated_captions to calculate compression ratio

In [24]:
# print generated_captions and generated_captions_custom_model elementwise to compare the results
for i in range(len(generated_captions)-1):
    print (generated_captions[i], generated_captions_custom_model_pre_compression[i], generated_captions_custom_model[i])

a green truck parked next to a curb  a green truck parked car on a street next to a fence   with with with with with with truck with to with with with with truck with with with truck with truck to with with to truck with with with with with with with with with truck witha with with with truck with truck with truck with to with with with with with with with with with with truck with with with with truck with with truck truck with with with with truck with with truck truck with with with with with with with with truck truck truck with with with with with with with truck with with with with with to truck with with truck with with truck with with witha with truck with with with with toa with with with with with with with with
a man is walking down the street with a skateboard  a man walking down a street with a car   with with with with with with with with with with with with with with witha with with with with with with the with with with with with playing with with with with with with wi

: 

improve autoencoder architecture, use better semantic meaning preserving metric instead of simply cross entropy