In [None]:
pip install transformers

In [None]:
pip install pot

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import tensorflow as tf
import os
import json
import random
from PIL import Image

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from transformers import BertTokenizer, BertModel
from torch.utils.data import DataLoader, Dataset
import ot

In [None]:
def draw_samples(df):
    fig, axes = plt.subplots(5, 2, figsize=(12, 24))
    
    sample_data = df.groupby("image_filepath")["caption"].agg(list).iloc[:5]
    
    for i, (index, sample) in enumerate(sample_data.items()):
        img = Image.open(index)
        axes[i, 0].imshow(img)
        axes[i, 0].axis("off")
        for j, cap in enumerate(sample):
            axes[i, 1].text(0, 0.9 - 0.2 * j, cap, fontsize=14)
        axes[i, 1].axis("off")

In [None]:
trainval_image_dir = '/kaggle/input/coco-image-caption/train2014/train2014'
trainval_captions_dir = '/kaggle/input/coco-image-caption/annotations_trainval2014/annotations/captions_train2014.json'

test_image_dir = '/kaggle/input/coco-image-caption/val2017/val2017'
test_captions_dir = '/kaggle/input/coco-image-caption/annotations_trainval2017/annotations/captions_val2017.json'

In [None]:
all_filepaths = np.array([os.path.join(trainval_image_dir, f) for f in os.listdir(trainval_image_dir)])
rand_indices = np.arange(len(all_filepaths))
np.random.shuffle(rand_indices)

split = int(len(all_filepaths)*0.8)

train_filepaths, valid_filepaths = all_filepaths[rand_indices[:split]], all_filepaths[rand_indices[split:]] 

print(f"Train dataset size: {len(train_filepaths)}")
print(f"Valid dataset size: {len(valid_filepaths)}")

In [None]:
with open(trainval_captions_dir, 'r') as f:
    trainval_data = json.load(f)
    
trainval_captions_df = pd.json_normalize(trainval_data, "annotations")
trainval_captions_df["image_filepath"] = trainval_captions_df["image_id"].apply(
    lambda x: os.path.join(trainval_image_dir, 'COCO_train2014_' + format(x, '012d') + '.jpg')
)

def preprocess_captions(df):
    df["preprocessed_caption"] = df["caption"].str.lower().str.replace('[^\w\s]', '', regex=True)
    return df

train_captions_df = trainval_captions_df[trainval_captions_df["image_filepath"].isin(train_filepaths)]
train_captions_df = preprocess_captions(train_captions_df)
                                       
valid_captions_df = trainval_captions_df[trainval_captions_df["image_filepath"].isin(valid_filepaths)]
valid_captions_df = preprocess_captions(valid_captions_df)

with open(test_captions_dir, 'r') as f:
    test_data = json.load(f)
    
test_captions_df = pd.json_normalize(test_data, "annotations")
test_captions_df["image_filepath"] = test_captions_df["image_id"].apply(
    lambda x: os.path.join(test_image_dir, format(x, '012d') + '.jpg')
)
test_captions_df = preprocess_captions(test_captions_df)

In [None]:
train_captions_df

In [None]:
draw_samples(train_captions_df)

In [None]:
resnet = torchvision.models.resnet50(pretrained=True)
resnet = nn.Sequential(*list(resnet.children())[:-1])
resnet.train()

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, df, tokenizer, transform):
        self.df = df
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['image_filepath']
        image = Image.open(img_path).convert('RGB')
        image = self.transform(image)
        
        caption = self.df.iloc[idx]['preprocessed_caption']
        inputs = self.tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
        
        return image, inputs.input_ids.squeeze(), inputs.attention_mask.squeeze()

In [None]:
class CaptioningModel(nn.Module):
    def __init__(self, resnet, bert_model):
        super(CaptioningModel, self).__init__()
        self.resnet = resnet
        self.bert = bert_model
        self.fc = nn.Linear(2048, 768)

    def forward(self, image, input_ids, attention_mask):
        image_features = self.resnet(image).view(image.size(0), -1)
        image_features = self.fc(image_features)
        
        bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        text_embeddings = bert_outputs.last_hidden_state.mean(dim=1)
        
        return image_features, text_embeddings

In [None]:
def optimal_transport_loss(image_features, text_embeddings):
    image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
    text_embeddings = text_embeddings / text_embeddings.norm(p=2, dim=-1, keepdim=True)
    
    cost_matrix = torch.cdist(image_features, text_embeddings, p=2).detach().cpu().numpy()

    a = np.ones(image_features.size(0)) / image_features.size(0)
    b = np.ones(text_embeddings.size(0)) / text_embeddings.size(0)
    
    ot_distance = ot.emd2(a, b, cost_matrix)
    
    return torch.tensor(ot_distance, device=image_features.device, requires_grad=True)


In [None]:
subset = train_captions_df[:8000]
# subset = train_captions_df
dataset = ImageCaptioningDataset(subset, tokenizer, transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
captioning_model = CaptioningModel(resnet, bert_model).to(device)
optimizer = torch.optim.Adam(captioning_model.parameters(), lr=0.0001)

num_epochs = 20
for epoch in range(num_epochs):
    for images, input_ids, attention_mask in dataloader:
        images = images.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        
        optimizer.zero_grad()
        
        image_features, text_embeddings = captioning_model(images, input_ids, attention_mask)
        
        loss = optimal_transport_loss(image_features, text_embeddings)
        
        loss.backward()
        optimizer.step()
        
    print(f"Epoch {epoch + 1}/{num_epochs}, OT Loss: {loss.item()}")

In [None]:
captioning_model.eval()

img_file_path = subset.iloc[0]['image_filepath']
img_caption = subset.iloc[0]['preprocessed_caption']

image_ori = Image.open(img_file_path).convert('RGB')
image = transform(image_ori)
image = image.to(device).unsqueeze(0)

text = img_caption
# text = "a woman is playing badminton"
tokenized = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
input_ids = tokenized["input_ids"].to(device)
attention_mask = tokenized["attention_mask"].to(device)

with torch.no_grad():
    image_features, text_embeddings = captioning_model(image, input_ids, attention_mask)

# print("Image Features:", image_features)
# print("Text Embeddings:", text_embeddings)
plt.imshow(image_ori)
plt.axis('off')
plt.show()
print("Caption:", text)
print("Loss:", optimal_transport_loss(image_features, text_embeddings).item())