In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
from torchvision import transforms
from PIL import Image
import random
import json
from torchvision.datasets import CocoCaptions
from torchvision.models import vit_b_16, ViT_B_16_Weights

image_dir = '/kaggle/input/coco-2017-dataset/coco2017/train2017'
ann_file = '/kaggle/input/coco-2017-dataset/coco2017/annotations/captions_train2017.json'


class COCOAlignmentDataset(Dataset):
    def __init__(self, img_dir, annotations_file, transform=None):
        self.img_dir = img_dir
        with open(annotations_file, 'r') as f:
            self.annotations = json.load(f)
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomCrop(224, padding=4),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        
    def __len__(self):
        return len(self.annotations['annotations'])
    
    def __getitem__(self, idx):
        ann = self.annotations['annotations'][idx]
        img_path = f"{self.img_dir}/{ann['image_id']:012d}.jpg"
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)
        
        # Anchor caption
        caption = ann['caption']
        encoding = self.tokenizer(caption, padding='max_length', 
                                truncation=True, max_length=64,
                                return_tensors='pt')
        
        # Get negative caption (random from dataset)
        neg_idx = random.choice([i for i in range(len(self)) if i != idx])
        neg_caption = self.annotations['annotations'][neg_idx]['caption']
        neg_encoding = self.tokenizer(neg_caption, padding='max_length',
                                    truncation=True, max_length=64,
                                    return_tensors='pt')
        
        return {
            'image': image,
            'caption_ids': encoding['input_ids'].squeeze(0),
            'caption_mask': encoding['attention_mask'].squeeze(0),
            'neg_caption_ids': neg_encoding['input_ids'].squeeze(0),
            'neg_caption_mask': neg_encoding['attention_mask'].squeeze(0)
        }

class ImageEncoder(nn.Module):
    def __init__(self, out_dim):
        super().__init__()
        weights = ViT_B_16_Weights.IMAGENET1K_V1
        self.model = vit_b_16(weights=weights)
        for param in self.model.parameters():
            param.requires_grad = False

        # Replace the head to output the desired dimension
        self.model.heads.head = nn.Sequential(
            nn.Linear(self.model.heads.head.in_features, out_dim),  # Linear layer to `out_dim`
            nn.LayerNorm(out_dim)  # Layer normalization for better stability
        )

    def forward(self, x):
        return self.model(x)

class TextEncoder(nn.Module):
    def __init__(self, out_dim):
        super().__init__()
        self.model = BertModel.from_pretrained('bert-base-uncased')
        self.projection = nn.Sequential(
            nn.Linear(768, out_dim),
            nn.LayerNorm(out_dim)
        )
        
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, 
                           attention_mask=attention_mask)
        return self.projection(outputs.pooler_output)