In [1]:
import torch
print(torch.__version__)


2.4.0


In [3]:
import sys
sys.path.append(r"C:\Users\vdsha\CLIP")
from clip import clip


In [5]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from clip import clip

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load CLIP model and preprocessing pipeline
clip_model, preprocess = clip.load("ViT-B/32", device=device)

# Load GPT-2 tokenizer and model
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
#@title Model
import torch.nn as nn
from torch.nn import functional as F
from transformers import GPT2LMHeadModel
from typing import Tuple, Optional

import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import GPT2LMHeadModel
from typing import Tuple, Optional

#@title Model
T = torch.Tensor
D = torch.device

class MLP(nn.Module):

    def forward(self, x: T) -> T:
        return self.model(x)

    def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(sizes) -1):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
            if i < len(sizes) - 2:
                layers.append(act())
        self.model = nn.Sequential(*layers)


class ClipCaptionModel(nn.Module):

    #@functools.lru_cache #FIXME
    def get_dummy_token(self, batch_size: int, device: D) -> T:
        return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)

    def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None):
        embedding_text = self.gpt.transformer.wte(tokens)
        prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
        
        #print(embedding_text.size()) #torch.Size([5, 67, 768])
        #print(prefix_projections.size()) #torch.Size([5, 1, 768])
        embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
        if labels is not None:
            dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
            labels = torch.cat((dummy_token, tokens), dim=1)
        out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
        return out

    def __init__(self, prefix_length: int, prefix_size: int = 512):
        super(ClipCaptionModel, self).__init__()
        self.prefix_length = prefix_length
        self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
        self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
        if prefix_length > 10:  # not enough memory
            self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length)
        else:
            self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length))


class ClipCaptionPrefix(ClipCaptionModel):

    def parameters(self, recurse: bool = True):
        return self.clip_project.parameters()

    def train(self, mode: bool = True):
        super(ClipCaptionPrefix, self).train(mode)
        self.gpt.eval()
        return self

In [9]:
from tqdm import tqdm
import json
import os
from PIL import Image
from torch.utils.data import DataLoader, Dataset

In [11]:
annotations_path = {'train':r"D:\vidisha\vizwiz\annotations\train.json"}
images_path = {'train':r"D:\vidisha\vizwiz\train"}

In [13]:
import json

# Path to the annotations JSON file
annotation_file = annotations_path['train']

# Load the JSON data
with open(annotation_file, 'r') as f:
    data = json.load(f)

# Check the type of the loaded data
print(f"Data type: {type(data)}")

# If it's a list, print the first element to inspect its structure
if isinstance(data, list):
    print(f"First element in the list: {data[0]}")
elif isinstance(data, dict):
    # If it's a dictionary, print the first key-value pair
    first_key = next(iter(data))
    print(f"First key: {first_key}")
    print(f"First value: {data[first_key]}")

else:
    print(f"Unexpected data format: {type(data)}")

Data type: <class 'dict'>
First key: info
First value: {'description': 'This dataset contains crowdsourced captions of images from VizWiz datasets. This file contains the train partition.', 'license': {'url': 'https://creativecommons.org/licenses/by/4.0/', 'name': 'Attribution 4.0 International (CC BY 4.0)'}, 'url': 'https://vizwiz.org', 'version': 'VizWiz-Captions 1.0', 'year': 2019, 'contributor': 'VizWiz-Captions Consortium', 'date_created': '2019-12-23'}


In [15]:
# Inspect the keys in the data dictionary to find where the captions are stored
print(data.keys())

dict_keys(['info', 'images', 'annotations'])


In [17]:
# Inspect the structure of the annotations
print(data['annotations'][:5])  # Print the first 5 entries to get an idea of the structure

[{'caption': 'ITS IS A BASIL LEAVES CONTAINER ITS CONTAINS THE NET WEIGHT TOO.', 'image_id': 0, 'is_precanned': False, 'is_rejected': False, 'id': 0, 'text_detected': True}, {'caption': 'A green and white plastic condiment bottle containing Basil leaves.', 'image_id': 0, 'is_precanned': False, 'is_rejected': False, 'id': 1, 'text_detected': True}, {'caption': 'Quality issues are too severe to recognize visual content.', 'image_id': 0, 'is_precanned': True, 'is_rejected': True, 'id': 2, 'text_detected': True}, {'caption': 'A bottle of spices in a plastic container laying on a surface.', 'image_id': 0, 'is_precanned': False, 'is_rejected': False, 'id': 3, 'text_detected': True}, {'caption': 'some basil leaves in a container on a counter', 'image_id': 0, 'is_precanned': False, 'is_rejected': False, 'id': 4, 'text_detected': True}]


In [19]:
import pandas as pd

# Extract the annotations from the data
annotations = data['annotations']

# Create the DataFrame from the annotations list
df = pd.DataFrame(annotations)

# Optionally, reset the index and drop the 'index' column if you need to
df = df.reset_index(drop=True)

# Display the DataFrame
df.head(10)

Unnamed: 0,caption,image_id,is_precanned,is_rejected,id,text_detected
0,ITS IS A BASIL LEAVES CONTAINER ITS CONTAINS T...,0,False,False,0,True
1,A green and white plastic condiment bottle con...,0,False,False,1,True
2,Quality issues are too severe to recognize vis...,0,True,True,2,True
3,A bottle of spices in a plastic container layi...,0,False,False,3,True
4,some basil leaves in a container on a counter,0,False,False,4,True
5,A can of Coca Cola on a counter is shown for w...,1,False,False,5,True
6,A black can of Coca Cola Zero calorie soda is ...,1,False,False,6,True
7,A kitchen counter the various items on top inc...,1,False,False,7,True
8,a black tin of Coca Cola placed on a black sur...,1,False,False,8,True
9,"Black counter with canisters, kettle and can o...",1,False,False,9,True


In [21]:
import os
import collections

# Initialize the defaultdict
image_path_to_caption = collections.defaultdict(list)

# Limit to the first 20,000 rows
limit = 40000

# Iterate over the rows of the DataFrame
for n in range(min(df.shape[0], limit)):
    # image_path = os.path.abspath('.') + images_path['train'] + str(df.iloc[n]['image_id']) + '.jpg'  # Assuming .jpg format
    image_path = os.path.join(
    images_path['train'], 
    f"VizWiz_train_{df.iloc[n]['image_id']:08d}.jpg"  # Ensures zero-padding to 8 digits
    )
    caption = f'<start> {df.iloc[n]["caption"]} <end>'
    if len(caption) > 300:
        continue
    image_path_to_caption[image_path].append(caption)

# Convert to lists as before
captions = []
img_name_vector = []
for image_path, caption_list in image_path_to_caption.items():
    captions.extend(caption_list)
    img_name_vector.extend([image_path] * len(caption_list))

# print("First 5 captions:", captions[:5])
# print("First 5 image paths:", img_name_vector[:5])

num_captions = len(captions)
num_images = len(img_name_vector)

print(f"Number of captions: {num_captions}")
print(f"Number of image paths: {num_images}")

# Print the first 5 entries of the dictionary for verification
for i, (image_path, captions) in enumerate(image_path_to_caption.items()):
    if i < 5:  # Limiting to first 5 entries
        print(image_path, captions)
    else:
        break


Number of captions: 39994
Number of image paths: 39994
D:\vidisha\vizwiz\train\VizWiz_train_00000000.jpg ['<start> ITS IS A BASIL LEAVES CONTAINER ITS CONTAINS THE NET WEIGHT TOO. <end>', '<start> A green and white plastic condiment bottle containing Basil leaves. <end>', '<start> Quality issues are too severe to recognize visual content. <end>', '<start> A bottle of spices in a plastic container laying on a surface. <end>', '<start> some basil leaves in a container on a counter <end>']
D:\vidisha\vizwiz\train\VizWiz_train_00000001.jpg ['<start> A can of Coca Cola on a counter is shown for when one can use a nice, cold drink. <end>', '<start> A black can of Coca Cola Zero calorie soda is on the counter near the coffee maker. <end>', '<start> A kitchen counter the various items on top including a can of Coca-Cola, metal containers, and a teapot. <end>', '<start> a black tin of Coca Cola placed on a black surface <end>', '<start> Black counter with canisters, kettle and can of soda. <end

In [22]:
# Now we separate the captions and image paths into two lists
captions = []
img_name_vector = []

for image_path, caption_list in image_path_to_caption.items():
    captions.extend(caption_list)
    img_name_vector.extend([image_path] * len(caption_list))  # Add image path multiple times for each caption

# Check the distribution of caption lengths
sizes_to_indices = collections.defaultdict(list)
for index, caption in enumerate(captions):
    size = len(caption)
    sizes_to_indices[size].append(index)

# Inspect caption lengths and their distribution
print("Caption sizes and their frequencies:")
for size, indices in sizes_to_indices.items():
    print(f"Size {size}: {len(indices)} captions")

Caption sizes and their frequencies:
Size 78: 621 captions
Size 81: 564 captions
Size 72: 5801 captions
Size 76: 794 captions
Size 59: 913 captions
Size 95: 244 captions
Size 97: 215 captions
Size 118: 87 captions
Size 64: 1093 captions
Size 67: 1010 captions
Size 112: 95 captions
Size 75: 869 captions
Size 68: 925 captions
Size 61: 998 captions
Size 57: 862 captions
Size 71: 922 captions
Size 63: 1080 captions
Size 84: 403 captions
Size 66: 988 captions
Size 146: 28 captions
Size 62: 1010 captions
Size 77: 660 captions
Size 70: 934 captions
Size 91: 298 captions
Size 69: 955 captions
Size 83: 489 captions
Size 58: 818 captions
Size 65: 1088 captions
Size 53: 460 captions
Size 50: 258 captions
Size 94: 233 captions
Size 51: 294 captions
Size 54: 572 captions
Size 49: 174 captions
Size 73: 838 captions
Size 89: 337 captions
Size 82: 472 captions
Size 92: 285 captions
Size 60: 992 captions
Size 88: 352 captions
Size 79: 653 captions
Size 90: 294 captions
Size 104: 155 captions
Size 96: 2

In [25]:
captions_and_images_validation = captions, img_name_vector

sizes_to_indices_again = collections.defaultdict(list)
for index, caption in enumerate(captions_and_images_validation[0]):
  size = len(caption)
  sizes_to_indices_again[size].append(index)
sorted(list(sizes_to_indices_again.keys()),reverse=True)[:]

[293,
 288,
 286,
 282,
 281,
 279,
 275,
 274,
 272,
 270,
 269,
 268,
 266,
 265,
 264,
 262,
 257,
 254,
 253,
 250,
 249,
 248,
 247,
 245,
 243,
 242,
 241,
 240,
 238,
 236,
 235,
 234,
 233,
 232,
 231,
 229,
 228,
 227,
 226,
 225,
 224,
 223,
 221,
 220,
 219,
 218,
 217,
 216,
 214,
 213,
 212,
 211,
 210,
 209,
 208,
 206,
 205,
 204,
 203,
 201,
 200,
 199,
 198,
 197,
 196,
 194,
 193,
 192,
 191,
 190,
 189,
 188,
 187,
 186,
 185,
 184,
 183,
 182,
 181,
 180,
 179,
 178,
 177,
 176,
 175,
 174,
 173,
 172,
 171,
 170,
 169,
 168,
 167,
 166,
 165,
 164,
 163,
 162,
 161,
 160,
 159,
 158,
 157,
 156,
 155,
 154,
 153,
 152,
 151,
 150,
 149,
 148,
 147,
 146,
 145,
 144,
 143,
 142,
 141,
 140,
 139,
 138,
 137,
 136,
 135,
 134,
 133,
 132,
 131,
 130,
 129,
 128,
 127,
 126,
 125,
 124,
 123,
 122,
 121,
 120,
 119,
 118,
 117,
 116,
 115,
 114,
 113,
 112,
 111,
 110,
 109,
 108,
 107,
 106,
 105,
 104,
 103,
 102,
 101,
 100,
 99,
 98,
 97,
 96,
 95,
 94,
 93,
 92,


In [27]:
total =  0
for key in sizes_to_indices.keys():
    total += len(sizes_to_indices[key])
print(total)

39994


In [29]:
from PIL import Image
from torch.utils.data import Dataset
import torch

class VizWizDataset(Dataset):
    def __init__(self, img_paths, captions, clip_model, preprocess, tokenizer, max_length=50):
        self.img_paths = img_paths  # List of image paths
        self.captions = captions  # List of captions corresponding to the images
        self.clip_model = clip_model
        self.preprocess = preprocess
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        # Get image path and caption at the specified index
        image_path = self.img_paths[idx]
        caption = self.captions[idx]

        # Load and preprocess image
        image = Image.open(image_path).convert("RGB")
        image = self.preprocess(image)

        # Tokenize caption
        tokens = self.tokenizer(
            caption,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return (image, tokens['input_ids'], tokens['attention_mask'])

        # def __getitem__(self, idx):
        #     image_path = self.img_paths[idx]   # Access image path from img_paths
        #     caption = self.captions[idx]       # Access caption from captions
    
        #     # Load and preprocess image
        #     image = Image.open(image_path).convert("RGB")
        #     image = self.preprocess(image)
    
        #     # Tokenize caption
        #     tokens = self.tokenizer(
        #         caption,
        #         max_length=self.max_length,
        #         padding="max_length",
        #         truncation=True,
        #         return_tensors="pt"
        #     )
    
        #     # Return image tensor, tokenized input ids, and attention mask
        #     return image, tokens['input_ids'].squeeze(0), tokens['attention_mask'].squeeze(0)



In [31]:
# Load CLIP model and tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)

# Load GPT-2 tokenizer and add padding token
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token  # Set the padding token to be the same as EOS

# # Add special tokens BEFORE training
# tokenizer.add_special_tokens({'bos_token': '<start>', 'eos_token': '<end>', 'pad_token': '<pad>'})

# # Explicitly set a unique ID for pad_token if it's overlapping with eos_token
# tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids("<pad>")

# gpt2_tokenizer.add_special_tokens({'bos_token': '<start>', 'eos_token': '<end>', 'pad_token': '<pad>'})
# gpt2_tokenizer.pad_token_id = gpt2_tokenizer.convert_tokens_to_ids("<pad>")


# # Print token IDs to verify
# print("BOS Token ID:", gpt2_tokenizer.convert_tokens_to_ids("<start>"))
# print("EOS Token ID:", gpt2_tokenizer.convert_tokens_to_ids("<end>"))
# print("PAD Token ID:", gpt2_tokenizer.pad_token_id)


In [33]:
# Create Dataset
dataset = VizWizDataset(
    img_paths=img_name_vector,
    captions=captions,
    clip_model=clip_model,
    preprocess=preprocess,
    tokenizer=gpt2_tokenizer
)


# Test a sample
sample = dataset[0]
print("Image shape:", sample[0].shape)  # Image tensor
print("Input IDs shape:", sample[1].shape)  # Input IDs
print("Attention Mask shape:", sample[2].shape)  # Attention mask


Image shape: torch.Size([3, 224, 224])
Input IDs shape: torch.Size([1, 50])
Attention Mask shape: torch.Size([1, 50])


In [35]:
# from torch.utils.data import DataLoader

# # Create DataLoader for your dataset
# train_dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# # Test DataLoader with one batch
# batch = next(iter(train_dataloader))

# # Accessing elements of the batch
# images, input_ids, attention_mask = batch

# # Print shapes of the batch components
# print("Images shape:", images.shape)
# print("Input IDs shape:", input_ids.shape)
# print("Attention Mask shape:", attention_mask.shape)

In [37]:
# # Test DataLoader with one batch
# batch = next(iter(train_dataloader))

# # Print the structure of the batch
# print(f"Batch type: {type(batch)}")  # Should be a tuple or dictionary
# if isinstance(batch, dict):
#     print(f"Batch keys: {batch.keys()}")  # Print keys if the batch is a dictionary
# elif isinstance(batch, tuple):
#     print(f"Batch length: {len(batch)}")  # Print length if the batch is a tuple
#     # Optionally print first few elements in the tuple to inspect
#     for i, item in enumerate(batch):
#         print(f"Batch item {i}: {item.shape if isinstance(item, torch.Tensor) else 'Not a tensor'}")

In [35]:
from torch.utils.data import DataLoader

# Define custom collate_fn to return tuple instead of list
def collate_fn(batch):
    images, input_ids, attention_masks = zip(*batch)  # Unzip the batch
    images = torch.stack(images)  # Stack images into a batch tensor
    input_ids = torch.stack(input_ids)  # Stack input_ids
    attention_masks = torch.stack(attention_masks)  # Stack attention_masks
    return images, input_ids, attention_masks

# Now create the DataLoader
train_dataloader = DataLoader(
    dataset, 
    batch_size=16, 
    shuffle=True, 
    collate_fn=collate_fn , # Use the custom collate_fn
    drop_last=True
)

# Test the batch structure
batch = next(iter(train_dataloader))
print(f"Batch type: {type(batch)}")
if isinstance(batch, tuple):
    print(f"Batch length: {len(batch)}")
    for i, item in enumerate(batch):
        print(f"Batch item {i}: {item.shape if isinstance(item, torch.Tensor) else 'Not a tensor'}")


Batch type: <class 'tuple'>
Batch length: 3
Batch item 0: torch.Size([16, 3, 224, 224])
Batch item 1: torch.Size([16, 1, 50])
Batch item 2: torch.Size([16, 1, 50])


**Setting up validation dataset**

In [38]:
annotations_path = {'train':r"D:\vidisha\vizwiz\annotations\train.json", 'val' :r"D:\vidisha\vizwiz\annotations\val.json"}
images_path = {'train':r"D:\vidisha\vizwiz\train", 'val' :r"D:\vidisha\vizwiz\val"}

In [40]:
import json

# Path to the annotations JSON file
annotation_file = annotations_path['val']

# Load the JSON data
with open(annotation_file, 'r') as f:
    data = json.load(f)

# Check the type of the loaded data
print(f"Data type: {type(data)}")

# If it's a list, print the first element to inspect its structure
if isinstance(data, list):
    print(f"First element in the list: {data[0]}")
elif isinstance(data, dict):
    # If it's a dictionary, print the first key-value pair
    first_key = next(iter(data))
    print(f"First key: {first_key}")
    print(f"First value: {data[first_key]}")

else:
    print(f"Unexpected data format: {type(data)}")

Data type: <class 'dict'>
First key: info
First value: {'description': 'This dataset contains crowdsourced captions of images from VizWiz datasets. This file contains the val partition.', 'license': {'url': 'https://creativecommons.org/licenses/by/4.0/', 'name': 'Attribution 4.0 International (CC BY 4.0)'}, 'url': 'https://vizwiz.org', 'version': 'VizWiz-Captions 1.0', 'year': 2019, 'contributor': 'VizWiz-Captions Consortium', 'date_created': '2019-12-23'}


In [42]:
# Inspect the keys in the data dictionary to find where the captions are stored
print(data.keys())

dict_keys(['info', 'images', 'annotations'])


In [44]:
# Inspect the structure of the annotations
print(data['annotations'][:5])  # Print the first 5 entries to get an idea of the structure

[{'caption': 'A computer screen shows a repair prompt on the screen.', 'image_id': 23431, 'is_precanned': False, 'is_rejected': False, 'id': 117155, 'text_detected': True}, {'caption': 'a computer screen with a repair automatically pop up', 'image_id': 23431, 'is_precanned': False, 'is_rejected': False, 'id': 117156, 'text_detected': True}, {'caption': 'partial computer screen showing the need of repairs', 'image_id': 23431, 'is_precanned': False, 'is_rejected': False, 'id': 117157, 'text_detected': True}, {'caption': 'Part of a computer monitor showing a computer repair message.', 'image_id': 23431, 'is_precanned': False, 'is_rejected': False, 'id': 117158, 'text_detected': True}, {'caption': 'The top of a laptop with a blue background and dark blue text.', 'image_id': 23431, 'is_precanned': False, 'is_rejected': False, 'id': 117159, 'text_detected': True}]


In [46]:
print(data['images'][:5]) 

[{'file_name': 'VizWiz_val_00000000.jpg', 'vizwiz_url': 'https://ivc.ischool.utexas.edu/VizWiz_visualization_img/VizWiz_val_00000000.jpg', 'id': 23431, 'text_detected': True}, {'file_name': 'VizWiz_val_00000001.jpg', 'vizwiz_url': 'https://ivc.ischool.utexas.edu/VizWiz_visualization_img/VizWiz_val_00000001.jpg', 'id': 23432, 'text_detected': True}, {'file_name': 'VizWiz_val_00000002.jpg', 'vizwiz_url': 'https://ivc.ischool.utexas.edu/VizWiz_visualization_img/VizWiz_val_00000002.jpg', 'id': 23433, 'text_detected': True}, {'file_name': 'VizWiz_val_00000003.jpg', 'vizwiz_url': 'https://ivc.ischool.utexas.edu/VizWiz_visualization_img/VizWiz_val_00000003.jpg', 'id': 23434, 'text_detected': True}, {'file_name': 'VizWiz_val_00000004.jpg', 'vizwiz_url': 'https://ivc.ischool.utexas.edu/VizWiz_visualization_img/VizWiz_val_00000004.jpg', 'id': 23435, 'text_detected': True}]


In [48]:
import pandas as pd
import os
import collections

# Create a mapping of image_id to file_name from the 'images' data
image_id_to_file_name = {img['id']: img['file_name'] for img in data['images']}

# Extract the annotations from the data
annotations = data['annotations']

# Create the DataFrame from the annotations list
df_val = pd.DataFrame(annotations)

# Reset the index
df_val = df_val.reset_index(drop=True)

# Initialize the defaultdict
image_path_to_caption = collections.defaultdict(list)

# Limit to the first 40,000 rows (or adjust as necessary)
limit = None

# Iterate over the rows of the DataFrame
for n in range(df_val.shape[0]):
    # Get the file name from the mapping
    file_name = image_id_to_file_name[df_val.iloc[n]['image_id']]
    
    # Construct the full image path
    image_path = os.path.join(images_path['val'], file_name)
    
    # Add the caption with <start> and <end> tokens
    caption = f"<start> {df_val.iloc[n]['caption']} <end>"
    image_path_to_caption[image_path].append(caption)

# Convert the dictionary to lists for further processing
captions_val = []
img_name_vector_val = []
for image_path, caption_list in image_path_to_caption.items():
    captions_val.extend(caption_list)
    img_name_vector_val.extend([image_path] * len(caption_list))

# Print the first 5 entries of the dictionary for verification
for i, (image_path, captions) in enumerate(image_path_to_caption.items()):
    if i < 5:  # Limiting to first 5 entries
        print(image_path, captions)
    else:
        break


D:\vidisha\vizwiz\val\VizWiz_val_00000000.jpg ['<start> A computer screen shows a repair prompt on the screen. <end>', '<start> a computer screen with a repair automatically pop up <end>', '<start> partial computer screen showing the need of repairs <end>', '<start> Part of a computer monitor showing a computer repair message. <end>', '<start> The top of a laptop with a blue background and dark blue text. <end>']
D:\vidisha\vizwiz\val\VizWiz_val_00000001.jpg ['<start> A person is holding a bottle that has medicine for the night time. <end>', '<start> A bottle of medication has a white twist top. <end>', '<start> night time medication bottle being held by someone <end>', '<start> a person holding a small black bottle of NIGHT TIME <end>', '<start> A bottle of what appears to be cough syrup held in hand. <end>']
D:\vidisha\vizwiz\val\VizWiz_val_00000002.jpg ['<start> a white paper showing an image of black and brown dog <end>', '<start> A library book with pictures of two dogs on the cov

In [50]:
# Check the number of captions and image paths
num_captions = len(captions_val)
num_images = len(img_name_vector_val)

# Print the results
print(f"Number of captions: {num_captions}")
print(f"Number of image paths: {num_images}")

Number of captions: 38750
Number of image paths: 38750


In [52]:
# Create the validation Dataset
val_dataset = VizWizDataset(
    img_paths=img_name_vector_val,  # List of image paths for validation
    captions=captions_val,  # List of captions for validation
    clip_model=clip_model,  # Your CLIP model (or any other feature extractor you're using)
    preprocess=preprocess,  # Your image preprocessing function
    tokenizer=gpt2_tokenizer  # Your tokenizer (e.g., GPT-2 tokenizer)
)

# Check dataset length (just to verify)
print(f"Validation Dataset Size: {len(val_dataset)}")


Validation Dataset Size: 38750


In [54]:
# Create the validation DataLoader
val_dataloader = DataLoader(
    val_dataset,
    batch_size=16,  # Adjust the batch size if needed
    shuffle=False,  # No need to shuffle the validation data
    collate_fn=collate_fn  # Leave as None unless you need to customize batching
)

# Test the DataLoader with one batch
batch = next(iter(val_dataloader))

# Print the structure of the batch
print(f"Batch type: {type(batch)}")
if isinstance(batch, tuple):
    print(f"Batch length: {len(batch)}")
    for i, item in enumerate(batch):
        print(f"Batch item {i}: {item.shape if isinstance(item, torch.Tensor) else 'Not a tensor'}")


Batch type: <class 'tuple'>
Batch length: 3
Batch item 0: torch.Size([16, 3, 224, 224])
Batch item 1: torch.Size([16, 1, 50])
Batch item 2: torch.Size([16, 1, 50])


In [66]:
# Initialize the model and optimizer
prefix_length = 10  # Define your prefix length based on your needs
model = ClipCaptionModel(prefix_length=prefix_length).to(device)

# Define optimizer (e.g., Adam)
# optimizer = torch.optim.Adam(model.parameters(), lr=5e-6)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-2)


# Define your loss function (e.g., CrossEntropyLoss for text generation)
criterion = torch.nn.CrossEntropyLoss()


In [64]:
# for epoch in range(1):
#     model.train()
#     total_loss = 0
    
#     for step, (images, input_ids, attention_masks) in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch + 1}")):
#         images = images.to(device)
#         input_ids = input_ids.squeeze(1).to(device)
#         attention_mask = attention_masks.squeeze(1).to(device)
        

#         # Extract features using CLIP
#         with torch.no_grad():
#             clip_features = clip_model.encode_image(images).float()  # Shape: [batch_size, clip_feature_size]
#             print("clip features:",clip_features.size())
        
#         batch_size=16
        
        
#         # Concatenate the flattened clip features with the input_ids
#         input_ids = torch.cat([clip_features, input_ids], dim=1)
#         # Ensure input_ids are of type Long (torch.int64)
#         input_ids = input_ids.long()
#         print("input ids size:",input_ids.size())


#         # Check the shape of attention_mask before squeezing
#         print("Original attention mask shape:", attention_mask.shape)


        
#         # Squeeze the middle singleton dimension (axis 1)
#         attention_mask_2d = attention_mask.squeeze(1)  # Convert to [batch_size, sequence_length]

#         # Check the new shape of the attention mask
#         print("Attention mask 2d shape:", attention_mask_2d.shape)
        
#         # Now concatenate the tensors
#         adjusted_attention_mask = torch.cat([torch.ones(batch_size, prefix_length, device=attention_mask_2d.device), attention_mask_2d], dim=1)

#         print("Adjusted attention mask size:", adjusted_attention_mask.size())
        
#         # Ensure the adjusted attention mask has the same length as input_ids
#         # assert adjusted_attention_mask.size(1) == input_ids.size(1)


        


#         print("---------------------------------")
#         # # Debugging outputs
#         # print("Prefix embeddings size:", clip_features.size())  # Should be [16, prefix_length, embedding_dim]
#         # print("Input IDs size:", input_ids.size())  # Should be [16, 50]
#         # print("Attention masks size:", attention_masks.size())  # Should be [16, 50]
#         # print("Adjusted attention mask size:", adjusted_attention_mask.size())  # Should match total sequence length

#         # Obtain GPT-2 embeddings for input tokens
#         # gpt2_embeddings = model.gpt2.transformer.wte(input_ids)  # Shape: [batch_size, sequence_length, hidden_size]
#         # Access GPT-2 embeddings during training
#         gpt2_embeddings = model.gpt.transformer.wte(input_ids)  # Shape: [batch_size, sequence_length, hidden_size]

#         print(f"GPT-2 embeddings size: {gpt2_embeddings.size()}")
        
#         # # Concatenate prefix with GPT-2 embeddings
#         # inputs_embeds = torch.cat([prefix, gpt2_embeddings], dim=1)  # Concatenate along sequence length
#         # print(f"Concatenated inputs_embeds size: {inputs_embeds.size()}")


#         # Define hidden_size based on the GPT-2 model configuration
#         hidden_size = 768
#         print(f"Hidden size: {hidden_size}")

#         print(f"clip_features shape: {clip_features.size()}")


#         # Get the output from clip_project
#         clip_project_output = model.clip_project(clip_features)
#         print(f"clip_project output size: {clip_project_output.size()}")
        
#         # Reshape clip_project output
#         prefix = clip_project_output.view(batch_size, prefix_length, hidden_size)
#         print(f"Reshaped Prefix size: {prefix.size()}")
        
#         # Concatenate prefix with GPT-2 embeddings
#         inputs_embeds = torch.cat([prefix, gpt2_embeddings], dim=1)
#         print(f"Concatenated inputs_embeds size: {inputs_embeds.size()}")

        
#         # Extend attention mask
#         prefix_mask = torch.ones(batch_size, prefix_length).to(attention_mask.device)
#         print(f"Prefix mask shape: {prefix_mask.size()}")
#         print(f"Attention mask shape: {attention_mask.size()}")

#         extended_attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
#         print(f"Extended Attention Mask size: {extended_attention_mask.size()}")





#         # Forward pass through the model
#         outputs = model(
#             tokens=input_ids,
#             prefix=clip_features,
#             mask= adjusted_attention_mask,  # Use adjusted mask
#             labels=input_ids,
#         )

#         # Compute loss
#         loss = outputs.loss / gradient_accumulation_steps

#         # Backward pass
#         loss.backward()

#         # Gradient accumulation
#         if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_dataloader):
#             torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
#             optimizer.step()
#             optimizer.zero_grad()

#         total_loss += loss.item() * gradient_accumulation_steps

#         # Logging
#         if (step + 1) % log_interval == 0:
#             avg_loss = total_loss / (step + 1)
#             print(f"Step {step + 1}, Loss: {avg_loss:.4f}")

#     # Validation
#     model.eval()
#     val_loss = 0
#     with torch.no_grad():
#         for images, input_ids, attention_masks in tqdm(val_dataloader, desc="Validation"):
#             images = images.to(device)
#             input_ids = input_ids.squeeze(1).to(device)
#             attention_masks = attention_masks.squeeze(1).to(device)

#             clip_features = clip_model.encode_image(images).float()

#             # Adjust attention mask for validation
#             prefix_length = clip_features.shape[1]
#             batch_size, token_length = input_ids.shape
#             prefix_mask = torch.ones(batch_size, prefix_length, device=input_ids.device)
#             adjusted_attention_mask = torch.cat([prefix_mask, attention_masks], dim=1)

#             outputs = model(
#                 tokens=input_ids,
#                 prefix=clip_features,
#                 mask=adjusted_attention_mask,  # Use adjusted mask
#                 labels=input_ids,
#             )
#             val_loss += outputs.loss.item()

#     avg_train_loss = total_loss / len(train_dataloader)
#     avg_val_loss = val_loss / len(val_dataloader)

#     print(f"Epoch {epoch + 1} completed. Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

# print("Training complete.")

In [66]:
# from torch.utils.data import Subset, DataLoader

# batch_size=16
# hidden_size=768
# gradient_accumulation_steps = 1  # Number of steps to accumulate gradients
# max_grad_norm = 1.0  # For gradient clipping
# log_interval = 10  # Steps after which to log training progress
# validation_interval = 50


# # Create a subset of the validation dataset with 500 samples
# subset_indices = list(range(500))  # Take the first 500 samples
# val_subset = Subset(val_dataset, subset_indices)

# # Create a dataloader for the subset
# val_subset_dataloader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

# for epoch in range(1):
#     model.train()
#     total_loss = 0
#     for step, (images, input_ids, attention_masks) in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch + 1}")):
#         images = images.to(device)
#         input_ids = input_ids.squeeze(1).to(device)
#         attention_mask = attention_masks.squeeze(1).to(device)
    
#         # Extract CLIP features
#         with torch.no_grad():
#             clip_features = clip_model.encode_image(images).float()  # Shape: [batch_size, clip_feature_size]
    
#         # Project CLIP features into prefix embeddings
#         clip_project_output = model.clip_project(clip_features)  # Shape: [batch_size, gpt_embedding_size * prefix_length]
#         prefix = clip_project_output.view(batch_size, prefix_length, hidden_size)  # Shape: [batch_size, prefix_length, hidden_size]
    
#         # Obtain GPT-2 embeddings for tokens
#         gpt2_embeddings = model.gpt.transformer.wte(input_ids)  # Shape: [batch_size, sequence_length, hidden_size]
    
#         # Concatenate prefix with GPT-2 embeddings
#         inputs_embeds = torch.cat([prefix, gpt2_embeddings], dim=1)  # Shape: [batch_size, prefix_length + sequence_length, hidden_size]
    
#         # Adjust attention mask
#         prefix_mask = torch.ones(batch_size, prefix_length, device=attention_mask.device)
#         adjusted_attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)
    
#         # Forward pass
#         outputs = model(
#             tokens=input_ids,  # Pass None to use inputs_embeds
#             prefix=clip_features,  # Not needed here since we pass inputs_embeds
#             mask=adjusted_attention_mask,
#             labels=input_ids,
#         )
    
#         # Compute loss
#         loss = outputs.loss / gradient_accumulation_steps
       
    
#         # Backward pass
#         loss.backward()
    
#         # Gradient accumulation
#         if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_dataloader):
#             torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
#             optimizer.step()
#             optimizer.zero_grad()
    
#         total_loss += loss.item() * gradient_accumulation_steps
    
#         # Logging
#         if (step + 1) % log_interval == 0:
#             avg_loss = total_loss / (step + 1)
#             print(f"Step {step + 1}, Loss: {avg_loss:.4f}")

#         # Perform validation after every 'validation_interval' steps
#         if (step + 1) % validation_interval == 0:
#             model.eval()
#             val_loss = 0
#             with torch.no_grad():
#                 for val_images, val_input_ids, val_attention_masks in tqdm(val_subset_dataloader, desc="Validation", leave = False, disable =True):
#                     val_images = val_images.to(device)
#                     val_input_ids = val_input_ids.squeeze(1).to(device)
#                     val_attention_masks = val_attention_masks.squeeze(1).to(device)

#                     val_clip_features = clip_model.encode_image(val_images).float()

#                     # Adjust validation attention mask
#                     val_prefix_mask = torch.ones(val_images.size(0), prefix_length, device=val_attention_masks.device)
#                     val_adjusted_attention_mask = torch.cat([val_prefix_mask, val_attention_masks], dim=1)

#                     val_outputs = model(
#                         tokens=val_input_ids,
#                         prefix=val_clip_features,
#                         mask=val_adjusted_attention_mask,
#                         labels=val_input_ids,
#                     )
#                     val_loss += val_outputs.loss.item()

#             avg_val_loss = val_loss / len(val_dataloader)
#             print(f"Validation after Step {step + 1}: Avg Val Loss: {avg_val_loss:.4f}")

#             model.train()  # Switch back to training mode

#     # Epoch summary
#     avg_train_loss = total_loss / len(train_dataloader)
#     print(f"Epoch {epoch + 1} completed. Train Loss: {avg_train_loss:.4f}")
    


In [68]:
# Debug: Check if Adam's internal state (momentum, variance estimates) is restored
if optimizer.state_dict()["state"]:
    for group in optimizer.state_dict()["state"].values():
        print(group.keys())  # Should include "exp_avg" and "exp_avg_sq"

In [None]:
import os
import torch
from torch.utils.data import Subset, DataLoader

batch_size=16
hidden_size=768
gradient_accumulation_steps = 1  # Number of steps to accumulate gradients
max_grad_norm = 1.0  # For gradient clipping
log_interval = 10  # Steps after which to log training progress
validation_interval = 100

# Create a subset of the validation dataset with 500 samples
subset_indices = list(range(500))  # Take the first 500 samples
val_subset = Subset(val_dataset, subset_indices)

# Create a dataloader for the subset
val_subset_dataloader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

# Checkpoint path
checkpoint_path = r"C:\Users\vdsha\Downloads\model_checkpoint_epoch4.pth"

# Function to save checkpoint
def save_checkpoint(model, optimizer, epoch, step, total_loss, checkpoint_path):
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "step": step,
        "total_loss": total_loss,
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved at epoch {epoch}, step {step}.")

# Function to load checkpoint
def load_checkpoint(checkpoint_path, model, optimizer):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        epoch = checkpoint["epoch"]
        step = checkpoint["step"]
        total_loss = checkpoint["total_loss"]
        print(f"Checkpoint loaded. Resuming from epoch {epoch}, step {step}.")
        return epoch, step, total_loss
    else:
        print("No checkpoint found. Starting from scratch.")
        return 0, 0, 0.0

# Load checkpoint if exists
start_epoch, start_step, total_loss = load_checkpoint(checkpoint_path, model, optimizer)
# # Debug: Check if Adam's internal state (momentum, variance estimates) is restored
# if optimizer.state_dict()["state"]:
#     for group in optimizer.state_dict()["state"].values():
#         print(group.keys())  # Should include "exp_avg" and "exp_avg_sq"


try:
    for epoch in range(start_epoch, 8):  # Adjust range as needed
        model.train()
        total_loss = 0.0  # Reset total loss for the new epoch
        if epoch == start_epoch:
            start = start_step
        else:
            start = 0

        for step, (images, input_ids, attention_masks) in enumerate(
            tqdm(train_dataloader, desc=f"Epoch {epoch + 1}")
        ):
            if step < start:
                continue  # Skip steps already completed

            images = images.to(device)
            input_ids = input_ids.squeeze(1).to(device)
            attention_mask = attention_masks.squeeze(1).to(device)

            # Extract CLIP features
            with torch.no_grad():
                clip_features = clip_model.encode_image(images).float()

            # Project CLIP features into prefix embeddings
            clip_project_output = model.clip_project(clip_features)
            prefix = clip_project_output.view(batch_size, prefix_length, hidden_size)

            # Obtain GPT-2 embeddings for tokens
            gpt2_embeddings = model.gpt.transformer.wte(input_ids)

            # Concatenate prefix with GPT-2 embeddings
            inputs_embeds = torch.cat([prefix, gpt2_embeddings], dim=1)

            # Adjust attention mask
            prefix_mask = torch.ones(batch_size, prefix_length, device=attention_mask.device)
            adjusted_attention_mask = torch.cat([prefix_mask, attention_mask], dim=1)




            # Forward pass
            outputs = model(
                tokens=input_ids,
                prefix=clip_features,
                mask=adjusted_attention_mask,
                labels=input_ids,
            )

            # Compute loss
            loss = outputs.loss / gradient_accumulation_steps

            # Backward pass
            loss.backward()

            # # Debugging: Check gradient norms
            # for name, param in model.named_parameters():
            #     if param.grad is not None:
            #         print(f"{name}: grad norm = {torch.norm(param.grad)}")

            # Gradient accumulation
            if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_dataloader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                optimizer.step()
                optimizer.zero_grad()

            total_loss += loss.item() * gradient_accumulation_steps

            # Logging
            if (step + 1) % log_interval == 0:
                avg_loss = total_loss / (step + 1)
                print(f"Step {step + 1}, Loss: {avg_loss:.4f}")

            # Perform validation
            if (step + 1) % validation_interval == 0:
                model.eval()
                val_loss = 0
                with torch.no_grad():
                    for val_images, val_input_ids, val_attention_masks in tqdm(val_subset_dataloader, desc="Validation", leave = False, disable =True):
                        val_images = val_images.to(device)
                        val_input_ids = val_input_ids.squeeze(1).to(device)
                        val_attention_masks = val_attention_masks.squeeze(1).to(device)
    
                        val_clip_features = clip_model.encode_image(val_images).float()
    
                        # Adjust validation attention mask
                        val_prefix_mask = torch.ones(val_images.size(0), prefix_length, device=val_attention_masks.device)
                        val_adjusted_attention_mask = torch.cat([val_prefix_mask, val_attention_masks], dim=1)
    
                        val_outputs = model(
                            tokens=val_input_ids,
                            prefix=val_clip_features,
                            mask=val_adjusted_attention_mask,
                            labels=val_input_ids,
                        )
                        val_loss += val_outputs.loss.item()
    
                avg_val_loss = val_loss / len(val_subset_dataloader)
                print(f"Validation after Step {step + 1}: Avg Val Loss: {avg_val_loss:.4f}")
    
                model.train()  # Switch back to training mode


        # Epoch summary
        avg_train_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch + 1} completed. Train Loss: {avg_train_loss:.4f}")

except KeyboardInterrupt:
    print("Training interrupted. Saving checkpoint...")
    save_checkpoint(model, optimizer, epoch, step, total_loss, checkpoint_path)

# save_checkpoint(model, optimizer, epoch, step, total_loss, checkpoint_path)



  checkpoint = torch.load(checkpoint_path)


Checkpoint loaded. Resuming from epoch 6, step 1320.


Epoch 7:  53%|████████████████████████████████████▋                                | 1330/2499 [09:13<14:49,  1.31it/s]

Step 1330, Loss: 0.0045


Epoch 7:  54%|████████████████████████████████████▉                                | 1340/2499 [09:21<14:51,  1.30it/s]

Step 1340, Loss: 0.0089


Epoch 7:  54%|█████████████████████████████████████▎                               | 1350/2499 [09:28<13:38,  1.40it/s]

Step 1350, Loss: 0.0131


Epoch 7:  54%|█████████████████████████████████████▌                               | 1360/2499 [09:35<14:35,  1.30it/s]

Step 1360, Loss: 0.0171


Epoch 7:  55%|█████████████████████████████████████▊                               | 1370/2499 [09:43<14:33,  1.29it/s]

Step 1370, Loss: 0.0215


Epoch 7:  55%|██████████████████████████████████████                               | 1380/2499 [09:51<14:35,  1.28it/s]

Step 1380, Loss: 0.0254


Epoch 7:  56%|██████████████████████████████████████▍                              | 1390/2499 [09:59<14:38,  1.26it/s]

Step 1390, Loss: 0.0290


Epoch 7:  56%|██████████████████████████████████████▋                              | 1399/2499 [10:05<13:58,  1.31it/s]

Step 1400, Loss: 0.0327


Epoch 7:  56%|████████████████████████████████████▍                            | 1400/2499 [21:59<65:30:16, 214.57s/it]

Validation after Step 1400: Avg Val Loss: 0.6762


Epoch 7:  56%|█████████████████████████████████████▊                             | 1410/2499 [22:07<2:03:52,  6.82s/it]

Step 1410, Loss: 0.0363


Epoch 7:  57%|███████████████████████████████████████▏                             | 1420/2499 [22:15<16:58,  1.06it/s]

Step 1420, Loss: 0.0402


Epoch 7:  57%|███████████████████████████████████████▍                             | 1430/2499 [22:22<13:30,  1.32it/s]

Step 1430, Loss: 0.0441


Epoch 7:  58%|███████████████████████████████████████▊                             | 1440/2499 [22:30<13:37,  1.30it/s]

Step 1440, Loss: 0.0477


Epoch 7:  58%|████████████████████████████████████████                             | 1450/2499 [22:37<13:04,  1.34it/s]

Step 1450, Loss: 0.0511


Epoch 7:  58%|████████████████████████████████████████▎                            | 1460/2499 [22:45<12:51,  1.35it/s]

Step 1460, Loss: 0.0548


Epoch 7:  59%|████████████████████████████████████████▌                            | 1470/2499 [22:52<12:43,  1.35it/s]

Step 1470, Loss: 0.0582


Epoch 7:  59%|████████████████████████████████████████▊                            | 1480/2499 [23:00<13:13,  1.28it/s]

Step 1480, Loss: 0.0618


Epoch 7:  60%|█████████████████████████████████████████▏                           | 1490/2499 [23:08<12:49,  1.31it/s]

Step 1490, Loss: 0.0651


Epoch 7:  60%|█████████████████████████████████████████▍                           | 1499/2499 [23:15<13:06,  1.27it/s]

Step 1500, Loss: 0.0684


Epoch 7:  60%|████████████████████████████████████████▏                          | 1500/2499 [23:33<1:40:33,  6.04s/it]

Validation after Step 1500: Avg Val Loss: 0.6785


Epoch 7:  60%|█████████████████████████████████████████▋                           | 1510/2499 [23:41<14:50,  1.11it/s]

Step 1510, Loss: 0.0715


Epoch 7:  61%|█████████████████████████████████████████▉                           | 1520/2499 [23:48<12:10,  1.34it/s]

Step 1520, Loss: 0.0744


Epoch 7:  61%|██████████████████████████████████████████▏                          | 1530/2499 [23:56<12:15,  1.32it/s]

Step 1530, Loss: 0.0772


Epoch 7:  62%|██████████████████████████████████████████▌                          | 1540/2499 [24:03<12:29,  1.28it/s]

Step 1540, Loss: 0.0802


Epoch 7:  62%|██████████████████████████████████████████▊                          | 1550/2499 [24:11<12:05,  1.31it/s]

Step 1550, Loss: 0.0833


Epoch 7:  62%|███████████████████████████████████████████                          | 1560/2499 [24:18<12:04,  1.30it/s]

Step 1560, Loss: 0.0861


Epoch 7:  63%|███████████████████████████████████████████▎                         | 1570/2499 [24:26<11:13,  1.38it/s]

Step 1570, Loss: 0.0894


Epoch 7:  63%|███████████████████████████████████████████▋                         | 1580/2499 [24:33<11:20,  1.35it/s]

Step 1580, Loss: 0.0923


Epoch 7:  64%|███████████████████████████████████████████▉                         | 1590/2499 [24:41<11:06,  1.36it/s]

Step 1590, Loss: 0.0952


Epoch 7:  64%|████████████████████████████████████████████▏                        | 1599/2499 [24:47<10:47,  1.39it/s]

Step 1600, Loss: 0.0980


Epoch 7:  64%|██████████████████████████████████████████▉                        | 1600/2499 [25:05<1:28:34,  5.91s/it]

Validation after Step 1600: Avg Val Loss: 0.6770


Epoch 7:  64%|████████████████████████████████████████████▍                        | 1610/2499 [25:13<12:44,  1.16it/s]

Step 1610, Loss: 0.1010


Epoch 7:  65%|████████████████████████████████████████████▋                        | 1620/2499 [25:20<11:48,  1.24it/s]

Step 1620, Loss: 0.1037


Epoch 7:  65%|█████████████████████████████████████████████                        | 1630/2499 [25:28<10:58,  1.32it/s]

Step 1630, Loss: 0.1065


Epoch 7:  66%|█████████████████████████████████████████████▎                       | 1640/2499 [25:35<10:15,  1.40it/s]

Step 1640, Loss: 0.1090


Epoch 7:  66%|█████████████████████████████████████████████▌                       | 1650/2499 [25:43<11:01,  1.28it/s]

Step 1650, Loss: 0.1115


Epoch 7:  66%|█████████████████████████████████████████████▊                       | 1660/2499 [25:51<10:45,  1.30it/s]

Step 1660, Loss: 0.1142


Epoch 7:  67%|██████████████████████████████████████████████                       | 1670/2499 [25:58<10:14,  1.35it/s]

Step 1670, Loss: 0.1170


Epoch 7:  67%|██████████████████████████████████████████████▍                      | 1680/2499 [26:06<10:01,  1.36it/s]

Step 1680, Loss: 0.1199


Epoch 7:  68%|██████████████████████████████████████████████▋                      | 1690/2499 [26:14<10:37,  1.27it/s]

Step 1690, Loss: 0.1226


Epoch 7:  68%|██████████████████████████████████████████████▉                      | 1699/2499 [26:20<09:32,  1.40it/s]

Step 1700, Loss: 0.1250


Epoch 7:  68%|█████████████████████████████████████████████▌                     | 1700/2499 [26:38<1:19:41,  5.98s/it]

Validation after Step 1700: Avg Val Loss: 0.6759


Epoch 7:  68%|███████████████████████████████████████████████▏                     | 1710/2499 [26:46<11:34,  1.14it/s]

Step 1710, Loss: 0.1273


Epoch 7:  69%|███████████████████████████████████████████████▍                     | 1720/2499 [26:54<10:08,  1.28it/s]

Step 1720, Loss: 0.1300


Epoch 7:  69%|███████████████████████████████████████████████▊                     | 1730/2499 [27:01<09:42,  1.32it/s]

Step 1730, Loss: 0.1327


Epoch 7:  70%|████████████████████████████████████████████████                     | 1740/2499 [27:09<09:30,  1.33it/s]

Step 1740, Loss: 0.1353


Epoch 7:  70%|████████████████████████████████████████████████▎                    | 1750/2499 [27:16<09:14,  1.35it/s]

Step 1750, Loss: 0.1376


Epoch 7:  70%|████████████████████████████████████████████████▌                    | 1760/2499 [27:24<09:42,  1.27it/s]

Step 1760, Loss: 0.1397


Epoch 7:  71%|████████████████████████████████████████████████▊                    | 1770/2499 [27:31<08:55,  1.36it/s]

Step 1770, Loss: 0.1421


Epoch 7:  71%|█████████████████████████████████████████████████▏                   | 1780/2499 [27:39<09:01,  1.33it/s]

Step 1780, Loss: 0.1445


Epoch 7:  72%|█████████████████████████████████████████████████▍                   | 1790/2499 [27:46<09:03,  1.30it/s]

Step 1790, Loss: 0.1465


Epoch 7:  72%|█████████████████████████████████████████████████▋                   | 1799/2499 [27:53<08:33,  1.36it/s]

Step 1800, Loss: 0.1490


Epoch 7:  72%|████████████████████████████████████████████████▎                  | 1800/2499 [28:12<1:10:21,  6.04s/it]

Validation after Step 1800: Avg Val Loss: 0.6777


Epoch 7:  72%|█████████████████████████████████████████████████▉                   | 1810/2499 [28:19<10:21,  1.11it/s]

Step 1810, Loss: 0.1513


Epoch 7:  73%|██████████████████████████████████████████████████▎                  | 1820/2499 [28:27<08:49,  1.28it/s]

Step 1820, Loss: 0.1537


Epoch 7:  73%|██████████████████████████████████████████████████▌                  | 1830/2499 [28:35<08:38,  1.29it/s]

Step 1830, Loss: 0.1561


Epoch 7:  74%|██████████████████████████████████████████████████▊                  | 1840/2499 [28:42<08:08,  1.35it/s]

Step 1840, Loss: 0.1583


Epoch 7:  74%|███████████████████████████████████████████████████                  | 1850/2499 [28:50<07:54,  1.37it/s]

Step 1850, Loss: 0.1605


Epoch 7:  74%|███████████████████████████████████████████████████▎                 | 1860/2499 [28:57<07:57,  1.34it/s]

Step 1860, Loss: 0.1630


Epoch 7:  75%|███████████████████████████████████████████████████▋                 | 1870/2499 [29:05<07:57,  1.32it/s]

Step 1870, Loss: 0.1651


Epoch 7:  75%|███████████████████████████████████████████████████▉                 | 1880/2499 [29:12<07:50,  1.32it/s]

Step 1880, Loss: 0.1672


Epoch 7:  76%|████████████████████████████████████████████████████▏                | 1890/2499 [29:20<07:53,  1.29it/s]

Step 1890, Loss: 0.1694


Epoch 7:  76%|████████████████████████████████████████████████████▍                | 1899/2499 [29:27<07:48,  1.28it/s]

Step 1900, Loss: 0.1714


Epoch 7:  76%|██████████████████████████████████████████████████▉                | 1900/2499 [29:46<1:01:03,  6.12s/it]

Validation after Step 1900: Avg Val Loss: 0.6768


Epoch 7:  76%|████████████████████████████████████████████████████▋                | 1910/2499 [29:53<09:04,  1.08it/s]

Step 1910, Loss: 0.1730


Epoch 7:  77%|█████████████████████████████████████████████████████                | 1920/2499 [30:01<07:54,  1.22it/s]

Step 1920, Loss: 0.1753


Epoch 7:  77%|█████████████████████████████████████████████████████▎               | 1930/2499 [30:09<07:29,  1.27it/s]

Step 1930, Loss: 0.1770


Epoch 7:  78%|█████████████████████████████████████████████████████▌               | 1940/2499 [30:17<07:31,  1.24it/s]

Step 1940, Loss: 0.1791


Epoch 7:  78%|█████████████████████████████████████████████████████▊               | 1950/2499 [30:24<06:59,  1.31it/s]

Step 1950, Loss: 0.1810


Epoch 7:  78%|██████████████████████████████████████████████████████               | 1960/2499 [30:32<06:47,  1.32it/s]

Step 1960, Loss: 0.1828


Epoch 7:  79%|██████████████████████████████████████████████████████▍              | 1970/2499 [30:40<06:46,  1.30it/s]

Step 1970, Loss: 0.1845


Epoch 7:  79%|██████████████████████████████████████████████████████▋              | 1980/2499 [30:47<06:33,  1.32it/s]

Step 1980, Loss: 0.1866


Epoch 7:  80%|██████████████████████████████████████████████████████▉              | 1990/2499 [30:55<06:33,  1.29it/s]

Step 1990, Loss: 0.1882


Epoch 7:  80%|███████████████████████████████████████████████████████▏             | 1999/2499 [31:02<06:46,  1.23it/s]

Step 2000, Loss: 0.1902


Epoch 7:  80%|███████████████████████████████████████████████████████▏             | 2000/2499 [31:20<50:36,  6.08s/it]

Validation after Step 2000: Avg Val Loss: 0.6783


Epoch 7:  80%|███████████████████████████████████████████████████████▍             | 2010/2499 [31:28<07:34,  1.08it/s]

Step 2010, Loss: 0.1919


Epoch 7:  81%|███████████████████████████████████████████████████████▊             | 2020/2499 [31:36<06:07,  1.30it/s]

Step 2020, Loss: 0.1937


Epoch 7:  81%|████████████████████████████████████████████████████████             | 2030/2499 [31:43<05:54,  1.32it/s]

Step 2030, Loss: 0.1956


Epoch 7:  82%|████████████████████████████████████████████████████████▎            | 2040/2499 [31:51<05:53,  1.30it/s]

Step 2040, Loss: 0.1976


Epoch 7:  82%|████████████████████████████████████████████████████████▌            | 2050/2499 [31:59<05:41,  1.32it/s]

Step 2050, Loss: 0.1992


Epoch 7:  82%|████████████████████████████████████████████████████████▉            | 2060/2499 [32:06<05:43,  1.28it/s]

Step 2060, Loss: 0.2010


Epoch 7:  83%|█████████████████████████████████████████████████████████▏           | 2070/2499 [32:14<05:22,  1.33it/s]

Step 2070, Loss: 0.2029


Epoch 7:  83%|█████████████████████████████████████████████████████████▍           | 2080/2499 [32:21<05:13,  1.34it/s]

Step 2080, Loss: 0.2045


Epoch 7:  84%|█████████████████████████████████████████████████████████▋           | 2090/2499 [32:29<05:04,  1.34it/s]

Step 2090, Loss: 0.2061


Epoch 7:  84%|█████████████████████████████████████████████████████████▉           | 2099/2499 [32:36<05:06,  1.30it/s]

Step 2100, Loss: 0.2078


In [35]:
# print("prefix size:", prefix.size())
# print("input_ids size:", input_ids.size())
# print("attention_masks size:", attention_masks.size())
# print("extended_attention_masks size:", extended_attention_masks.size())


In [72]:
save_checkpoint(model, optimizer, epoch, step, total_loss, checkpoint_path)


Checkpoint saved at epoch 4, step 2498.


In [43]:
# import torch

# checkpoint = torch.load(r"C:\Users\vdsha\Downloads\model_checkpoint_epoch1.pth")
# model.load_state_dict(checkpoint['model_state_dict'])


# # Set the model to evaluation mode
# model.eval()


In [45]:
from transformers import GPT2Tokenizer

# Load the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")  # Replace "gpt2" with your tokenizer if customized


In [37]:
import torch
from PIL import Image
from transformers import GPT2Tokenizer
from torchvision import transforms

# Paths
image_path = r"D:\vidisha\vizwiz\real time test imgs\nivea.jpg"
checkpoint_path = r"C:\Users\vdsha\Downloads\model_checkpoint_epoch4_new.pth"
# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ClipCaptionModel(prefix_length=10).to(device)  # Adjust prefix_length as per your training
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)  # Use the same optimizer as training

# Load the checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

# Load the CLIP model
from transformers import CLIPProcessor, CLIPModel
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Preprocess the image
image = Image.open(image_path).convert("RGB")
preprocessed_image = clip_processor(images=image, return_tensors="pt")["pixel_values"].to(device)

# Extract CLIP features
with torch.no_grad():
    clip_features = clip_model.get_image_features(preprocessed_image).float()



# Load GPT-2 tokenizer and add special tokens
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Add BOS, EOS, and PAD tokens
special_tokens = {'bos_token': '<start>', 'eos_token': '<end>', 'pad_token': '<pad>'}
gpt2_tokenizer.add_special_tokens(special_tokens)

# Update pad token ID
gpt2_tokenizer.pad_token_id = gpt2_tokenizer.convert_tokens_to_ids("<pad>")

# Ensure the model knows about the new tokens
model.gpt.resize_token_embeddings(len(gpt2_tokenizer))

# Use special tokens
start_token_id = gpt2_tokenizer.bos_token_id
end_token_id = gpt2_tokenizer.eos_token_id

max_length = 100  # Adjust this as needed

# Generate caption
# with torch.no_grad():
#     tokens = torch.tensor([[start_token_id]], device=device)
#     prefix = model.clip_project(clip_features).view(1, model.prefix_length, model.gpt_embedding_size)

#     for _ in range(max_length):
#         outputs = model.gpt(inputs_embeds=torch.cat((prefix, model.gpt.transformer.wte(tokens)), dim=1))
#         logits = outputs.logits
#         next_token = torch.argmax(logits[:, -1, :], dim=-1)
#         tokens = torch.cat((tokens, next_token.unsqueeze(0)), dim=1)

#         if next_token.item() == end_token_id:
#             break

# caption = gpt2_tokenizer.decode(tokens.squeeze().tolist(), skip_special_tokens=True)
# print(f"Generated Caption: {caption}")

# Generate caption
with torch.no_grad():
    tokens = torch.tensor([[start_token_id]], device=device)
    prefix = model.clip_project(clip_features).view(1, model.prefix_length, model.gpt_embedding_size)

    for _ in range(max_length):
        outputs = model.gpt(inputs_embeds=torch.cat((prefix, model.gpt.transformer.wte(tokens)), dim=1))
        logits = outputs.logits
        next_token = torch.argmax(logits[:, -1, :], dim=-1)

        # Stop if the <end> token is generated
        if next_token.item() == end_token_id:
            break

        tokens = torch.cat((tokens, next_token.unsqueeze(0)), dim=1)

# Decode and remove unnecessary repetitions
caption = gpt2_tokenizer.decode(tokens.squeeze().tolist(), skip_special_tokens=True)
caption = caption.split("<end>")[0].strip()  # Keep only the first valid caption
caption = caption.replace("<start>", "").strip()

# Debugging output
# print("Generated Token IDs:", tokens.tolist())
print(f"Generated Caption: {caption}")


  checkpoint = torch.load(checkpoint_path, map_location=device)
  attn_output = torch.nn.functional.scaled_dot_product_attention(
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Generated Caption: A blue bottle of shampoo is on a table.


In [47]:
# import torch
# from PIL import Image
# import easyocr
# from transformers import GPT2Tokenizer, CLIPProcessor, CLIPModel

# def load_models():
#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
#     # Load fine-tuned CLIP-GPT2 model
#     model = ClipCaptionModel(prefix_length=10).to(device)
#     checkpoint_path = r"C:\Users\vdsha\Downloads\model_checkpoint_epoch4_new.pth"  # Update with actual path
#     checkpoint = torch.load(checkpoint_path, map_location=device)
#     model.load_state_dict(checkpoint["model_state_dict"])
#     model.eval()
    
#     # Load CLIP model
#     clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
#     clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    
#     # Load tokenizer
#     tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
#     tokenizer.add_special_tokens({'bos_token': '<start>', 'eos_token': '<end>', 'pad_token': '<pad>'})
#     tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids("<pad>")
#     model.gpt.resize_token_embeddings(len(tokenizer))
    
#     return model, clip_model, clip_processor, tokenizer, device

# def get_ocr_text(image_path):
#     reader = easyocr.Reader(["en"])  
#     results = reader.readtext(image_path)
#     extracted_text = " ".join([res[1] for res in results])  

#     # Clean up detected text (remove numbers, extra words, etc.)
#     words = extracted_text.split()
#     filtered_words = [word for word in words if len(word) > 2 and not word.isnumeric()]
#     extracted_text = " ".join(filtered_words[:4])  # Limit to 4 words
    
#     return extracted_text.strip()

# def generate_caption(image_path, model, clip_model, clip_processor, tokenizer, ocr_text, device):
#     image = Image.open(image_path).convert("RGB")
#     preprocessed_image = clip_processor(images=image, return_tensors="pt")["pixel_values"].to(device)
    
#     with torch.no_grad():
#         clip_features = clip_model.get_image_features(preprocessed_image).float()
    
#     start_token_id = tokenizer.bos_token_id
#     end_token_id = tokenizer.eos_token_id
#     max_length = 50  

#     # Convert OCR text into tokens and filter duplication
#     ocr_tokens = tokenizer.encode(ocr_text, add_special_tokens=False) if ocr_text else []
    
#     with torch.no_grad():
#         prefix = model.clip_project(clip_features).view(1, model.prefix_length, model.gpt_embedding_size)
#         tokens = torch.tensor([[start_token_id]], device=device)

#         for _ in range(max_length):
#             outputs = model.gpt(inputs_embeds=torch.cat((prefix, model.gpt.transformer.wte(tokens)), dim=1))
#             logits = outputs.logits[:, -1, :]
#             next_token = torch.argmax(logits, dim=-1)

#             if next_token.item() == end_token_id:
#                 break
#             tokens = torch.cat((tokens, next_token.unsqueeze(0)), dim=1)
    
#     caption = tokenizer.decode(tokens.squeeze().tolist(), skip_special_tokens=True)
#     caption = caption.replace("<start>", "").split("<end>")[0].strip()

#     # **Fix duplication: Remove repeated OCR words in caption**
#     if ocr_text:
#         ocr_words = set(ocr_text.lower().split())
#         caption_words = caption.split()
#         caption = " ".join([word for word in caption_words if word.lower() not in ocr_words])

#         # Reinsert OCR text in a natural way
#         caption = f"{ocr_text} {caption}".strip()

#     return caption

# def main(image_path):
#     model, clip_model, clip_processor, tokenizer, device = load_models()
    
#     # Get detected text from EasyOCR
#     ocr_text = get_ocr_text(image_path)
    
#     # Generate caption with **seamless OCR integration**
#     final_caption = generate_caption(image_path, model, clip_model, clip_processor, tokenizer, ocr_text, device)

#     print("Final Caption:", final_caption)
#     return final_caption

# # Example usage
# image_path = r"D:\vidisha\vizwiz\test\VizWiz_test_00000046.jpg"  # Update with actual path
# main(image_path)


  checkpoint = torch.load(checkpoint_path, map_location=device)
  net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
  model.load_state_dict(torch.load(model_path, map_location=device))


Final Caption: SLICED PEACHEY A can of minced garlic sitting on a countertop.


'SLICED PEACHEY A can of minced garlic sitting on a countertop.'

In [37]:
import torch
from PIL import Image
import PIL
if not hasattr(PIL.Image, "ANTIALIAS"):  
    PIL.Image.ANTIALIAS = PIL.Image.LANCZOS  # Redirect ANTIALIAS to LANCZOS

import easyocr
from transformers import GPT2Tokenizer, CLIPProcessor, CLIPModel

def load_models():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load fine-tuned CLIP-GPT2 model
    model = ClipCaptionModel(prefix_length=10).to(device)
    checkpoint_path = r"C:\Users\vdsha\Downloads\model_checkpoint_epoch4_new.pth" # Update path
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only= True)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    
    # Load CLIP model
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    
    # Load tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.add_special_tokens({'bos_token': '<start>', 'eos_token': '<end>', 'pad_token': '<pad>'})
    tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids("<pad>")
    model.gpt.resize_token_embeddings(len(tokenizer))
    
    return model, clip_model, clip_processor, tokenizer, device

def get_ocr_text(image_path):
    reader = easyocr.Reader(["en"])  # Load EasyOCR for English
    results = reader.readtext(image_path)
    extracted_text = " ".join([res[1] for res in results])  # Combine detected words

    # Clean up detected text (remove single characters, numbers, etc.)
    words = extracted_text.split()
    filtered_words = [word for word in words if len(word) > 2 and not word.isnumeric()]
    extracted_text = " ".join(filtered_words[:5])  # Limit to 5 words for better integration
    
    return extracted_text.strip()

def generate_caption(image_path, model, clip_model, clip_processor, tokenizer, ocr_text, device):
    image = Image.open(image_path).convert("RGB")
    preprocessed_image = clip_processor(images=image, return_tensors="pt")["pixel_values"].to(device)
    
    with torch.no_grad():
        clip_features = clip_model.get_image_features(preprocessed_image).float()
    
    start_token_id = tokenizer.bos_token_id
    end_token_id = tokenizer.eos_token_id
    max_length = 50  

    # **Inject OCR text into token sequence**
    tokens = [start_token_id]  # Start sequence
    
    if ocr_text:  
        ocr_tokens = tokenizer.encode(ocr_text, add_special_tokens=False)  # Convert OCR text to tokens
        tokens.extend(ocr_tokens)  # Merge OCR tokens **before** caption generation

    tokens = torch.tensor([tokens], device=device)  # Convert to tensor

    with torch.no_grad():
        prefix = model.clip_project(clip_features).view(1, model.prefix_length, model.gpt_embedding_size)
        
        for _ in range(max_length):
            outputs = model.gpt(inputs_embeds=torch.cat((prefix, model.gpt.transformer.wte(tokens)), dim=1))
            logits = outputs.logits[:, -1, :]
            next_token = torch.argmax(logits, dim=-1)
            
            if next_token.item() == end_token_id:
                break
            tokens = torch.cat((tokens, next_token.unsqueeze(0)), dim=1)
    
    caption = tokenizer.decode(tokens.squeeze().tolist(), skip_special_tokens=True)
    caption = caption.replace("<start>", "").split("<end>")[0].strip()

    return caption

def main(image_path):
    model, clip_model, clip_processor, tokenizer, device = load_models()
    
    # Get detected text from EasyOCR
    ocr_text = get_ocr_text(image_path)
    
    # Generate caption with **embedded OCR text**
    final_caption = generate_caption(image_path, model, clip_model, clip_processor, tokenizer, ocr_text, device)

    # print("Final Caption:", final_caption)
    return final_caption

# Example usage: "D:\vidisha\vizwiz\test\VizWiz_test_00000474.jpg","D:\vidisha\vizwiz\real time test imgs\nivea.jpg", "D:\vidisha\vizwiz\test\VizWiz_test_00000046.jpg", "D:\vidisha\vizwiz\test\VizWiz_test_00000148.jpg", "D:\vidisha\vizwiz\test\VizWiz_test_00000181.jpg"
image_path = r"D:\vidisha\vizwiz\real time test imgs\nivea.jpg"# Update with actual path
main(image_path)


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
  net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
  model.load_state_dict(torch.load(model_path, map_location=device))
  attn_output = torch.nn.functional.scaled_dot_product_attention(


'NIVEA Body Mllk Shea Smoothie  A bottle of body lotion is on a table.'

In [135]:
import torch
from PIL import Image
import easyocr
from transformers import GPT2Tokenizer, CLIPProcessor, CLIPModel
import PIL
if not hasattr(PIL.Image, "ANTIALIAS"):  
    PIL.Image.ANTIALIAS = PIL.Image.LANCZOS  # Redirect ANTIALIAS to LANCZOS


def load_models():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load fine-tuned CLIP-GPT2 model
    model = ClipCaptionModel(prefix_length=10).to(device)
    checkpoint_path = r"C:\Users\vdsha\Downloads\model_checkpoint_epoch4_new.pth" # Update path
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only= True)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    
    # Load CLIP model
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    
    # Load tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.add_special_tokens({'bos_token': '<start>', 'eos_token': '<end>', 'pad_token': '<pad>'})
    tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids("<pad>")
    model.gpt.resize_token_embeddings(len(tokenizer))
    
    return model, clip_model, clip_processor, tokenizer, device

def get_ocr_text(image_path):
    reader = easyocr.Reader(["en"])  # Load EasyOCR for English
    results = reader.readtext(image_path)
    extracted_text = " ".join([res[1] for res in results])  # Combine detected words

    # Clean up detected text (remove single characters, numbers, etc.)
    words = extracted_text.split()
    filtered_words = [word for word in words if len(word) > 2 and not word.isnumeric()]
    extracted_text = " ".join(filtered_words[:5])  # Limit to 5 words for better integration
    
    return extracted_text.strip()

# def generate_caption(image_path, model, clip_model, clip_processor, tokenizer, ocr_text, device):
#     image = Image.open(image_path).convert("RGB")
#     preprocessed_image = clip_processor(images=image, return_tensors="pt")["pixel_values"].to(device)
    
#     with torch.no_grad():
#         clip_features = clip_model.get_image_features(preprocessed_image).float()
    
#     start_token_id = tokenizer.bos_token_id
#     end_token_id = tokenizer.eos_token_id
#     max_length = 50  

#     # **Inject OCR text into token sequence**
#     tokens = [start_token_id]  # Start sequence
    
#     if ocr_text:  
#         ocr_tokens = tokenizer.encode(ocr_text, add_special_tokens=False)  # Convert OCR text to tokens
#         tokens.extend(ocr_tokens)  # Merge OCR tokens **before** caption generation

#     tokens = torch.tensor([tokens], device=device)  # Convert to tensor

#     with torch.no_grad():
#         prefix = model.clip_project(clip_features).view(1, model.prefix_length, model.gpt_embedding_size)
        
#         for _ in range(max_length):
#             outputs = model.gpt(inputs_embeds=torch.cat((prefix, model.gpt.transformer.wte(tokens)), dim=1))
#             logits = outputs.logits[:, -1, :]
#             next_token = torch.argmax(logits, dim=-1)
            
#             if next_token.item() == end_token_id:
#                 break
#             tokens = torch.cat((tokens, next_token.unsqueeze(0)), dim=1)
    
#     caption = tokenizer.decode(tokens.squeeze().tolist(), skip_special_tokens=True)
#     caption = caption.replace("<start>", "").split("<end>")[0].strip()

#     return caption

# def generate_caption(image_path, model, clip_model, clip_processor, tokenizer, ocr_text, device):
#     image = Image.open(image_path).convert("RGB")
#     preprocessed_image = clip_processor(images=image, return_tensors="pt")["pixel_values"].to(device)

#     with torch.no_grad():
#         clip_features = clip_model.get_image_features(preprocessed_image).float()

#     start_token_id = tokenizer.bos_token_id
#     end_token_id = tokenizer.eos_token_id
#     max_length = 50  

#     tokens = torch.tensor([[start_token_id]], device=device)

#     with torch.no_grad():
#         prefix = model.clip_project(clip_features).view(1, model.prefix_length, model.gpt_embedding_size)

#         for i in range(max_length):
#             outputs = model.gpt(inputs_embeds=torch.cat((prefix, model.gpt.transformer.wte(tokens)), dim=1))
#             logits = outputs.logits[:, -1, :]
#             next_token = torch.argmax(logits, dim=-1)

#             # Insert OCR text **after the first few tokens** for natural integration
#             if i == 2 and ocr_text:  
#                 ocr_tokens = tokenizer.encode(ocr_text, add_special_tokens=False)  # Convert OCR text to tokens
#                 ocr_tokens = torch.tensor([ocr_tokens], device=device)
#                 tokens = torch.cat((tokens, ocr_tokens), dim=1)

#             if next_token.item() == end_token_id:
#                 break

#             tokens = torch.cat((tokens, next_token.unsqueeze(0)), dim=1)

    # caption = tokenizer.decode(tokens.squeeze().tolist(), skip_special_tokens=True)
    # caption = caption.replace("<start>", "").split("<end>")[0].strip()

    # return caption

def generate_caption(image_path, model, clip_model, clip_processor, tokenizer, ocr_text, device):
    image = Image.open(image_path).convert("RGB")
    preprocessed_image = clip_processor(images=image, return_tensors="pt")["pixel_values"].to(device)

    with torch.no_grad():
        clip_features = clip_model.get_image_features(preprocessed_image).float()

    start_token_id = tokenizer.bos_token_id
    end_token_id = tokenizer.eos_token_id
    max_length = 50  

    tokens = [start_token_id]  # Start sequence

    if ocr_text:  
        # 🔹 If OCR text is available, let GPT-2 use it as the **main subject**
        ocr_tokens = tokenizer.encode(ocr_text, add_special_tokens=False)
        tokens.extend(ocr_tokens)
    else:
        # 🔹 No OCR? Let GPT-2 generate normally from the start
        tokens = [start_token_id]

    tokens = torch.tensor([tokens], device=device)

    with torch.no_grad():
        prefix = model.clip_project(clip_features).view(1, model.prefix_length, model.gpt_embedding_size)

        for _ in range(max_length):
            outputs = model.gpt(inputs_embeds=torch.cat((prefix, model.gpt.transformer.wte(tokens)), dim=1))
            logits = outputs.logits[:, -1, :]
            next_token = torch.argmax(logits, dim=-1)

            if next_token.item() == end_token_id:
                break
            tokens = torch.cat((tokens, next_token.unsqueeze(0)), dim=1)

    caption = tokenizer.decode(tokens.squeeze().tolist(), skip_special_tokens=True)
    caption = caption.replace("<start>", "").split("<end>")[0].strip()

    return caption

def main(image_path):
    model, clip_model, clip_processor, tokenizer, device = load_models()
    
    # Get detected text from EasyOCR
    ocr_text = get_ocr_text(image_path)
    
    # Generate caption with **embedded OCR text**
    final_caption = generate_caption(image_path, model, clip_model, clip_processor, tokenizer, ocr_text, device)

    print("Final Caption:", final_caption)
    return final_caption

# Example usage
image_path = r"D:\vidisha\vizwiz\real time test imgs\nivea.jpg"# Update with actual path
main(image_path)


  net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
  model.load_state_dict(torch.load(model_path, map_location=device))


Final Caption: NIVEA Body Mllk Shea Smoothie  A bottle of body lotion is on a table.


'NIVEA Body Mllk Shea Smoothie  A bottle of body lotion is on a table.'

In [129]:
# def generate_caption(image_path, model, clip_model, gpt2_tokenizer, device, max_length=50):
#     """Generate image caption using CLIP-GPT2 with OCR conditioning."""

#     # Load and preprocess the image
#     image = Image.open(image_path).convert("RGB")
#     preprocessed_image = clip_processor(images=image, return_tensors="pt")["pixel_values"].to(device)

#     # Extract CLIP features
#     with torch.no_grad():
#         clip_features = clip_model.get_image_features(preprocessed_image).float()

#     # Extract OCR text
#     ocr_text = get_ocr_text(image_path)
#     print(f"[OCR]: {ocr_text}")

#     # Construct the prompt for GPT-2
#     if ocr_text.strip():
#         prompt = f"Detected text: {ocr_text}. Caption:"
#     else:
#         prompt = "Caption:"

#     # Tokenize the prompt
#     input_ids = gpt2_tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)

#     # Project CLIP features
#     prefix = model.clip_project(clip_features).view(1, model.prefix_length, model.gpt_embedding_size)

#     # Generate caption
#     with torch.no_grad():
#         for _ in range(max_length):
#             # Embed input tokens
#             gpt2_embeddings = model.gpt.transformer.wte(input_ids)
#             inputs_embeds = torch.cat((prefix, gpt2_embeddings), dim=1)

#             # Generate logits
#             outputs = model.gpt(inputs_embeds=inputs_embeds)
#             logits = outputs.logits[:, -1, :]
#             next_token = torch.argmax(logits, dim=-1)

#             # Stop if the <end> token is generated
#             if next_token.item() == gpt2_tokenizer.eos_token_id:
#                 break

#             # Append new token
#             input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)

#     # Decode and clean the caption
#     caption = gpt2_tokenizer.decode(input_ids.squeeze().tolist(), skip_special_tokens=True)
    
#     # **Fix: Stop at the first <end> token and remove redundant <start>**
#     caption = caption.split("<end>")[0].strip()  # Keep only the first valid caption
#     caption = caption.replace("<start>", "").strip()  # Remove <start> token if present

#     # **Fix: Prevent redundant OCR text in the final output**
#     if ocr_text.strip().lower() in caption.lower():
#         final_caption = caption
#     else:
#         final_caption = f"Detected text: {ocr_text}. {caption}"

#     return final_caption

import re  # Import regex for OCR text cleanup

def clean_ocr_text(ocr_text):
    """Removes noisy characters and keeps only readable words."""
    ocr_text = re.sub(r"[^a-zA-Z0-9\s]", "", ocr_text)  # Remove symbols
    words = ocr_text.split()
    return " ".join(words[:5])  # Keep only the first 6 words for brevity

def generate_caption(image_path, model, clip_model, gpt2_tokenizer, device, max_length=50):
    """Generate image caption using CLIP-GPT2 with OCR conditioning (removes repetition & filters OCR)."""

    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")
    preprocessed_image = clip_processor(images=image, return_tensors="pt")["pixel_values"].to(device)

    # Extract CLIP features
    with torch.no_grad():
        clip_features = clip_model.get_image_features(preprocessed_image).float()

    # Extract OCR text
    ocr_text = get_ocr_text(image_path)
    print(f"[OCR]: {ocr_text}")

    # 🔹 **Clean the OCR text**  
    clean_text = clean_ocr_text(ocr_text)

    # 🔹 **New Prompt Style**  
    if clean_text.strip():
        prompt = f"Text detected: {clean_text}. Describe the scene:"
    else:
        prompt = "Describe the scene:"

    # Tokenize the prompt
    input_ids = gpt2_tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)

    # Project CLIP features
    prefix = model.clip_project(clip_features).view(1, model.prefix_length, model.gpt_embedding_size)

    # Generate caption
    with torch.no_grad():
        for _ in range(max_length):
            # Embed input tokens
            gpt2_embeddings = model.gpt.transformer.wte(input_ids)
            inputs_embeds = torch.cat((prefix, gpt2_embeddings), dim=1)

            # Generate logits
            outputs = model.gpt(inputs_embeds=inputs_embeds)
            logits = outputs.logits[:, -1, :]
            next_token = torch.argmax(logits, dim=-1)

            # Stop if the <end> token is generated
            if next_token.item() == gpt2_tokenizer.eos_token_id:
                break

            # Append new token
            input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)

    # Decode and clean the caption
    caption = gpt2_tokenizer.decode(input_ids.squeeze().tolist(), skip_special_tokens=True)

    # 🔹 **Remove repetition**  
    unique_sentences = list(dict.fromkeys(caption.split(". ")))  # Remove duplicate sentences
    caption = ". ".join(unique_sentences)  # Reconstruct without repetition

    # 🔹 **Clean up and finalize the caption**  
    caption = caption.replace("Describe the scene:", "").strip()
    caption = caption.replace("Text detected:", "").strip()
    caption = caption.replace("<start>", "").strip()
    caption = caption.split("<end>")[0].strip()  # Stop at first `<end>`

    # 🔹 **Append cleaned OCR text only if needed**  
    if clean_text.strip().lower() not in caption.lower():
        final_caption = f"{caption}. (Detected text: {clean_text})"
    else:
        final_caption = caption  # No need to re-add OCR text

    return final_caption


# Example usage:
image_path = r"D:\vidisha\vizwiz\real time test imgs\anvi.jpg"
caption = generate_caption(image_path, model, clip_model, gpt2_tokenizer, device)
print(f"Final Caption: {caption}")


  net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
  model.load_state_dict(torch.load(model_path, map_location=device))


[OCR]: Smile
Final Caption: Smile.  A woman is wearing a pink t-shirt with a woman in the background.  A woman is wearing a blue t-shirt with a woman in the background.


In [127]:
#improving second approach of prompt tuning for gpt2
import re
import torch

def clean_ocr_text(ocr_text):
    """Removes noisy characters and keeps only readable words."""
    ocr_text = re.sub(r"[^a-zA-Z0-9\s]", "", ocr_text)  # Remove symbols
    words = ocr_text.split()
    return " ".join(words[:10])  # Keep only the first 5 words for brevity

def remove_hallucinations(caption, image_objects):
    """Removes hallucinated objects that aren't detected by CLIP."""
    hallucinated_objects = ["bookcase", "water", "person", "chair", "sofa", "lamp"]  # Common hallucinations
    filtered_caption = " ".join([word for word in caption.split() if word.lower() not in hallucinated_objects])
    return filtered_caption.strip()

def generate_caption(image_path, model, clip_model, gpt2_tokenizer, device, max_length=50):
    """Generate image caption using CLIP-GPT2 with OCR conditioning (prevents hallucination)."""

    # Load and preprocess the image
    image = Image.open(image_path).convert("RGB")
    preprocessed_image = clip_processor(images=image, return_tensors="pt")["pixel_values"].to(device)

    # Extract CLIP features
    with torch.no_grad():
        clip_features = clip_model.get_image_features(preprocessed_image).float()

    # Extract and clean OCR text
    ocr_text = get_ocr_text(image_path)
    clean_text = clean_ocr_text(ocr_text)
    print(f"[OCR]: {clean_text}")

    # 🔹 **Better OCR Integration**
    if clean_text.strip():
        prompt = f"Image contains text: '{clean_text}'. What else is in the image?"
    else:
        prompt = "What do you see in the image?"

    # Tokenize the refined prompt
    input_ids = gpt2_tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)

    # Project CLIP features
    prefix = model.clip_project(clip_features).view(1, model.prefix_length, model.gpt_embedding_size)

    # Generate caption
    with torch.no_grad():
        for _ in range(max_length):
            # Embed input tokens
            gpt2_embeddings = model.gpt.transformer.wte(input_ids)
            inputs_embeds = torch.cat((prefix, gpt2_embeddings), dim=1)

            # Generate logits
            outputs = model.gpt(inputs_embeds=inputs_embeds)
            logits = outputs.logits[:, -1, :]
            next_token = torch.argmax(logits, dim=-1)

            # Stop at the first `<eos>` token
            if next_token.item() == gpt2_tokenizer.eos_token_id:
                break

            # Append new token
            input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1)

    # Decode and clean the caption
    caption = gpt2_tokenizer.decode(input_ids.squeeze().tolist(), skip_special_tokens=True)

    # 🔹 **Final Cleanup**
    caption = caption.replace(prompt, "").strip()  # Remove any leftover prompt text
    caption = caption.replace("<start>", "").strip()
    caption = caption.split("<end>")[0].strip()  # Keep text before first `<end>`

    # 🔹 **Hallucination Removal**
    caption = remove_hallucinations(caption, clip_features)

    # 🔹 **Ensure OCR text is naturally included**
    if clean_text.strip().lower() not in caption.lower():
        caption = f"{clean_text}. {caption}"

    return caption


# Example usage:
image_path = r"D:\vidisha\vizwiz\real time test imgs\anvi.jpg"
caption = generate_caption(image_path, model, clip_model, gpt2_tokenizer, device)
print(f"Final Caption: {caption}")


  net.load_state_dict(copyStateDict(torch.load(trained_model, map_location=device)))
  model.load_state_dict(torch.load(model_path, map_location=device))


[OCR]: Smile
Final Caption: Smile. A woman wearing a pink t-shirt with a blue flower on it.


In [87]:
print("CLIP Features Shape:", clip_features.shape)
print("CLIP Features Sample:", clip_features[0, :5])  # Print first 5 values
prefix = model.clip_project(clip_features).view(1, model.prefix_length, model.gpt_embedding_size)
print("Prefix Shape:", prefix.shape)
print("Prefix Sample:", prefix[0, 0, :5])  # Print first 5 values of the first prefix
print("Tokens:", tokens)


CLIP Features Shape: torch.Size([1, 512])
CLIP Features Sample: tensor([-0.2584, -0.0551,  0.0754, -0.0036, -0.0451], device='cuda:0')
Prefix Shape: torch.Size([1, 10, 768])
Prefix Sample: tensor([-0.0038, -0.1391,  0.2918,  0.1991,  0.0606], device='cuda:0',
       grad_fn=<SliceBackward0>)
Tokens: tensor([[50256,     0,     2,     0,     1,     3,     0,    20,     1,     5,
             1,     0,     3,     0,     0,    21,     7,     0,     4,     0,
             0,     0,     0,     1,     1,     2,     0,    11,     0,     0,
             3,     0,     5,     3,     0,     9,     0,     2,     0,    12,
             0,     0,     0,    13,     4,     0,     0,     0,     0,     0,
             0]], device='cuda:0')
