## Summary
The goal of this project is to develop a multimodal model that combines ResNet50, a powerful deep neural network for feature extraction from images, and GPT-2, an advanced language model, to generate captions for images. We will be utilizing the Flickr30k dataset in English, a large image dataset paired with textual descriptions, to train and fine-tune the model. </br>
Specifically, the ResNet50 model will be used to extract features from the images, while GPT-2 will be fine-tuned to generate relevant captions based on these extracted features

### Extracr feature using ResNet50

In [None]:
import torch
from torchvision import models, transforms
from PIL import Image

# change to gpu if it avaiable
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#load pretrained ResNet
resnet = models.resnet50(pretrained=True)
resnet = torch.nn.Sequential(*list(resnet.children())[:-1])
resnet = resnet.to(device).eval()

#processing image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

#extract features
def extract_image_features(image_path):
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        features = resnet(image_tensor).squeeze(-1).squeeze(-1)  # [2048]
    return features



### Tokenize dataset

In [None]:
from transformers import GPT2Tokenizer

#tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

#tokenizing the dataset
from torch.utils.data import Dataset
class ImageCaptionDataset(Dataset):
  def __init__(self, image_dir, caption_file):
        self.image_dir = image_dir
        self.samples = []
        with open(caption_file, 'r') as f:
            for line in f:
                img_name, caption = line.strip().split('|')
                self.samples.append((img_name, caption))
  def __len__(self):
        return len(self.samples)
  def __getitem__(self, idx):
        img_name, caption = self.samples[idx]
        image_path = os.path.join(self.image_dir, img_name)
        img_feat = extract_image_features(image_path)

        tokens = tokenizer(caption, return_tensors='pt', padding="max_length", truncation=True, max_length=30)
        return img_feat, tokens['input_ids'].squeeze(0), tokens['attention_mask'].squeeze(0)



## Combine resnet and gpt2

In [None]:
import torch.nn as nn
from transformers import GPT2LMHeadModel

class ImageCaptioningModel(nn.Module):
  def __init__(self):
        super().__init__()
        self.gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
        self.gpt2.resize_token_embeddings(len(tokenizer))

        self.image_proj = nn.Linear(2048, self.gpt2.config.n_embd)
        self.prefix_length = 1

  def forward(self, image_features, input_ids, attention_mask):
        batch_size = image_features.size(0)
        #embedding
        prefix_embeddings = self.image_proj(image_features).unsqueeze(1)  # (B, 1, D)
        gpt_embeddings = self.gpt2.transformer.wte(input_ids)
        full_embeddings = torch.cat([prefix_embeddings, gpt_embeddings], dim=1)
        #
        extended_attention_mask = torch.cat([torch.ones((batch_size, self.prefix_length), device=image_features.device), attention_mask], dim=1)

        outputs = self.gpt2(inputs_embeds=full_embeddings, attention_mask=extended_attention_mask, labels=input_ids)
        return outputs


## Train the model

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import os

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths
dataPath = '/content/drive/My Drive/Dataset/flickrDataset'
savePath='/content/drive/My Drive/AI Models/ImageCaptioning'
checkpoint_path = f"{savePath}/caption_checkpoint.pth"

# Dataset + Dataloader
dataset = ImageCaptionDataset(f"{dataPath}/flickr30k_images", f"{dataPath}/result.csv")
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Model + Optimizer
model = ImageCaptioningModel().to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)

# === Load checkpoint if available ===
start_epoch = 0
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f" Loaded checkpoint from epoch {checkpoint['epoch']}")

# === Training Loop ===
for epoch in range(start_epoch, 5):  # Train up to 5 epochs total
    model.train()
    total_loss = 0

    for img_feat, input_ids, attn_mask in dataloader:
        img_feat = img_feat.to(device)
        input_ids = input_ids.to(device)
        attn_mask = attn_mask.to(device)

        outputs = model(img_feat, input_ids, attn_mask)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f" Epoch {epoch + 1}, Loss: {avg_loss:.4f}")

    # === Save checkpoint ===
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, checkpoint_path)
    print(f" Saved checkpoint at epoch {epoch}")


## Save model

In [None]:
savePath='/content/drive/My Drive/AI Models/ImageCaptioning/Model'
model.save_pretrained(savePath)
tokenizer.save_pretrained(savePath)

## Generate caption

In [None]:
modelPath='/content/drive/My Drive/AI Models/ImageCaptioning/Model'
model = ImageCaptioningModel.from_pretrained(modelPath)
tokenizer = GPT2Tokenizer.from_pretrained(modelPath)

def generate_caption(image_path, max_length=30):
  model.eval()
  img_feat = extract_image_features(image_path).unsqueeze(0).to(device)
  prefix_embed = model.image_proj(img_feat).unsqueeze(1)

  generated = torch.full((1, 1), tokenizer.bos_token_id, device=device, dtype=torch.long)
  for _ in range(max_length):
    gpt_embed = model.gpt2.transformer.wte(generated)
    full_embed = torch.cat([prefix_embed, gpt_embed], dim=1)
    attention_mask = torch.ones(full_embed.shape[:2], device=device)

    outputs = model.gpt2(inputs_embeds=full_embed, attention_mask=attention_mask)
    logits = outputs.logits[:, -1, :]
    next_token = torch.argmax(logits, dim=-1).unsqueeze(1)
    generated = torch.cat([generated, next_token], dim=1)
    if next_token.item() == tokenizer.eos_token_id:
        break
  return tokenizer.decode(generated.squeeze(), skip_special_tokens=True)


import matplotlib.pyplot as plt
import matplotlib.image as mpimg
img = mpimg.imread('/content/drive/My Drive/Dataset/flickrDataset/images/1000268201_693b08cb0e.jpg')
plt.imshow(img)
plt.axis('off')
plt.show()
print(generate_caption('/content/drive/My Drive/Dataset/flickrDataset/images/1000268201_693b08cb0e.jpg'))
