In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
import numpy as np
import pandas as pd
from PIL import Image
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from collections import Counter


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections for Q, K, V and output
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.wo = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def split_heads(self, x):
        # x shape: (batch_size, seq_len, d_model)
        batch_size, seq_len = x.size(0), x.size(1)
        
        # Reshape to (batch_size, seq_len, num_heads, d_k)
        x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
        
        # Transpose to (batch_size, num_heads, seq_len, d_k)
        return x.transpose(1, 2)
    
    def merge_heads(self, x):
        # x shape: (batch_size, num_heads, seq_len, d_k)
        batch_size, _, seq_len = x.size(0), x.size(1), x.size(2)
        
        # Transpose to (batch_size, seq_len, num_heads, d_k)
        x = x.transpose(1, 2)
        
        # Reshape to (batch_size, seq_len, d_model)
        return x.reshape(batch_size, seq_len, self.d_model)
    
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        
        # Linear projections and split heads
        q = self.split_heads(self.wq(q))  # (batch_size, num_heads, seq_len_q, d_k)
        k = self.split_heads(self.wk(k))  # (batch_size, num_heads, seq_len_k, d_k)
        v = self.split_heads(self.wv(v))  # (batch_size, num_heads, seq_len_v, d_k)
        
        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        # scores shape: (batch_size, num_heads, seq_len_q, seq_len_k)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Apply softmax to get attention weights
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention weights to values
        context = torch.matmul(attn_weights, v)  # (batch_size, num_heads, seq_len_q, d_k)
        
        # Merge heads and apply output projection
        context = self.merge_heads(context)  # (batch_size, seq_len_q, d_model)
        output = self.wo(context)  # (batch_size, seq_len_q, d_model)
        
        return output




In [2]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        return self.linear2(self.dropout(F.relu(self.linear1(x))))




In [3]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        # Self-attention layer
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.self_attn_norm = nn.LayerNorm(d_model)
        
        # Cross-attention layer (encoder-decoder attention)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_attn_norm = nn.LayerNorm(d_model)
        
        # Feed-forward network
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.ffn_norm = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, self_attn_mask=None, cross_attn_mask=None):
        # Self-attention with residual connection and layer norm
        residual = x
        x = self.self_attn_norm(x)
        x = residual + self.dropout(self.self_attn(x, x, x, self_attn_mask))
        
        # Cross-attention with residual connection and layer norm
        residual = x
        x = self.cross_attn_norm(x)
        x = residual + self.dropout(self.cross_attn(x, enc_output, enc_output, cross_attn_mask))
        
        # Feed-forward network with residual connection and layer norm
        residual = x
        x = self.ffn_norm(x)
        x = residual + self.dropout(self.ffn(x))
        
        return x




In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        # Register as buffer (not a parameter but part of the module)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        return x + self.pe[:, :x.size(1), :]




In [5]:
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_len)
        
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.final_norm = nn.LayerNorm(d_model)
        self.output_projection = nn.Linear(d_model, vocab_size)
        
        self.dropout = nn.Dropout(dropout)
        
        # Initialize parameters
        self._init_parameters()
        
    def _init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def forward(self, x, enc_output, self_attn_mask=None, cross_attn_mask=None):
        # x shape: (batch_size, seq_len)
        
        # Embedding and positional encoding
        x = self.embedding(x) * math.sqrt(self.d_model)  # (batch_size, seq_len, d_model)
        x = self.positional_encoding(x)
        x = self.dropout(x)
        
        # Apply decoder layers
        for layer in self.decoder_layers:
            x = layer(x, enc_output, self_attn_mask, cross_attn_mask)
        
        # Final layer norm
        x = self.final_norm(x)
        
        # Output projection
        logits = self.output_projection(x)
        
        return logits




In [6]:
def generate_square_subsequent_mask(sz):
    """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
    Unmasked positions are filled with float(0.0).
    """
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask




In [7]:

class Flickr8kDataset(Dataset):
    def __init__(self, captions_file, images_dir, feature_dir=None, transform=None, max_len=50):
        """
        Args:
            captions_file (string): Path to the captions file.
            images_dir (string): Directory with all the images.
            feature_dir (string): Directory with precomputed ResNet50 features.
            transform (callable, optional): Optional transform to be applied on an image.
            max_len (int): Maximum caption length.
        """
        self.images_dir = images_dir
        self.feature_dir = feature_dir
        self.transform = transform
        self.max_len = max_len
        
        # Read captions from text file
        self.captions_data = []
        with open(captions_file, 'r') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) >= 2:  # Ensure there's an image ID and caption
                    # In Flickr8k.token.txt, the format is "image_id#caption_number caption_text"
                    image_id_with_num = parts[0]
                    image_id = image_id_with_num.split('#')[0]  # Remove caption number
                    caption = parts[1]
                    self.captions_data.append({
                        'image': image_id,  # Store just the image ID without the caption number
                        'caption': caption
                    })
        
        # Build vocabulary
        self.word_to_idx, self.idx_to_word, self.vocab_size = self._build_vocab()
        
        # Special tokens
        self.start_token = "<START>"
        self.end_token = "<END>"
        self.pad_token = "<PAD>"
        self.unk_token = "<UNK>"
        
        # Process captions
        self.captions = self._preprocess_captions()
    
    def _build_vocab(self, threshold=4):
        # Count word frequency
        counter = Counter()
        for item in self.captions_data:
            counter.update(item['caption'].lower().split())

        # Filter words below threshold
        words = [word for word, count in counter.items() if count >= threshold]

        # Create mappings
        word_to_idx = {
            "<PAD>": 0,
            "<START>": 1,
            "<END>": 2,
            "<UNK>": 3
        }

        # Add words to dictionary
        for i, word in enumerate(words):
            word_to_idx[word] = i + 4

        idx_to_word = {idx: word for word, idx in word_to_idx.items()}
        vocab_size = len(word_to_idx)

        return word_to_idx, idx_to_word, vocab_size
    
    def _preprocess_captions(self):
        processed_captions = []
        
        for item in self.captions_data:
            caption = item['caption'].lower().split()
            
            # Add start and end tokens
            caption = [self.start_token] + caption + [self.end_token]
            
            # Convert words to indices
            caption_indices = []
            for word in caption:
                if word in self.word_to_idx:
                    caption_indices.append(self.word_to_idx[word])
                else:
                    caption_indices.append(self.word_to_idx[self.unk_token])
            
            # Pad caption if necessary
            if len(caption_indices) < self.max_len:
                caption_indices.extend([self.word_to_idx[self.pad_token]] * (self.max_len - len(caption_indices)))
            else:
                caption_indices = caption_indices[:self.max_len]
                if caption_indices[-1] != self.word_to_idx[self.end_token]:
                    caption_indices[-1] = self.word_to_idx[self.end_token]
            
            processed_captions.append({
                'image_id': item['image'],
                'caption': torch.tensor(caption_indices)
            })
        
        return processed_captions
    
    def __len__(self):
        return len(self.captions)
    
    def __getitem__(self, idx):
        item = self.captions_data[idx]
        image_id = item['image']
        caption = item['caption']

        try:
            image_name = image_id.split('#')[0]

            if self.feature_dir:
                feature_path = os.path.join(self.feature_dir, f"{os.path.splitext(image_name)[0]}.npy")
                img_feature = np.load(feature_path)
                img_tensor = torch.tensor(img_feature, dtype=torch.float)
            else:
                image_path = os.path.join(self.images_dir, image_name)
                image = Image.open(image_path).convert('RGB')
                if self.transform:
                    image = self.transform(image)
                img_tensor = image

            tokens = caption.lower().split()
            tokens = [self.start_token] + tokens + [self.end_token]
            caption_idx = [self.word_to_idx.get(word, self.word_to_idx[self.unk_token]) for word in tokens]

            if len(caption_idx) < self.max_len:
                caption_idx += [self.word_to_idx[self.pad_token]] * (self.max_len - len(caption_idx))
            else:
                caption_idx = caption_idx[:self.max_len]

            caption_tensor = torch.tensor(caption_idx)

            return img_tensor, caption_tensor

        except FileNotFoundError:
            print(f"[WARNING] File not found, skipping: {image_id}")
            return None

    
    def _extract_image_features(self, image_id):
        """Extract features from an image using ResNet50."""
        image_path = os.path.join(self.images_dir, image_id)
        image = Image.open(image_path).convert('RGB')
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        else:
            # Default transformation
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
            image = transform(image)
        
        # Add batch dimension and extract features
        with torch.no_grad():
            model = models.resnet50(pretrained=True)
            # Remove the final fully connected layer
            model = nn.Sequential(*list(model.children())[:-1])
            model.eval()
            
            # Extract features
            features = model(image.unsqueeze(0))
            features = features.squeeze()
            
        return features

def custom_collate_fn(batch):
    # Skip Nones
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None, None
    images, captions = zip(*batch)
    return torch.stack(images), torch.stack(captions)




In [8]:
class ImageCaptioningModel(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout=0.1):
        super().__init__()
        
        # Feature projection layer (from ResNet50 features to decoder dimension)
        self.feature_projection = nn.Linear(2048, d_model)
        
        # Decoder
        self.decoder = TransformerDecoder(vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout)
    
    def forward(self, img_features, captions, self_attn_mask=None):
        # Project image features to d_model dimensions
        # Reshape img_features to ensure it's a tensor with [batch_size, feature_dim]
        batch_size = img_features.size(0)
        if img_features.dim() == 3:  # If features have spatial dimensions [batch, spatial, feature_dim]
            img_features = img_features.mean(dim=1)  # Average pooling over spatial dimensions
        
        if img_features.dim() == 1:  # Single feature vector (no batch)
            img_features = img_features.unsqueeze(0)  # Add batch dimension
            
        enc_output = self.feature_projection(img_features)
        
        # Add sequence dimension for encoder output - one "token" per image
        enc_output = enc_output.unsqueeze(1)  # [batch_size, 1, d_model]
        
        # Create cross attention mask (all ones, as we attend to the single image feature vector)
        cross_attn_mask = torch.ones(captions.size(1), 1)
        
        # Forward through decoder
        output = self.decoder(captions, enc_output, self_attn_mask, cross_attn_mask)
        
        return output


def prepare_flickr8k_data():
    """Prepare Flickr8k dataset for image captioning."""
    # Define paths
    captions_file = "Flickr8k_text/Flickr8k.token.txt"
    images_dir = "Flickr8k_Dataset/"
    feature_dir = "Features/"
    
    # Create feature directory if it doesn't exist
    os.makedirs(feature_dir, exist_ok=True)
    
    # Define image transformations
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Create dataset
    dataset = Flickr8kDataset(captions_file, images_dir, feature_dir, transform)
    
    # Extract features for all images if not already done
    print("Extracting features from images...")
    model = models.resnet50(pretrained=True)
    # Remove the final fully connected layer
    feature_extractor = nn.Sequential(*list(model.children())[:-1])
    feature_extractor.eval()
    
    # Get unique image IDs
    image_ids = set()
    for item in dataset.captions_data:
        image_ids.add(item['image'])
    image_ids = list(image_ids)
    
    for image_id in image_ids:
        feature_path = os.path.join(feature_dir, f"{image_id.split('.')[0]}.npy")
        if not os.path.exists(feature_path):
            try:
                # Load and transform image
                image_path = os.path.join(images_dir, image_id)
                image = Image.open(image_path).convert('RGB')
                image = transform(image)
                
                # Extract features
                with torch.no_grad():
                    features = feature_extractor(image.unsqueeze(0))
                    features = features.squeeze().numpy()
                
                # Save features
                np.save(feature_path, features)
            except FileNotFoundError:
                print(f"Warning: Could not find image file {image_path}")
                continue
    
    print(f"Features extracted for {len(image_ids)} images")
    
    # Create data loaders
    train_size = int(0.8 * len(dataset))
    valid_size = len(dataset) - train_size
    train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size])
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn= custom_collate_fn)
    valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn= custom_collate_fn)
    
    return train_loader, valid_loader, dataset.vocab_size


def train_image_captioning_model():
    """Train the image captioning model."""
    # Prepare data
    train_loader, valid_loader, vocab_size = prepare_flickr8k_data()
    
    # Model parameters
    d_model = 512
    num_heads = 8
    num_layers = 6
    d_ff = 2048
    max_seq_len = 50
    dropout = 0.1
    
    # Create model
    model = ImageCaptioningModel(vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout)
    
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding tokens
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    
    # Training loop
    num_epochs = 1
    device = torch.device('cpu')
    model = model.to(device)
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        
        for img_features, captions in train_loader:
            if img_features is None or captions is None:
                continue  # Skip batch if it's empty due to all files being missing
            img_features = img_features.to(device)
            captions = captions.to(device)
            
            # Create target (shifted by 1 to the right)
            targets = captions[:, 1:].contiguous()
            inputs = captions[:, :-1].contiguous()
            
            # Create self-attention mask (causal mask for autoregressive generation)
            self_attn_mask = generate_square_subsequent_mask(inputs.size(1)).to(device)
            
            # Forward pass
            outputs = model(img_features, inputs, self_attn_mask)
            
            # Calculate loss
            loss = criterion(outputs.reshape(-1, vocab_size), targets.reshape(-1))
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Validation
        model.eval()
        valid_loss = 0
        
        with torch.no_grad():
            for img_features, captions in valid_loader:
                img_features = img_features.to(device)
                captions = captions.to(device)
                
                # Create target (shifted by 1 to the right)
                targets = captions[:, 1:].contiguous()
                inputs = captions[:, :-1].contiguous()
                
                # Create self-attention mask
                self_attn_mask = generate_square_subsequent_mask(inputs.size(1)).to(device)
                
                # Forward pass
                outputs = model(img_features, inputs, self_attn_mask)
                
                # Calculate loss
                loss = criterion(outputs.reshape(-1, vocab_size), targets.reshape(-1))
                valid_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Valid Loss: {valid_loss/len(valid_loader):.4f}")
    
    # Save model
    torch.save(model.state_dict(), "image_captioning_model.pth")
    print("Training complete!")




In [9]:
def generate_caption(model, image_path, dataset, device, max_length=50):
    """Generate a caption for a given image."""
    # Load and transform image
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    
    # Extract features
    with torch.no_grad():
        feature_extractor = nn.Sequential(*list(models.resnet50(pretrained=True).children())[:-1])
        feature_extractor.eval()
        features = feature_extractor(image.unsqueeze(0))
        features = features.squeeze()
    
    # Move to device
    model = model.to(device)
    features = features.to(device)
    
    # Generate caption
    model.eval()
    with torch.no_grad():
        # Start with start token
        caption = [dataset.word_to_idx["<START>"]]
        
        for i in range(max_length):
            # Convert caption tokens to tensor
            caption_tensor = torch.tensor([caption]).to(device)
            
            # Create mask
            mask = generate_square_subsequent_mask(len(caption)).to(device)
            
            # Predict next word
            output = model(features.unsqueeze(0), caption_tensor, mask)
            predictions = output[:, -1, :]
            predicted_id = torch.argmax(predictions, dim=-1).item()
            
            # Add prediction to caption
            caption.append(predicted_id)
            
            # Stop if end token is predicted
            if predicted_id == dataset.word_to_idx["<END>"]:
                break
    
    # Convert indices to words
    words = [dataset.idx_to_word[idx] for idx in caption if idx not in [dataset.word_to_idx["<START>"], dataset.word_to_idx["<END>"], dataset.word_to_idx["<PAD>"]]]
    
    return " ".join(words)




In [None]:
# Main execution
if __name__ == "__main__":
    # Prepare data and train model
    train_image_captioning_model()
    
    # Load trained model for inference
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Define model parameters
    d_model = 512
    num_heads = 8
    num_layers = 6
    d_ff = 2048
    max_seq_len = 50
    dropout = 0.1
    
    # Create dataset to get vocabulary info
    captions_file = "Flickr8k_text/Flickr8k.token.txt"
    images_dir = "Flickr8k_Dataset/"
    dataset = Flickr8kDataset(captions_file, images_dir)
    
    # Load model
    model = ImageCaptioningModel(dataset.vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout)
    model.load_state_dict(torch.load("image_captioning_model.pth"))
    
    # Generate a caption for a test image
    test_image_path = os.path.join(images_dir, "241347664_4a3e7e5be7.jpg")
    caption = generate_caption(model, test_image_path, dataset, device)
    print(f"Generated caption: {caption}")

Generated caption: three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three three
