In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from encoders import CLIPResNetEncoder, CLIPViTEncoder
from transformers import AutoTokenizer, AutoModel

In [None]:
class CLIPGPT2Encoder(nn.Module):
    def __init__(self, projection_dim=512):
        super().__init__()
        
        self.transformer = AutoModel.from_pretrained('gpt2')
        self.tokenizer = AutoTokenizer.from_pretrained('gpt2')
        
        self.projection = nn.Sequential(
            nn.LayerNorm(768),
            nn.Linear(768, projection_dim)
        )
        
        self.max_length = 77

    def forward(self, input_ids, attention_mask=None):
        outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        # Mean pooling
        last_hidden_state = outputs.last_hidden_state
        pooled_output = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(1).unsqueeze(-1)
        
        projected = self.projection(pooled_output)
        text_features = projected / projected.norm(dim=-1, keepdim=True)
        
        return text_features

In [None]:
class CLIPModel(nn.Module):
    def __init__(self, image_encoder, text_encoder, temperature=0.07):
        super(CLIPModel, self).__init__()
        
        # Assuming projection heads are already incorporated in encoders
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        
        # Temperature scaling constant (typically a small value, 0.07 in CLIP)
        self.temperature = temperature

    def forward(self, images, texts):
        # Encode the images and texts
        image_features = self.image_encoder(images)  # Shape: (batch_size, feature_dim)
        text_features = self.text_encoder(texts)    # Shape: (batch_size, feature_dim)
        
        # Normalize both features for cosine similarity calculation
        image_features = F.normalize(image_features, p=2, dim=-1)
        text_features = F.normalize(text_features, p=2, dim=-1)
        
        return image_features, text_features

    def compute_loss(self, image_features, text_features):
        # Cosine similarity between image and text features
        logits_per_image = image_features @ text_features.T  # Shape: (batch_size, batch_size)
        logits_per_text = text_features @ image_features.T  # Shape: (batch_size, batch_size)

        # Apply temperature scaling
        logits_per_image /= self.temperature
        logits_per_text /= self.temperature
        
        # Labels are the diagonal (the matching image-text pairs)
        labels = torch.arange(image_features.size(0)).to(image_features.device)
        
        # Cross-entropy loss (image -> text and text -> image)
        loss_img2txt = F.cross_entropy(logits_per_image, labels)
        loss_txt2img = F.cross_entropy(logits_per_text, labels)
        
        # Total loss is the average of both directions
        total_loss = (loss_img2txt + loss_txt2img) / 2
        return total_loss