In [1]:
import os
import json
import random
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from transformers import AutoTokenizer

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [3]:
class COCOImageCaptionDataset(Dataset):
    def __init__(self, img_dir, annotations_file, transform=None):
        self.img_dir = img_dir
        
        # Load annotations
        with open(annotations_file, 'r') as f:
            self.annotations = json.load(f)
        
        # Define default transform if not provided
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomGrayscale(p=0.3),
            transforms.RandomRotation(5),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.1, contrast=0.1),
            transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 1.0)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")

        # Get list of image files
        self.image_files = [
            os.path.join(img_dir, file)
            for file in os.listdir(img_dir)
            if file.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))
        ]
    
    def __len__(self):
        return len(self.annotations['annotations'])
    
    def __getitem__(self, idx):
        # Get annotation
        ann = self.annotations['annotations'][idx]
        
        # Load image
        img_path = os.path.join(self.img_dir, f"{ann['image_id']:012d}.jpg")
        if not os.path.exists(img_path):
            img_path = os.path.join(self.img_dir, f"{ann['image_id']}.jpg")  # Fallback

        image = Image.open(img_path).convert('RGB')
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        # Encode anchor caption
        caption = ann['caption']
        encoding = self.tokenizer(caption, padding='max_length', 
                                  truncation=True, max_length=64,
                                  return_tensors='pt')
        
        # Select negative caption safely
        if len(self.annotations['annotations']) > 1:
            neg_idx = random.choice([i for i in range(len(self)) if i != idx])
        else:
            neg_idx = idx  # Fallback if only one annotation
        
        neg_caption = self.annotations['annotations'][neg_idx]['caption']
        neg_encoding = self.tokenizer(neg_caption, padding='max_length',
                                      truncation=True, max_length=64,
                                      return_tensors='pt')
        
        return {
            'image': image,
            'caption_ids': encoding['input_ids'].squeeze(0),
            'caption_mask': encoding['attention_mask'].squeeze(0),
            'neg_caption_ids': neg_encoding['input_ids'].squeeze(0),
            'neg_caption_mask': neg_encoding['attention_mask'].squeeze(0)
        }


In [4]:
img_path = '/kaggle/input/coco-2017-dataset/coco2017/train2017'
cap_path = '/kaggle/input/coco-2017-dataset/coco2017/annotations/captions_train2017.json'

train_data = COCOImageCaptionDataset(img_path, cap_path)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [5]:
sample = train_data[0]
sample['image'].shape

torch.Size([3, 224, 224])

In [6]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4)

In [7]:
#utility scripts
from cnn_img_encoder import ImgEncoder_CNN
from bert_encodings import TextEncoder

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 196MB/s] 


torch.Size([32, 512])


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [8]:
class CLIPModel(nn.Module):
    def __init__(self, embedding_dim=512):
        super().__init__()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.image_encoder = ImgEncoder_CNN(projection_dim=embedding_dim).to(device)
        self.text_encoder = TextEncoder(projection_dim=embedding_dim).to(device)  
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        
    def forward(self, images, caption_ids, caption_mask, neg_caption_ids, neg_caption_mask):
        # Get embeddings
        image_features = self.image_encoder(images)
        text_features = self.text_encoder(caption_ids, caption_mask)
        neg_text_features = self.text_encoder(neg_caption_ids, neg_caption_mask)
        
        # Normalize features
        image_features = F.normalize(image_features, p=2, dim=-1)
        text_features = F.normalize(text_features, p=2, dim=-1)
        neg_text_features = F.normalize(neg_text_features, p=2, dim=-1)
        
        # Scaled pairwise cosine similarities
        logit_scale = torch.exp(self.logit_scale)
        pos_logits = (image_features * text_features).sum(dim=-1) * logit_scale
        neg_logits = (image_features * neg_text_features).sum(dim=-1) * logit_scale
        
        return pos_logits, neg_logits

In [9]:
import torch.optim as optim 


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Fixed device
model = CLIPModel().to(device)
learning_rate = 3e-4  
optimizer = optim.AdamW(model.parameters(), lr=learning_rate) 


In [None]:
from tqdm import tqdm

num_epochs = 10 
loss_history = []

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    # Add progress bar for batches
    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch in train_loader_tqdm:
        images = batch['image'].to(device)
        caption_ids = batch['caption_ids'].to(device)
        caption_mask = batch['caption_mask'].to(device)
        neg_caption_ids = batch['neg_caption_ids'].to(device)
        neg_caption_mask = batch['neg_caption_mask'].to(device)

        
        pos_logits, neg_logits = model(images, caption_ids, caption_mask, neg_caption_ids, neg_caption_mask)

        loss = -torch.mean(
            torch.log(torch.sigmoid(pos_logits)) + 
            torch.log(1 - torch.sigmoid(neg_logits))
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Update progress bar with loss
        train_loader_tqdm.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader)
    loss_history.append(avg_loss)
    
    print(f"Epoch [{epoch+1}/{num_epochs}] - Avg Loss: {avg_loss:.4f}")


Epoch 1/10:   4%|▍         | 704/18493 [11:58<3:30:04,  1.41it/s, loss=0.747]

In [None]:
import matplotlib.pyplot as plt

plt.plot(range(num_epochs), loss_history, color='red', label='Training Loss')
plt.title("Training Los") 
plt.xlabel("Epochs")  
plt.grid(True)
plt.show()