# Image Captioning Training Notebook

This notebook demonstrates how to train the image captioning model on datasets like Flickr8k, Flickr30k, or COCO Captions.

## Overview
- Data loading and preprocessing
- Model training setup
- Training loop implementation
- Evaluation and visualization

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import os
import json
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Dataset Preparation

For this demo, we'll create a mock dataset. In practice, you would use:
- Flickr8k: 8,000 images with 5 captions each
- Flickr30k: 31,000 images with 5 captions each
- COCO Captions: 330,000 images with 5 captions each

In [None]:
class ImageCaptionDataset(Dataset):
    """Custom Dataset for Image Captioning"""
    
    def __init__(self, image_dir, captions_file, transform=None, vocab=None):
        self.image_dir = image_dir
        self.transform = transform
        
        # Load captions
        with open(captions_file, 'r') as f:
            self.captions = json.load(f)
        
        # Build vocabulary if not provided
        if vocab is None:
            self.vocab = self.build_vocab()
        else:
            self.vocab = vocab
        
        # Image paths
        self.image_paths = list(self.captions.keys())
    
    def build_vocab(self, threshold=5):
        """Build vocabulary from captions"""
        counter = Counter()
        
        for captions in self.captions.values():
            for caption in captions:
                tokens = str(caption).lower().split()
                counter.update(tokens)
        
        # Create vocabulary
        vocab = {'<pad>': 0, '<start>': 1, '<end>': 2, '<unk>': 3}
        
        for word, count in counter.items():
            if count >= threshold:
                vocab[word] = len(vocab)
        
        return vocab
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.image_dir, self.image_paths[idx])
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        # Get random caption for this image
        captions = self.captions[self.image_paths[idx]]
        caption = np.random.choice(captions)
        
        # Convert caption to tokens
        tokens = ['<start>'] + str(caption).lower().split() + ['<end>']
        caption_ids = [self.vocab.get(token, self.vocab['<unk>']) for token in tokens]
        
        return image, torch.tensor(caption_ids)

# Create mock dataset for demonstration
def create_mock_dataset(num_samples=100):
    """Create a mock dataset for demonstration"""
    mock_captions = {}
    
    sample_captions = [
        "A dog is running in the park",
        "A cat is sitting on the couch",
        "People are walking on the beach",
        "A car is driving on the road",
        "Trees are standing in the forest",
        "A bird is flying in the sky",
        "Children are playing in the playground",
        "A house is built on the hill",
        "Flowers are blooming in the garden",
        "A boat is sailing on the water"
    ]
    
    for i in range(num_samples):
        img_name = f"image_{i:04d}.jpg"
        # Each image has 5 captions
        mock_captions[img_name] = np.random.choice(sample_captions, 5, replace=True).tolist()
    
    return mock_captions

# Save mock dataset
mock_data = create_mock_dataset(100)
with open('mock_captions.json', 'w') as f:
    json.dump(mock_data, f, indent=2)

print("Mock dataset created with 100 images and 5 captions each")

## Data Transforms and DataLoader

In [None]:
# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

# Create dataset
dataset = ImageCaptionDataset(
    image_dir='sample_images',  # You would use actual image directory
    captions_file='mock_captions.json',
    transform=transform
)

print(f"Vocabulary size: {len(dataset.vocab)}")
print(f"Dataset size: {len(dataset)}")

# Create data loader
def collate_fn(data):
    """Custom collate function for padding"""
    images, captions = zip(*data)
    
    # Stack images
    images = torch.stack(images, 0)
    
    # Pad captions
    lengths = [len(cap) for cap in captions]
    max_length = max(lengths)
    
    padded_captions = torch.zeros(len(captions), max_length).long()
    for i, cap in enumerate(captions):
        padded_captions[i, :len(cap)] = cap
    
    return images, padded_captions, lengths

# Create DataLoader
data_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn
)

print(f"DataLoader created with batch size 32")

## Model Initialization

In [None]:
# Import model components
from model import ImageCaptioningModel

# Initialize model
model = ImageCaptioningModel(
    encoder_type='resnet',
    decoder_type='lstm',  # or 'transformer'
    embed_size=256,
    hidden_size=512,
    vocab_size=len(dataset.vocab),
    fine_tune_encoder=False
).to(device)

print(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

## Training Setup

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab['<pad>'])

# Optimizer
optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=0.001,
    weight_decay=1e-4
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Training parameters
num_epochs = 20
print_every = 10
save_every = 5

print("Training setup completed")

## Training Loop

In [None]:
def train_epoch(model, data_loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    
    for batch_idx, (images, captions, lengths) in enumerate(tqdm(data_loader)):
        images = images.to(device)
        captions = captions.to(device)
        
        # Forward pass
        outputs = model(images, captions[:, :-1], lengths)
        
        # Calculate loss
        # Reshape for cross entropy loss
        batch_size, seq_len, vocab_size = outputs.shape
        outputs = outputs.reshape(batch_size * seq_len, vocab_size)
        targets = captions[:, 1:].reshape(batch_size * seq_len)
        
        loss = criterion(outputs, targets)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Print progress
        if batch_idx % print_every == 0:
            print(f'Batch [{batch_idx}/{len(data_loader)}], Loss: {loss.item():.4f}')
    
    return total_loss / len(data_loader)

def validate(model, data_loader, criterion, device):
    """Validate the model"""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for images, captions, lengths in data_loader:
            images = images.to(device)
            captions = captions.to(device)
            
            outputs = model(images, captions[:, :-1], lengths)
            
            batch_size, seq_len, vocab_size = outputs.shape
            outputs = outputs.reshape(batch_size * seq_len, vocab_size)
            targets = captions[:, 1:].reshape(batch_size * seq_len)
            
            loss = criterion(outputs, targets)
            total_loss += loss.item()
    
    return total_loss / len(data_loader)

print("Training functions defined")

## Start Training

In [None]:
# Training loop
train_losses = []
val_losses = []

print("Starting training...")
print("Note: This is a demonstration with mock data.")
print("For real training, use actual image datasets.")

for epoch in range(num_epochs):
    print(f'\nEpoch [{epoch+1}/{num_epochs}]')
    print('-' * 50)
    
    # Train
    train_loss = train_epoch(model, data_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    
    # Validate
    val_loss = validate(model, data_loader, criterion, device)
    val_losses.append(val_loss)
    
    # Update learning rate
    scheduler.step()
    
    print(f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
    print(f'Learning Rate: {scheduler.get_last_lr()[0]:.6f}')
    
    # Save model checkpoint
    if (epoch + 1) % save_every == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'vocab': dataset.vocab,
            'train_losses': train_losses,
            'val_losses': val_losses
        }
        torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')
        print(f'Checkpoint saved: checkpoint_epoch_{epoch+1}.pth')

print('\nTraining completed!')

# Save final model
torch.save({
    'model_state_dict': model.state_dict(),
    'vocab': dataset.vocab,
    'config': {
        'embed_size': 256,
        'hidden_size': 512,
        'decoder_type': 'lstm'
    }
}, 'final_model.pth')

print('Final model saved: final_model.pth')

## Training Visualization

In [None]:
# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(train_losses, label='Training Loss', color='blue')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)

plt.tight_layout()
plt.show()

print("Training curves plotted")

## Inference Demo

In [None]:
# Load trained model for inference
def load_model(checkpoint_path, vocab):
    """Load trained model from checkpoint"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model = ImageCaptioningModel(
        encoder_type='resnet',
        decoder_type='lstm',
        embed_size=256,
        hidden_size=512,
        vocab_size=len(vocab),
        fine_tune_encoder=False
    ).to(device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    return model

# Generate caption for an image
def generate_caption(model, image_path, vocab, transform, max_length=20):
    """Generate caption for a single image"""
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Generate caption
    with torch.no_grad():
        caption_ids = model.generate_caption(image_tensor, max_length)
    
    # Convert IDs to words
    idx_to_word = {v: k for k, v in vocab.items()}
    caption_words = []
    
    for idx in caption_ids:
        if idx in idx_to_word:
            word = idx_to_word[idx]
            if word == '<end>':
                break
            elif word not in ['<start>', '<pad>', '<unk>']:
                caption_words.append(word)
    
    return ' '.join(caption_words).capitalize() + '.'

print("Inference functions defined")
print("Note: For actual inference, you need:")
print("1. Trained model checkpoint")
print("2. Real images in sample_images directory")
print("3. Vocabulary from training")

## Evaluation Metrics

In [None]:
# BLEU Score calculation (simplified)
from collections import Counter
import math

def bleu_score(reference, candidate, n=4):
    """Calculate BLEU score (simplified version)"""
    def get_ngrams(tokens, n):
        ngrams = []
        for i in range(len(tokens) - n + 1):
            ngrams.append(tuple(tokens[i:i+n]))
        return Counter(ngrams)
    
    ref_tokens = reference.lower().split()
    cand_tokens = candidate.lower().split()
    
    if len(cand_tokens) == 0:
        return 0.0
    
    precisions = []
    
    for i in range(1, n + 1):
        ref_ngrams = get_ngrams(ref_tokens, i)
        cand_ngrams = get_ngrams(cand_tokens, i)
        
        if not cand_ngrams:
            precisions.append(0.0)
            continue
        
        # Count overlapping n-grams
        overlap = 0
        for ngram, count in cand_ngrams.items():
            overlap += min(count, ref_ngrams.get(ngram, 0))
        
        total = sum(cand_ngrams.values())
        precisions.append(overlap / total)
    
    # Brevity penalty
    bp = 1.0
    if len(cand_tokens) < len(ref_tokens):
        bp = math.exp(1 - len(ref_tokens) / len(cand_tokens))
    
    # Geometric mean of precisions
    if all(p > 0 for p in precisions):
        score = bp * math.exp(sum(math.log(p) for p in precisions) / n)
    else:
        score = 0.0
    
    return score

# Example usage
reference = "a brown dog is running in the grass"
candidate = "a dog runs on the grass"

bleu = bleu_score(reference, candidate)
print(f"BLEU Score: {bleu:.4f}")

print("\nNote: For comprehensive evaluation, use:")
print("- NLTK for BLEU, METEOR")
print("- pycocoevalcap for CIDEr, SPICE")
print("- Multiple reference captions")

## Next Steps

### For Real Training:
1. **Download Datasets**:
   - Flickr8k: https://www.kaggle.com/datasets/adityajn105/flickr8k
   - COCO Captions: https://cocodataset.org/#download

2. **Data Preparation**:
   - Extract images and captions
   - Split into train/val/test sets
   - Build proper vocabulary

3. **Training Configuration**:
   - Adjust hyperparameters
   - Use GPU for faster training
   - Implement early stopping

4. **Model Improvements**:
   - Add attention mechanisms
   - Implement beam search
   - Use transformer decoder
   - Fine-tune CNN encoder

5. **Evaluation**:
   - Calculate BLEU, METEOR, CIDEr scores
   - Visualize attention maps
   - Test on diverse images

### Deployment Options:
- Web application with Flask/FastAPI
- Mobile app with TensorFlow Lite
- Cloud service deployment
- Edge device optimization