Introducing manual compression of image captions on stale (offline) data

In [17]:
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer, AutoProcessor, LlavaForConditionalGeneration
from transformers import AdamW
from transformers import BitsAndBytesConfig
from datasets import load_dataset
import torch
import torch.optim as optim
import torch.quantization
from torch.quantization import quantize_dynamic
from collections import Counter
import fiftyone
import torch.nn as nn
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch.nn.functional as F
import numpy as np
from PIL import Image
import os
import time
import glob

In [18]:
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Quantization config
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)
# Load the pre-trained model and its components
model = LlavaForConditionalGeneration.from_pretrained("bczhou/tiny-llava-v1-hf")
# # Quantize the model for CPU (dynamic quantization as an example)
# model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
# model.to(device)
processor = AutoProcessor.from_pretrained("bczhou/tiny-llava-v1-hf")
# feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
# tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]Error while downloading from https://cdn-lfs-us-1.huggingface.co/repos/7f/40/7f40566b133188de11a64c8f928ec8db7ca3f4c63f5999a58e99c448903980e2/5757b71668b0f69d61185ff018ada04d490dac417f998bbd6455ec8a84fa1c72?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27model-00001-of-00002.safetensors%3B+filename%3D%22model-00001-of-00002.safetensors%22%3B&Expires=1708102769&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcwODEwMjc2OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzdmLzQwLzdmNDA1NjZiMTMzMTg4ZGUxMWE2NGM4ZjkyOGVjOGRiN2NhM2Y0YzYzZjU5OTlhNThlOTljNDQ4OTAzOTgwZTIvNTc1N2I3MTY2OGIwZjY5ZDYxMTg1ZmYwMThhZGEwNGQ0OTBkYWM0MTdmOTk4YmJkNjQ1NWVjOGE4NGZhMWM3Mj9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=iAJ32D2aVhHZbQ%7EPGB4bafMAgyFKwuo5pgmulCCrGSVXBM0f-JCdfwcYQUW8bfuehCX%7EGb5jrT6TfXLwTNHp51f-0dgB2iDubnV8VADDiZ90kpkZXst2Lja3sCjAbZDYPhqh%7ErNSICseE%

ConnectionError: (MaxRetryError("HTTPSConnectionPool(host='cdn-lfs-us-1.huggingface.co', port=443): Max retries exceeded with url: /repos/7f/40/7f40566b133188de11a64c8f928ec8db7ca3f4c63f5999a58e99c448903980e2/5757b71668b0f69d61185ff018ada04d490dac417f998bbd6455ec8a84fa1c72?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27model-00001-of-00002.safetensors%3B+filename%3D%22model-00001-of-00002.safetensors%22%3B&Expires=1708102769&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcwODEwMjc2OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzdmLzQwLzdmNDA1NjZiMTMzMTg4ZGUxMWE2NGM4ZjkyOGVjOGRiN2NhM2Y0YzYzZjU5OTlhNThlOTljNDQ4OTAzOTgwZTIvNTc1N2I3MTY2OGIwZjY5ZDYxMTg1ZmYwMThhZGEwNGQ0OTBkYWM0MTdmOTk4YmJkNjQ1NWVjOGE4NGZhMWM3Mj9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=iAJ32D2aVhHZbQ~PGB4bafMAgyFKwuo5pgmulCCrGSVXBM0f-JCdfwcYQUW8bfuehCX~Gb5jrT6TfXLwTNHp51f-0dgB2iDubnV8VADDiZ90kpkZXst2Lja3sCjAbZDYPhqh~rNSICseE~kYpZ3n911Y4ZrkSssuTY66ML8rWp9S93FmjD2UMYI7dODPKLCg1GYUEiylVtL11DOfAB70MZvOD6nMYkqepd1wvNtXtGdLdm~Q-~z3D~2Uin514XPJTXkPuhAW6xZjuwgMihFq~NBo0NWN~wG68F2ikl8Ovc8C~Gzp9aOR-lg~-~crK1HSUxhvloixodNtwhITfKazgQ__&Key-Pair-Id=KCD77M1F0VK2B (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x2dcd962b0>: Failed to establish a new connection: [Errno 8] nodename nor servname provided, or not known'))"), '(Request ID: 567aa697-becc-4526-9798-d0595aa728fd)')

In [3]:
# Load a dataset (for example, a subset of the COCO dataset)
# TODO: Potential datasets with repititive nature that can be used: MS COCO, Flickr30k, Visual Genome, SBU Captions 
# load small part of the coco dataset from all the .jpg images in datasets/mscoco/test2015
dataset = load_dataset("datasets/mscoco/test2015/", split="test[:1]")

Resolving data files: 100%|██████████| 81434/81434 [00:00<00:00, 1152607.01it/s]


In [4]:
def generate_caption(image, processor, max_length=128):
    prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"
    image = Image.open(image)
    inputs = processor(text=prompt, images=image, return_tensors="pt")
    # Generate
    # TODO: Vary generation parameters to get different results
    generate_ids = model.generate(**inputs, max_length=max_length, num_return_sequences=1, temperature=0.9, top_k=50, top_p=0.95, 
                                  do_sample=True,
                                #   pad_token_id=processor.tokenizer.pad_token_id, 
                                #   eos_token_id=processor.tokenizer.eos_token_id, bos_token_id=processor.tokenizer.bos_token_id, 
                                #   use_cache=True, num_beams=5, length_penalty=0.8, no_repeat_ngram_size=3, early_stopping=True, 
                                #   num_beam_groups=3, diversity_penalty=0.5
                                  )
    outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    print ("output is: ", outputs)
    return outputs

In [5]:
# Iterate over the dataset and generate captions
max_length = 128
generated_captions = []

for data in dataset:
    image = data['image']
    caption = generate_caption(image, processor, max_length=max_length)
    generated_captions.append(caption)

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


In [6]:
# Define the Transformer-based Autoencoder using PyTorch
class TransformerAutoencoder(nn.Module):
    def __init__(self, input_dim, model_dim, num_layers, num_heads, ff_dim, dropout=0.1):
        super(TransformerAutoencoder, self).__init__()
        # Encoder
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=dropout)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        # Decoder
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=model_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=dropout)
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_layers)
        # Project back to token space
        self.output_layer = nn.Linear(model_dim, input_dim)

    def encode(self, src, src_mask=None, src_key_padding_mask=None):
        return self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
    
    def decode(self, memory, tgt, tgt_mask=None, tgt_key_padding_mask=None):
        return self.decoder(tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # Assume src is already tokenized and prepared as model input
        memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        output = self.decoder(memory, memory, tgt_mask=src_mask, tgt_key_padding_mask=src_key_padding_mask)
        output = self.output_layer(output)
        return output

In [7]:
class CaptionDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, captions):
        self.encodings = encodings
        self.captions = captions

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.captions[idx])
        return item

    def __len__(self):
        return len(self.captions)

In [8]:
# Assuming `dataset` is your dataset containing images and captions
images = [data['image'] for data in dataset]
captions = generated_captions

# Process images and captions
inputs = feature_extractor(images=images, return_tensors="pt")
outputs = tokenizer(captions, padding="max_length", truncation=True, max_length=128, return_tensors="pt")

# Assuming 'captions' is a tensor of tokenized captions generated by VLM
vocab_size = tokenizer.vocab_size
embedding_dim = 50257
hidden_dim = 512
max_seq_length = 128
autoencoder = CaptionAutoencoder(vocab_size, embedding_dim, hidden_dim, max_seq_length)
autoencoder_output = autoencoder(outputs["input_ids"])

# Create dataset and dataloader
train_dataset = CaptionDataset(inputs, outputs["input_ids"])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [9]:
class LoRALayer(nn.Module):
    def __init__(self, original_weight, rank):
        super(LoRALayer, self).__init__()
        self.original_weight = original_weight
        self.rank = rank
        self.U = nn.Parameter(torch.Tensor(self.original_weight.size(0), self.rank))
        self.V = nn.Parameter(torch.Tensor(self.rank, self.original_weight.size(1)))
        nn.init.xavier_uniform_(self.U)
        nn.init.xavier_uniform_(self.V)

    def forward(self):
        return self.original_weight + self.U @ self.V

In [10]:
# Modify the first attention layer of the encoder
# TODO: Try modifying other layers as well and check the results
lora_layers = []

with torch.no_grad():
    original_weight = model.encoder.encoder.layer[0].attention.output.dense.weight
    lora_layer = LoRALayer(original_weight, rank=10).forward()  # Choose an appropriate rank
    # assign the new layer to the model
    model.encoder.encoder.layer[0].attention.output.dense.weight.copy_(lora_layer)
    # add the layer of the model to the list of LoRA layers
    lora_layers.append(model.encoder.encoder.layer[0].attention.output.dense)

In [11]:
def is_lora_param(param, lora_layer):
    # check if the parameter is part of the LoRA layer
    print (lora_layer.parameters())
    print ("nuj")
    print (param)
    return param in lora_layer.parameters()

In [19]:
def custom_loss(outputs, batch, lora_layers, autoencoder, standard_lambda_val = 1, lora_lambda_val = 0.01, compression_lambda_val = 0.01):
    # Standard captioning loss
    standard_loss = outputs.loss

    # Autoencoder compression reward
    captions = batch['labels']
    compressed_captions = autoencoder.encode(captions)
    # Measure the sparsity of the compressed representation (e.g., using L1 norm) # TODO: Try other measures
    compression_reward = torch.norm(compressed_captions, p=1)
    # Adjust the reward: lower norm (more sparse) should lead to lower loss (higher reward)
    compression_loss = -compression_reward

    # Optionally, add a term for LoRA regularization if needed
    lora_regularization = 0
    # for param in model.parameters():
    #     for lora_layer in lora_layers:
    #         if is_lora_param(param, lora_layer):
    #             lora_regularization += torch.norm(param)
    return standard_lambda_val* standard_loss + compression_lambda_val * compression_loss + lora_lambda_val * lora_regularization

In [24]:
ae_criterion1 = nn.CrossEntropyLoss()

def ae_criterion2 (reconstructed_caption, original_caption, end_of_text_token_id):

    loss= 0

    for i in range(len(original_caption)):

        # remove all end of text tokens from the right in original caption
        trim_index = 0
        for j in range(len(original_caption[i])-1, -1, -1):
            if original_caption[i][j] != end_of_text_token_id:
                trim_index = j
                break
        trim_index += 1

        # Trim the trailing spaces
        trimmed_original = original_caption[i][:trim_index]
        trimmed_reconstructed = reconstructed_caption[i][:trim_index]

        # Calculate the loss (assuming cross-entropy loss)
        loss += ae_criterion1(trimmed_reconstructed, trimmed_original)

    return loss

In [25]:
# Fine tuning using custom loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()

vlm_lr = 5e-5 # TODO: find a good learning rate
ae_lr = 1e-3 # TODO: find a good learning rate
num_epochs = 50 # TODO: find a good number of epochs

vlm_optimizer = AdamW([param for param in model.parameters() if param.requires_grad], lr=vlm_lr)
ae_optimizer = optim.Adam(autoencoder.parameters(), lr=ae_lr) 

for epoch in range(num_epochs):
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # Fine tune VLM with custom loss
        # Forward pass
        model.zero_grad()
        outputs = model(**batch)
        vlm_loss = custom_loss(outputs, batch, lora_layers, autoencoder)
        # Backward pass and optimization
        vlm_optimizer.zero_grad()
        vlm_loss.backward()
        vlm_optimizer.step()

        # Train the autoencoder
        autoencoder.zero_grad()
        captions = batch['labels']
        compressed_captions = autoencoder.encode(captions)
        reconstructed_captions = autoencoder.decode(compressed_captions)
        # reconstructed_flat = reconstructed_captions.view(-1, reconstructed_captions.size(-1))
        # captions_flat = captions.view(-1)
        end_of_text_token_id = tokenizer.encode('<|endoftext|>')[0]
        ae_loss = ae_criterion2(reconstructed_captions, captions, end_of_text_token_id)
        ae_loss.backward()
        ae_optimizer.step()

        # TODO: change loss as combination of vlm_loss and ae_loss instead of individual losses

        # Update progress bar
        loop.set_description(f"Epoch {epoch}")
        loop.set_postfix(vlm_loss=vlm_loss.item(), ae_loss=ae_loss.item())

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
  item['labels'] = torch.tensor(self.captions[idx])
Epoch 0: 100%|██████████| 1/1 [00:35<00:00, 35.97s/it, ae_loss=101, vlm_loss=-48.3]
Epoch 1: 100%|██████████| 1/1 [00:44<00:00, 44.13s/it, ae_loss=90.9, vlm_loss=-50]
Epoch 2: 100%|██████████| 1/1 [00:39<00:00, 39.88s/it, ae_loss=79.5, vlm_loss=-50.4]
Epoch 3: 100%|██████████| 1/1 [00:44<00:00, 44.66s/it, ae_loss=69.3, vlm_loss=-50.8]
Epoch 4: 100%|██████████| 1/1 [00:43<00:00, 43.62s/it, ae_loss=60.5, vlm_loss=-50.9]
Epoch 5: 100%|██████████| 1/1 [00:42<00:00, 42.38s/it, ae_loss=52.9, vlm_loss=-51]
Epoch 6: 100%|██████████| 1/1 [00:41<00:00, 41.86s/it, ae_loss=46.9, vlm_loss=-51.1]
Epoch 7: 100%|██████████| 1/1 [00:41<00:00, 41.77s/it, ae_loss=42.1, vlm_loss=-51.1]
Epoch 8: 100%|██████████| 1/1 [00:42<00:00, 42.63s/it, ae_loss=38.4, vlm_loss=-51.1]
Epoch 9: 100%|██████████| 1/1 [00:45<00:00, 45.54s/it, ae_loss=35.9, vlm_loss=-51.1]
Epoch 10: 100%|██████████

In [26]:
# create directory to save the model if it doesn't exist
if not os.path.exists("models_auto_compress_online_data"):
    os.mkdir("models_auto_compress_online_data")
# save model checkpoint to models directory using current timestamp and date
torch.save(model.state_dict(), f"models_auto_compress_online_data/{time.strftime('%Y%m%d-%H%M%S')}.pth")


In [27]:
# load latest model checkpoint among all the saved models
latest_model = torch.load(max(glob.glob('models_auto_compress_online_data/*.pth'), key=os.path.getctime))
# load the model with the latest checkpoint
model.load_state_dict(latest_model)

<All keys matched successfully>

In [32]:
# Generate captions for the test dataset
generated_captions_custom_model = []
generated_captions_custom_model_pre_compression = []
# Iterate over the dataset and generate captions
for data in dataset:
    image = data['image']
    # use autoencoder to encode and decode the caption
    caption = generate_caption(image)
    generated_captions_custom_model_pre_compression.append(caption)
    caption = tokenizer(caption, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
    caption = caption['input_ids']
    caption = caption.to(device)
    print (caption)
    compressed_caption = autoencoder.encode(caption)
    compressed_caption = compressed_caption.to(device)
    # print (compressed_caption)
    reconstructed_caption = autoencoder.decode(compressed_caption)
    reconstructed_caption = reconstructed_caption.to(device)
    reconstructed_caption = reconstructed_caption.cpu()
    reconstructed_caption = reconstructed_caption.detach().numpy()
    reconstructed_caption = np.argmax(reconstructed_caption, axis=2)
    print (reconstructed_caption)
    reconstructed_caption = tokenizer.decode(reconstructed_caption[0], skip_special_tokens=True)
    generated_captions_custom_model.append(reconstructed_caption)

tensor([[   64,  4077,  7779, 19584,  1306,   284,   257, 20799,   220, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 5

In [None]:
# Encode compressed dictionary word using manual huffman encoding

In [None]:
# Replace compressed_dict words occurring in the generated_captions_custom_model with their corresponding huffman encoding

In [None]:
# compare encoded generated_captions_custom_model + huffman encoding dictionary information with the original generated_captions to calculate compression ratio

In [33]:
# print generated_captions and generated_captions_custom_model elementwise to compare the results
for i in range(len(generated_captions)-1):
    print (generated_captions[i], generated_captions_custom_model_pre_compression[i], generated_captions_custom_model[i])

a green truck parked next to a curb  a green truck parked next to a curb  a green green parked a a a                                                                                                                         
a man is walking down the street with a skateboard  a man is walking down the street with a skateboard  a man is is a a a a a a                                                                                                                      
a baseball player swinging a bat at a ball  a baseball player swinging a bat at a ball  a man player a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a a
a cow is standing in a field of grass  a cow is standing in a field of grass  a man is standing a a a a                                                                              

improve autoencoder architecture, use better semantic meaning preserving metric instead of simply cross entropy