Introducing manual compression of image captions on stale (offline) data

In [37]:
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
from transformers import AdamW
from datasets import load_dataset
import torch
from collections import Counter
import fiftyone
import torch.nn as nn
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch.nn.functional as F
import numpy as np
import os
import time
import glob

# uncommon features  - events of interest
# loss less compression -  sudden more bits indicates anomaly can be flagged, alerts when anomaly detected - may shift to lossy video streaming
# lossy compression of noisy data varying distortion rate - accuracy is increasing
# video to video lossy reconstruction possibility
# image frame to image frame on a need basis - human satisfaction metric, GPT based comparison, RLHF based comparison

In [38]:
# Load the pre-trained model and its components
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")



In [39]:
# Load a dataset (for example, a subset of the COCO dataset)
# TODO: Potential datasets with repititive nature that can be used: MS COCO, Flickr30k, Visual Genome, SBU Captions 

# load small part of the coco dataset from all the .jpg images in datasets/mscoco/test2015
dataset = load_dataset("datasets/mscoco/test2015/", split="test[:2]")

Resolving data files: 100%|██████████| 81434/81434 [00:00<00:00, 486913.97it/s]


In [54]:
def generate_caption_with_logits(image, max_length=128):
    # Prepare the inputs
    inputs = feature_extractor(images=image, return_tensors="pt")
    pixel_values = inputs.pixel_values

    model.eval()
    with torch.no_grad():
        # Perform a forward pass to get the logits
        encoder_outputs = model.encoder(pixel_values=pixel_values)
        encoder_hidden_states = encoder_outputs.last_hidden_state
        
        # Prepare decoder input_ids. Typically, you start with the start-of-sentence token
        decoder_input_ids = torch.tensor([tokenizer.bos_token_id]).unsqueeze(0).to(encoder_hidden_states.device)
        decoder_attention_mask = torch.ones_like(decoder_input_ids)
        
        # Initialize an empty tensor for logits (for simplicity, accumulating logits for each step)
        logits_list = []
        
        for i in range(max_length):
            decoder_outputs = model.decoder(input_ids=decoder_input_ids,
                                            attention_mask=decoder_attention_mask,
                                            encoder_hidden_states=encoder_hidden_states)
            logits = decoder_outputs.logits[:, -1, :]  # Get the logits for the last token generated
            logits_list.append(logits)
            
            predicted_id = torch.argmax(logits, dim=-1).unsqueeze(-1)
            # Check if EOS token is generated
            if predicted_id[0, 0] == tokenizer.eos_token_id:
                print ("EOS has been generated")
            
            # Append predicted token ID to decoder_input_ids for generating next token
            decoder_input_ids = torch.cat([decoder_input_ids, predicted_id], dim=-1)
            decoder_attention_mask = torch.cat([decoder_attention_mask, torch.ones_like(predicted_id)], dim=-1)
            
        # Concatenate logits from each step to get the final logits tensor
        # make all elements of logits_list 3D by adding a dimension in the middle
        logits_list = [logits.unsqueeze(1) for logits in logits_list]
        logits = torch.cat(logits_list, dim=1)
        # Decode the generated token IDs to get the caption
        predicted_ids = torch.argmax(logits, dim=-1)
        caption = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
        
    return logits, predicted_ids, caption

# Example usage
# image: A PIL image or a tensor representing your input image
# logits, predicted_ids, caption = generate_caption_with_logits(image, model, feature_extractor, tokenizer)


In [55]:
# Iterate over the dataset and generate captions
generated_captions = []
generated_logits = []
generated_predicted_ids = []

for data in dataset:
    image = data['image']
    logits, predicted_ids, caption = generate_caption_with_logits(image)
    generated_captions.append(caption)
    generated_logits.append(logits)
    generated_predicted_ids.append(predicted_ids)

# concatenate generated logits along first dimension to make 3D tensor
generated_logits = torch.cat(generated_logits, dim=0)
print (generated_logits.shape)

EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
EOS has been generated
torch.Size([2, 128, 50257])


In [56]:
generated_captions

['a green truck parked next to a curb a green truck parked next to a fence a green truck parked next to a fence a parking meter on a street a green truck parked next to a fence a green truck parked next to a fence a green truck parked next to a fence a green truck parked next to a fence a green truck parked next to a fence a green truck parked next to a fence a green truck parked next to a fence a green truck parked next to a fence a green truck parked next to a fence ',
 'a man is walking down the street with a skateboard a man is crossing the street in front of a traffic light a man is crossing the street with a bicycle a man is crossing the street with a car a man is crossing the street with a car a man is crossing the street with a bike a man is crossing the street with a car a man is crossing the street with a car a man is crossing the street with a car a man is crossing the street with a car a man is crossing the street with a car a man']

In [6]:
def update_encoding_dict(captions, encoding_dict):
    for caption in captions:
        words = caption.split()
        encoding_dict.update(words) # purpose of update is to add the words to the dictionary if they don't exist
    return encoding_dict

In [7]:
encoding_dict = Counter() # Counter is a subclass of dictionary for counting hashable objects
threshold = 0 # threshold for word frequency # TODO: find a good threshold

update_encoding_dict(generated_captions, encoding_dict)

print (encoding_dict)

# Optionally, create a more compressed form based on frequency
compressed_dict = {word: idx for idx, (word, freq) in enumerate(encoding_dict.items()) if freq > threshold}

# Create the dictionary of entropy values from encoding_dict
entropy_dict = {word: -np.log(encoding_dict[word] / sum(encoding_dict.values())) 
                for word in encoding_dict}

print (entropy_dict)
# print 1/elem for elem in encoding_dict.values()
reciprocal_dict = {word: 1/(encoding_dict[word]+1) for word in encoding_dict}
print (reciprocal_dict)

Counter({'a': 49, 'green': 12, 'truck': 12, 'parked': 12, 'next': 12, 'to': 12, 'street': 12, 'man': 12, 'fence': 11, 'is': 11, 'the': 11, 'with': 10, 'crossing': 10, 'car': 7, 'curb': 1, 'parking': 1, 'meter': 1, 'on': 1, 'walking': 1, 'down': 1, 'skateboard': 1, 'in': 1, 'front': 1, 'of': 1, 'traffic': 1, 'light': 1, 'bicycle': 1, 'bike': 1})
{'a': 1.4408984951547426, 'green': 2.847812143477369, 'truck': 2.847812143477369, 'parked': 2.847812143477369, 'next': 2.847812143477369, 'to': 2.847812143477369, 'curb': 5.332718793265369, 'fence': 2.9348235204669986, 'parking': 5.332718793265369, 'meter': 5.332718793265369, 'on': 5.332718793265369, 'street': 2.847812143477369, 'man': 2.847812143477369, 'is': 2.9348235204669986, 'walking': 5.332718793265369, 'down': 5.332718793265369, 'the': 2.9348235204669986, 'with': 3.0301337002713233, 'skateboard': 5.332718793265369, 'crossing': 3.0301337002713233, 'in': 5.332718793265369, 'front': 5.332718793265369, 'of': 5.332718793265369, 'traffic': 5.33

In [8]:
class CaptionDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, captions):
        self.encodings = encodings
        self.captions = captions

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.captions[idx])
        return item

    def __len__(self):
        return len(self.captions)

In [9]:
# Assuming `dataset` is your dataset containing images and captions
images = [data['image'] for data in dataset]
caption_ids = generated_predicted_ids

# Process images and captions
inputs = feature_extractor(images=images, return_tensors="pt")

# Create dataset and dataloader
train_dataset = CaptionDataset(inputs, caption_ids)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [10]:
class LoRALayer(nn.Module):
    def __init__(self, original_weight, rank):
        super(LoRALayer, self).__init__()
        self.original_weight = original_weight
        self.rank = rank
        self.U = nn.Parameter(torch.Tensor(self.original_weight.size(0), self.rank))
        self.V = nn.Parameter(torch.Tensor(self.rank, self.original_weight.size(1)))
        nn.init.xavier_uniform_(self.U)
        nn.init.xavier_uniform_(self.V)

    def forward(self):
        return self.original_weight + self.U @ self.V

In [11]:
# Modify the first attention layer of the encoder
# TODO: Try modifying other layers as well and check the results
lora_layers = []

with torch.no_grad():
    original_weight = model.encoder.encoder.layer[0].attention.output.dense.weight
    lora_layer = LoRALayer(original_weight, rank=10).forward()  # Choose an appropriate rank
    # assign the new layer to the model
    model.encoder.encoder.layer[0].attention.output.dense.weight.copy_(lora_layer)
    # add the layer of the model to the list of LoRA layers
    lora_layers.append(model.encoder.encoder.layer[0].attention.output.dense)

In [12]:
def is_lora_param(param, lora_layer):
    # check if the parameter is part of the LoRA layer
    print (lora_layer.parameters())
    print ("nuj")
    print (param)
    return param in lora_layer.parameters()


In [13]:
# add two extra dimensions to generated_logits
generated_probs = F.softmax(generated_logits, dim=-1)
generated_probs_expanded = generated_probs.unsqueeze(0).unsqueeze(0)

In [14]:
def calculate_entropy_elbo_difference (prob_differences, D):
    sigma = 0.01
    # reduce prob_differences to 4D from 5D by taking norm square along the last dimension
    prob_differences = torch.norm(prob_differences, dim=-1)
    print (prob_differences.shape)
    # do elementwise for prob_differences: suqare
    prob_differences = prob_differences**2
    # take sum of all elements of prob_differences, hence scalar, then divide by 2*sigma^2*D
    return torch.sum(prob_differences) / (2*sigma**2*D)

In [15]:
def calculate_entropy_elbo_cross_entropy (prob_differences, D):
    pass

In [16]:
def custom_loss(outputs, batch, encoding_dict, lora_layers, lambda_val=1, lora_lambda_val = 0.01):
    # Standard captioning loss
    standard_loss = outputs.loss

    # Additional compression loss
    compression_loss = 0
    # add two dimensions to output probs at 2 and 3
    outputs_probs = F.softmax(outputs.logits, dim=-1)
    outputs_probs_expanded = outputs_probs.squeeze(1).unsqueeze(2).unsqueeze(3)
    prob_differences = generated_probs_expanded - outputs_probs_expanded
    print ("prob_differences.shape = ", outputs_probs.shape, generated_probs_expanded.shape, outputs_probs_expanded.shape, prob_differences.shape)
    # calculate the compression loss
    # find number of elements in generated_predicted_logits
    D = generated_probs.numel()
    compression_loss = lambda_val* calculate_entropy_elbo_difference (prob_differences, D)
     

    # Optionally, add a term for LoRA regularization if needed
    lora_regularization = 0
    # for param in model.parameters():
    #     for lora_layer in lora_layers:
    #         if is_lora_param(param, lora_layer):
    #             lora_regularization += torch.norm(param)
    print (standard_loss, compression_loss)

    return standard_loss + compression_loss + lora_lambda_val * lora_regularization

In [17]:
# Fine tuning using custom loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.train()

lr = 1e-4
num_epochs = 30

optimizer = AdamW([param for param in model.parameters() if param.requires_grad], lr=lr)

for epoch in range(num_epochs):
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # Forward pass
        outputs = model(**batch)
        loss = custom_loss(outputs, batch, encoding_dict, lora_layers)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update progress bar
        loop.set_description(f"Epoch {epoch}")
        loop.set_postfix(loss=loss.item())

  0%|          | 0/1 [00:00<?, ?it/s]

  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
  item['labels'] = torch.tensor(self.captions[idx])
We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(5.6478, grad_fn=<NllLossBackward0>) tensor(13.5347, grad_fn=<MulBackward0>)


Epoch 0: 100%|██████████| 1/1 [02:21<00:00, 141.98s/it, loss=19.2]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(6.8851, grad_fn=<NllLossBackward0>) tensor(12.0999, grad_fn=<MulBackward0>)


Epoch 1: 100%|██████████| 1/1 [02:11<00:00, 131.35s/it, loss=19]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(6.8987, grad_fn=<NllLossBackward0>) tensor(11.3557, grad_fn=<MulBackward0>)


Epoch 2: 100%|██████████| 1/1 [02:04<00:00, 124.16s/it, loss=18.3]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(6.3236, grad_fn=<NllLossBackward0>) tensor(11.3238, grad_fn=<MulBackward0>)


Epoch 3: 100%|██████████| 1/1 [01:59<00:00, 119.10s/it, loss=17.6]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(5.9397, grad_fn=<NllLossBackward0>) tensor(11.2387, grad_fn=<MulBackward0>)


Epoch 4: 100%|██████████| 1/1 [02:12<00:00, 132.64s/it, loss=17.2]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(5.5605, grad_fn=<NllLossBackward0>) tensor(11.1801, grad_fn=<MulBackward0>)


Epoch 5: 100%|██████████| 1/1 [02:48<00:00, 168.98s/it, loss=16.7]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(5.1987, grad_fn=<NllLossBackward0>) tensor(11.1533, grad_fn=<MulBackward0>)


Epoch 6: 100%|██████████| 1/1 [02:05<00:00, 125.18s/it, loss=16.4]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(4.8477, grad_fn=<NllLossBackward0>) tensor(11.0438, grad_fn=<MulBackward0>)


Epoch 7: 100%|██████████| 1/1 [01:45<00:00, 105.74s/it, loss=15.9]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(4.5807, grad_fn=<NllLossBackward0>) tensor(11.0147, grad_fn=<MulBackward0>)


Epoch 8: 100%|██████████| 1/1 [02:19<00:00, 139.08s/it, loss=15.6]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(4.3275, grad_fn=<NllLossBackward0>) tensor(10.9365, grad_fn=<MulBackward0>)


Epoch 9: 100%|██████████| 1/1 [02:03<00:00, 123.94s/it, loss=15.3]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(4.1134, grad_fn=<NllLossBackward0>) tensor(11.0169, grad_fn=<MulBackward0>)


Epoch 10: 100%|██████████| 1/1 [02:17<00:00, 137.83s/it, loss=15.1]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.8899, grad_fn=<NllLossBackward0>) tensor(10.9574, grad_fn=<MulBackward0>)


Epoch 11: 100%|██████████| 1/1 [02:45<00:00, 165.31s/it, loss=14.8]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.6850, grad_fn=<NllLossBackward0>) tensor(10.8558, grad_fn=<MulBackward0>)


Epoch 12: 100%|██████████| 1/1 [02:08<00:00, 128.93s/it, loss=14.5]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.5318, grad_fn=<NllLossBackward0>) tensor(10.9320, grad_fn=<MulBackward0>)


Epoch 13: 100%|██████████| 1/1 [02:04<00:00, 124.18s/it, loss=14.5]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.4499, grad_fn=<NllLossBackward0>) tensor(10.9252, grad_fn=<MulBackward0>)


Epoch 14: 100%|██████████| 1/1 [02:01<00:00, 121.36s/it, loss=14.4]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.3459, grad_fn=<NllLossBackward0>) tensor(11.0252, grad_fn=<MulBackward0>)


Epoch 15: 100%|██████████| 1/1 [02:23<00:00, 143.34s/it, loss=14.4]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.2740, grad_fn=<NllLossBackward0>) tensor(10.9399, grad_fn=<MulBackward0>)


Epoch 16: 100%|██████████| 1/1 [02:58<00:00, 178.81s/it, loss=14.2]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.2406, grad_fn=<NllLossBackward0>) tensor(10.9488, grad_fn=<MulBackward0>)


Epoch 17: 100%|██████████| 1/1 [02:32<00:00, 152.88s/it, loss=14.2]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.1527, grad_fn=<NllLossBackward0>) tensor(10.9427, grad_fn=<MulBackward0>)


Epoch 18: 100%|██████████| 1/1 [02:16<00:00, 136.06s/it, loss=14.1]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.1560, grad_fn=<NllLossBackward0>) tensor(10.9131, grad_fn=<MulBackward0>)


Epoch 19: 100%|██████████| 1/1 [02:06<00:00, 126.11s/it, loss=14.1]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.1316, grad_fn=<NllLossBackward0>) tensor(10.9292, grad_fn=<MulBackward0>)


Epoch 20: 100%|██████████| 1/1 [02:57<00:00, 177.76s/it, loss=14.1]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.1251, grad_fn=<NllLossBackward0>) tensor(10.9089, grad_fn=<MulBackward0>)


Epoch 21: 100%|██████████| 1/1 [02:16<00:00, 136.54s/it, loss=14]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.0652, grad_fn=<NllLossBackward0>) tensor(10.9100, grad_fn=<MulBackward0>)


Epoch 22: 100%|██████████| 1/1 [02:06<00:00, 126.33s/it, loss=14]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.0368, grad_fn=<NllLossBackward0>) tensor(10.9073, grad_fn=<MulBackward0>)


Epoch 23: 100%|██████████| 1/1 [01:38<00:00, 98.22s/it, loss=13.9]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.0146, grad_fn=<NllLossBackward0>) tensor(10.8928, grad_fn=<MulBackward0>)


Epoch 24: 100%|██████████| 1/1 [01:50<00:00, 110.25s/it, loss=13.9]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.0214, grad_fn=<NllLossBackward0>) tensor(10.9012, grad_fn=<MulBackward0>)


Epoch 25: 100%|██████████| 1/1 [02:29<00:00, 149.22s/it, loss=13.9]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.0015, grad_fn=<NllLossBackward0>) tensor(10.8813, grad_fn=<MulBackward0>)


Epoch 26: 100%|██████████| 1/1 [02:52<00:00, 172.23s/it, loss=13.9]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.0193, grad_fn=<NllLossBackward0>) tensor(10.8859, grad_fn=<MulBackward0>)


Epoch 27: 100%|██████████| 1/1 [02:20<00:00, 140.95s/it, loss=13.9]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.0348, grad_fn=<NllLossBackward0>) tensor(10.8798, grad_fn=<MulBackward0>)


Epoch 28: 100%|██████████| 1/1 [02:21<00:00, 141.93s/it, loss=13.9]
  0%|          | 0/1 [00:00<?, ?it/s]

prob_differences.shape =  torch.Size([2, 1, 128, 50257]) torch.Size([1, 1, 2, 128, 50257]) torch.Size([2, 128, 1, 1, 50257]) torch.Size([2, 128, 2, 128, 50257])
torch.Size([2, 128, 2, 128])
tensor(3.0162, grad_fn=<NllLossBackward0>) tensor(10.8406, grad_fn=<MulBackward0>)


Epoch 29: 100%|██████████| 1/1 [02:10<00:00, 130.32s/it, loss=13.9]


In [21]:
# create directory to save the model if it doesn't exist
if not os.path.exists("models"):
    os.mkdir("models")
# save model checkpoint to models directory using current timestamp and date
torch.save(model.state_dict(), f"models/{time.strftime('%Y%m%d-%H%M%S')}.pth")


In [22]:
# load latest model checkpoint among all the saved models
latest_model = torch.load(max(glob.glob('models/*.pth'), key=os.path.getctime))
# load the model with the latest checkpoint
model.load_state_dict(latest_model)

<All keys matched successfully>

In [23]:
# Generate captions for the test dataset
generated_captions_custom_model = []
# Iterate over the dataset and generate captions
for data in dataset:
    image = data['image']
    generated_logits, generated_predicted_ids, caption = generate_caption_with_logits(image)
    generated_captions_custom_model.append(caption)


In [None]:
# Encode compressed dictionary word using manual huffman encoding

In [None]:
# Replace compressed_dict words occurring in the generated_captions_custom_model with their corresponding huffman encoding

In [None]:
# compare encoded generated_captions_custom_model + huffman encoding dictionary information with the original generated_captions to calculate compression ratio

In [36]:
# print generated_captions and generated_captions_custom_model elementwise to compare the results
for i in range(len(generated_captions)):
    print (generated_captions[i], "WAIT", generated_captions_custom_model[i])

a green fence next to a street next to a fence  WAIT a green fence next to a street next to a fence 
man crossing the street with a street light  WAIT man crossing the street with a street light 
