In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel, GPT2Tokenizer, GPT2LMHeadModel
import pandas as pd
import os
from PIL import Image
import numpy as np
import nltk

nltk.download("punkt")

# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
EPOCHS = 15
LEARNING_RATE = 5e-5
SEED = 42

torch.manual_seed(SEED)

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


<torch._C.Generator at 0x783710d2ffd0>

In [2]:
import pandas as pd

# Function to load captions into a DataFrame
def load_captions_to_df(captions_file):
    """
    Load the captions file into a Pandas DataFrame.
    Args:
        captions_file: Path to the Flickr8k captions text file.
    Returns:
        DataFrame with two columns: 'image_id' and 'caption'.
    """
    df = pd.read_csv(captions_file, sep=',')
    return df

# Load the captions
CAPTIONS_FILE = "/kaggle/input/flickr8k/captions.txt"
IMAGE_FOLDER = "/kaggle/input/flickr8k/Images"
df_captions = load_captions_to_df(CAPTIONS_FILE)

print(df_captions.head())

                       image  \
0  1000268201_693b08cb0e.jpg   
1  1000268201_693b08cb0e.jpg   
2  1000268201_693b08cb0e.jpg   
3  1000268201_693b08cb0e.jpg   
4  1000268201_693b08cb0e.jpg   

                                             caption  
0  A child in a pink dress is climbing up a set o...  
1              A girl going into a wooden building .  
2   A little girl climbing into a wooden playhouse .  
3  A little girl climbing the stairs to her playh...  
4  A little girl in a pink dress going into a woo...  


In [3]:
class Flickr8kDataset(Dataset):
    def __init__(self, image_folder, captions_df, processor, tokenizer, transform=None):
        self.image_folder = image_folder
        self.captions_df = captions_df
        self.processor = processor
        self.tokenizer = tokenizer
        self.transform = transform

        # Group captions by image ID
        self.image_to_captions = captions_df.groupby("image")["caption"].apply(list).to_dict()
        self.image_files = list(self.image_to_captions.keys())

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_id = self.image_files[idx]
        captions = self.image_to_captions[image_id]

        # Load and preprocess image
        image_path = os.path.join(self.image_folder, image_id)
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # CLIP processor
        image_inputs = self.processor(images=image, return_tensors="pt", padding=True)
        image_inputs = {k: v.squeeze(0) for k, v in image_inputs.items()}

        # Randomly select a caption
        caption = np.random.choice(captions)
        caption_tokens = self.tokenizer.encode(caption, return_tensors="pt").squeeze(0)

        return image_inputs, caption_tokens

def collate_fn(batch):
    image_inputs = [item[0] for item in batch]
    captions = [item[1] for item in batch]

    max_len = max(len(cap) for cap in captions)
    padded_captions = torch.full((len(captions), max_len), gpt_tokenizer.pad_token_id, dtype=torch.long)

    for i, cap in enumerate(captions):
        padded_captions[i, :len(cap)] = cap

    batch_image_inputs = {}
    for key in image_inputs[0].keys():
        batch_image_inputs[key] = torch.stack([img[key] for img in image_inputs])

    return batch_image_inputs, padded_captions


In [4]:
class ImageCaptioningModel(nn.Module):
    def __init__(self, clip_model_name="openai/clip-vit-base-patch32", gpt_model_name="gpt2"):
        super(ImageCaptioningModel, self).__init__()
        self.clip_model = CLIPModel.from_pretrained(clip_model_name).vision_model
        self.gpt_model = GPT2LMHeadModel.from_pretrained(gpt_model_name)
        self.clip_to_gpt = nn.Linear(768, self.gpt_model.config.hidden_size)

    def forward(self, image_inputs, captions):
        # Extract CLIP image embeddings
        image_features = self.clip_model(**image_inputs).pooler_output
        image_features = self.clip_to_gpt(image_features).unsqueeze(1)

        # Embed captions
        caption_embeddings = self.gpt_model.transformer.wte(captions)

        # Combine image and caption embeddings
        inputs_embeds = torch.cat([image_features, caption_embeddings], dim=1)
        outputs = self.gpt_model(inputs_embeds=inputs_embeds)
        return outputs

In [None]:
# Initialize Processor, Tokenizer, and Dataset
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
gpt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

dataset = Flickr8kDataset(IMAGE_FOLDER, df_captions, clip_processor, gpt_tokenizer)
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

# Initialize Model, Optimizer, and Criterion
model = ImageCaptioningModel().to(DEVICE)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=gpt_tokenizer.pad_token_id)

# Training Loop
def train(model, dataloader, optimizer, criterion, epoch):
    model.train()
    total_loss = 0

    for batch in dataloader:
        image_inputs, captions = batch
        image_inputs = {k: v.to(DEVICE) for k, v in image_inputs.items()}
        captions = captions.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(image_inputs, captions)

        # Get logits (model predictions)
        logits = outputs.logits  # Shape: [batch_size, sequence_length + 1, vocab_size]

        # Shift labels to the right
        labels = captions[:, 1:]  # Drop the first token (start token)

        # Trim logits to align with labels
        logits = logits[:, :labels.size(1), :]  # Match logits to labels' length

        # Flatten logits and labels for loss calculation
        logits = logits.reshape(-1, logits.size(-1))  # Shape: (batch_size * sequence_length, vocab_size)
        labels = labels.reshape(-1)  # Shape: (batch_size * sequence_length)

        # Compute loss, ignoring padding tokens
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {average_loss:.4f}")

    # Save the model at the end of the epoch
    model_save_path = f"image_captioning_model_epoch_{epoch+1}.pth"
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved at: {model_save_path}")

    return average_loss


# Run Training and Save the Model
print("Training Begins")
for epoch in range(EPOCHS):
    
    train_loss = train(model, train_loader, optimizer, criterion, epoch)

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Training Begins
Epoch [1/15], Loss: 4.8281
Model saved at: image_captioning_model_epoch_1.pth
Epoch [2/15], Loss: 4.3918
Model saved at: image_captioning_model_epoch_2.pth
Epoch [3/15], Loss: 4.2014
Model saved at: image_captioning_model_epoch_3.pth
Epoch [4/15], Loss: 4.0489
Model saved at: image_captioning_model_epoch_4.pth
Epoch [5/15], Loss: 3.9484
Model saved at: image_captioning_model_epoch_5.pth


In [None]:
def generate_caption(model, image_path, processor, tokenizer, max_length=20):
    model.eval()
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        image_features = model.clip_model(**inputs).pooler_output
        image_features = model.clip_to_gpt(image_features).unsqueeze(1)

        generated = torch.tensor([[tokenizer.bos_token_id]], device=DEVICE)
        for _ in range(max_length):
            gpt_input = torch.cat([image_features, model.gpt_model.transformer.wte(generated)], dim=1)
            outputs = model.gpt_model(inputs_embeds=gpt_input)
            next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1)
            generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
            if next_token.item() == tokenizer.eos_token_id:
                break

    return tokenizer.decode(generated.squeeze(), skip_special_tokens=True)

# Test
test_image_path = "/kaggle/input/flickr8k/Images/1001773457_577c3a7d70.jpg"
caption = generate_caption(model, test_image_path, clip_processor, gpt_tokenizer)
print("Generated Caption:", caption)