In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, models
from transformers import BertTokenizer, BertModel
from torch.utils.data import DataLoader
from PIL import Image

# Contrastive Loss Implementation (from previous step)
class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, text_embeddings, image_embeddings):
        text_embeddings = F.normalize(text_embeddings, p=2, dim=1)
        image_embeddings = F.normalize(image_embeddings, p=2, dim=1)
        similarity_matrix = torch.matmul(text_embeddings, image_embeddings.T)
        positives = torch.diag(similarity_matrix)

        text_to_image_loss = -torch.log(torch.exp(positives / self.temperature) /
                                        torch.exp(similarity_matrix / self.temperature).sum(dim=1))
        image_to_text_loss = -torch.log(torch.exp(positives / self.temperature) /
                                        torch.exp(similarity_matrix.T / self.temperature).sum(dim=1))
        loss = (text_to_image_loss + image_to_text_loss).mean()
        return loss

# Data Preparation
# Example pre-processing for Flickr30k images
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images
    transforms.ToTensor(),          # Convert to Tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # Normalize
])

# Dummy Flickr30k-like Dataset Loader
class TextImageDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, captions, transform, tokenizer):
        self.image_paths = image_paths
        self.captions = captions
        self.transform = transform
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        image = self.transform(image)
        caption = self.captions[idx]
        tokens = self.tokenizer(caption, padding='max_length', max_length=32, truncation=True, return_tensors='pt')
        return image, tokens['input_ids'].squeeze(0), tokens['attention_mask'].squeeze(0)

# Example Data
image_paths = ["path_to_image1.jpg", "path_to_image2.jpg", "path_to_image3.jpg"]
captions = ["A man riding a bike.", "A dog playing with a ball.", "A group of people hiking."]

# Tokenizer and Dataset
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = TextImageDataset(image_paths, captions, transform, tokenizer)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Image Encoder (ResNet) and Text Encoder (BERT)
image_encoder = models.resnet18(pretrained=True)
image_encoder.fc = nn.Identity()  # Remove final classification layer

text_encoder = BertModel.from_pretrained('bert-base-uncased')

# Training Loop
contrastive_loss = ContrastiveLoss()
optimizer = torch.optim.Adam(list(image_encoder.parameters()) + list(text_encoder.parameters()), lr=1e-4)

for epoch in range(5):  # Example for 5 epochs
    for images, input_ids, attention_masks in dataloader:
        # Image embeddings
        image_features = image_encoder(images)

        # Text embeddings
        text_features = text_encoder(input_ids, attention_mask=attention_masks).pooler_output

        # Compute Loss
        loss = contrastive_loss(text_features, image_features)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}")

FileNotFoundError: [Errno 2] No such file or directory: '/content/path_to_image1.jpg'