In [1]:
import os
import csv
import json
import torch
from torch import nn, optim
from torch.cuda.amp import autocast
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from transformers import AutoTokenizer
from nltk.translate.bleu_score import corpus_bleu
from PIL import Image
from tqdm import tqdm
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from bert_score import score as bert_score

In [2]:
feature_dim = 32  # Dimension for CNN encoder output and transformer decoder
num_layers = 1  # Number of layers in the transformer decoder
nhead = 2  # Number of attention heads in the transformer decoder
dim_feedforward = 32  # Dimension of the feedforward network in the transformer decoder
batch_size = 8  # Batch size for data loaders
learning_rate = 1e-4  # Learning rate for the optimizer
num_epochs = 20  # Number of epochs for training
max_length = 50  # Maximum length for captions

In [3]:
# --- Dataset Definition ---
class ImageCaptionDataset(Dataset):
    """Custom Dataset for loading images and captions."""

    def __init__(self, image_dir, captions_file, tokenizer, max_length=max_length):
        self.image_dir = image_dir
        with open(captions_file, "r") as f:
            self.data = json.load(f)  # Expected format: {"image1.jpg": "caption", ...}
        self.image_filenames = list(self.data.keys())
        self.captions = list(self.data.values())
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        caption = self.captions[idx]
        tokenized = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        return (
            image,
            tokenized.input_ids.squeeze(),
            tokenized.attention_mask.squeeze(),
        )

In [4]:
# --- CNN Encoder ---
class CNNEncoder(nn.Module):
    """Encoder using EfficientNet-B0 to extract image features."""

    def __init__(self, feature_dim=feature_dim):
        super(CNNEncoder, self).__init__()
        efficientnet = models.efficientnet_b0(pretrained=True)
        self.features = efficientnet.features  # Extract up to last conv layer
        self.projection = nn.Conv2d(1280, feature_dim, kernel_size=1, stride=1)
        self.feature_dim = feature_dim  # Store feature_dim as an instance variable

    def forward(self, images):
        features = self.features(images)  # Shape: (batch_size, 1280, 7, 7)
        features = self.projection(features)  # Shape: (batch_size, feature_dim, 7, 7)
        batch_size = features.size(0)
        features = features.permute(0, 2, 3, 1).reshape(
            batch_size, 49, self.feature_dim
        )
        return features  # Shape: (batch_size, 49, feature_dim)

In [5]:
# --- Transformer Decoder ---
class TransformerDecoder(nn.Module):
    """Lightweight Transformer decoder for caption generation."""

    def __init__(
        self,
        feature_dim,
        vocab_size,
        num_layers=num_layers,
        nhead=nhead,
        dim_feedforward=dim_feedforward,
    ):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, feature_dim)
        self.decoder_layer = nn.TransformerDecoderLayer(
            d_model=feature_dim,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            batch_first=False,
        )
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(feature_dim, vocab_size)

    def forward(self, encoder_features, captions):
        embeddings = self.embedding(captions).permute(
            1, 0, 2
        )  # (seq_len, batch_size, feature_dim)
        memory = encoder_features.permute(1, 0, 2)  # (49, batch_size, feature_dim)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(captions.size(1)).to(
            captions.device
        )
        output = self.decoder(embeddings, memory, tgt_mask=tgt_mask)
        logits = self.fc(output.permute(1, 0, 2))  # (batch_size, seq_len, vocab_size)
        return logits


# --- Combined Image Captioning Model ---
class ImageCaptioningModel(nn.Module):
    """Combines CNN encoder and Transformer decoder."""

    def __init__(self, encoder, decoder):
        super(ImageCaptioningModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

In [6]:
# --- Validation Function ---
def validate_model(model, dataloader, criterion, device):
    """Compute validation loss."""
    model.eval()
    total_val_loss = 0.0
    with torch.no_grad():
        for images, input_ids, _ in tqdm(dataloader, desc="Validation", leave=False):
            images, input_ids = images.to(device), input_ids.to(device)
            with autocast():
                outputs = model(images, input_ids[:, :-1])
                loss = criterion(
                    outputs.reshape(-1, outputs.size(-1)), input_ids[:, 1:].reshape(-1)
                )
            total_val_loss += loss.item()
    return total_val_loss / len(dataloader)

In [7]:
def train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    scheduler,
    num_epochs,
    device,
    model_save_path,
    patience=3,
    min_delta=0.001,
):
    """Train the model in full precision with early stopping."""
    best_loss = float("inf")
    epochs_no_improve = 0

    with open("training_results.csv", "w", newline="") as csv_file:
        csv_writer = csv.writer(csv_file)
        csv_writer.writerow(["epoch", "train_loss", "val_loss", "best_val_loss", "lr"])
        csv_file.flush()

        # Outer tqdm loop for epochs
        with tqdm(range(num_epochs), desc="Training Progress") as epoch_pbar:
            for epoch in epoch_pbar:
                model.train()
                total_loss = 0.0
                # Inner tqdm loop for batches, completes before outer bar updates
                for images, input_ids, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False):
                    images, input_ids = images.to(device), input_ids.to(device)
                    optimizer.zero_grad()
                    outputs = model(images, input_ids[:, :-1])
                    loss = criterion(
                        outputs.reshape(-1, outputs.size(-1)), input_ids[:, 1:].reshape(-1)
                    )
                    loss.backward()
                    optimizer.step()
                    total_loss += loss.item()

                # Compute metrics after the full epoch
                avg_train_loss = total_loss / len(train_loader)
                avg_val_loss = validate_model(model, val_loader, criterion, device)
                scheduler.step(avg_val_loss)
                current_lr = optimizer.param_groups[0]["lr"]

                # Update postfix only after the epoch completes
                epoch_pbar.set_postfix({
                    "Train Loss": f"{avg_train_loss:.4f}",
                    "Val Loss": f"{avg_val_loss:.4f}",
                    "Best Val Loss": f"{best_loss:.4f}",
                    "LR": f"{current_lr:.6f}"
                })

                # Check for improvement and save model
                if best_loss - avg_val_loss > min_delta:
                    best_loss = avg_val_loss
                    epochs_no_improve = 0
                    torch.save(model.state_dict(), model_save_path)
                    print(f"Best model saved with val loss {best_loss:.4f}")
                else:
                    epochs_no_improve += 1

                csv_writer.writerow(
                    [epoch + 1, avg_train_loss, avg_val_loss, best_loss, current_lr]
                )
                csv_file.flush()

                if epochs_no_improve >= patience:
                    print("Early stopping triggered!")
                    break

In [8]:
# --- Caption Generation ---
def generate_caption(model, image, tokenizer, max_length=max_length, device="cuda"):
    """Generate a caption for a single image."""
    model.eval()
    with torch.no_grad():
        image = image.unsqueeze(0).to(device)
        features = model.encoder(image)  # (1, 49, feature_dim)
        generated = torch.tensor([tokenizer.bos_token_id], device=device).unsqueeze(0)
        for _ in range(max_length - 1):
            logits = model.decoder(features, generated)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            generated = torch.cat([generated, next_token], dim=1)
            if next_token.item() == tokenizer.eos_token_id:
                break
        return tokenizer.decode(generated[0], skip_special_tokens=True)

In [9]:
# --- Evaluation Function ---
def evaluate_model(model, test_loader, tokenizer, device):
    """Evaluate the model using BLEU, CIDEr, METEOR, ROUGE-L, and BERT scores."""
    model.eval()
    refs, hyps = {}, {}
    with torch.no_grad():
        for i, (images, input_ids, _) in enumerate(test_loader):
            generated = [
                generate_caption(model, img, tokenizer, device=device) for img in images
            ]
            references = [
                tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids
            ]
            for j, (ref, hyp) in enumerate(zip(references, generated)):
                idx = i * test_loader.batch_size + j
                refs[idx] = [ref]
                hyps[idx] = [hyp]

    # CIDEr
    cider_scorer = Cider()
    cider_score, _ = cider_scorer.compute_score(refs, hyps)

    # METEOR
    meteor_scorer = Meteor()
    meteor_score, _ = meteor_scorer.compute_score(refs, hyps)

    # ROUGE-L
    rouge_scorer = Rouge()
    rouge_score, _ = rouge_scorer.compute_score(refs, hyps)

    # BERT Score
    ref_list = [r[0] for r in refs.values()]
    hyp_list = [h[0] for h in hyps.values()]
    P, R, F1 = bert_score(hyp_list, ref_list, lang="en", verbose=True)
    bert_f1 = F1.mean().item()

    # BLEU
    bleu_score = corpus_bleu(
        [[r.split()] for r in ref_list], [h.split() for h in hyp_list]
    )

    print(
        f"BLEU: {bleu_score:.4f}, CIDEr: {cider_score:.4f}, METEOR: {meteor_score:.4f}, "
        f"ROUGE-L: {rouge_score:.4f}, BERT Score: {bert_f1:.4f}"
    )
    return bleu_score, cider_score, meteor_score, rouge_score, bert_f1

In [10]:
# Configuration
IMAGE_DIR = "train2017_20k_processed"  # Replace with your image directory
CAPTIONS_FILE = "merged_captions.json"  # Replace with your captions JSON file
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [11]:
# Dataset and DataLoaders
dataset = ImageCaptionDataset(
    IMAGE_DIR, CAPTIONS_FILE, tokenizer, max_length=max_length
)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True
)


In [12]:
# Model Initialization
vocab_size = tokenizer.vocab_size
encoder = CNNEncoder(feature_dim=feature_dim)
decoder = TransformerDecoder(
    feature_dim=feature_dim,
    vocab_size=vocab_size,
    num_layers=num_layers,
    nhead=nhead,
    dim_feedforward=dim_feedforward,
)
model = ImageCaptioningModel(encoder, decoder).to(device)

# Training Setup
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.1, patience=2
)



In [None]:
# Train the Model
train_model(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    scheduler,
    num_epochs=num_epochs,
    device=device,
    model_save_path="my_model.pth",
    patience=3,
    min_delta=0.001,
)


Training Progress:   0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
# Load Best Model and Evaluate
model.load_state_dict(torch.load("my_model.pth"))
evaluate_model(model, test_loader, tokenizer, device)

  model.load_state_dict(torch.load("my_model.pth"))


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


calculating scores...
computing bert embedding.


  0%|          | 0/94 [00:00<?, ?it/s]

computing greedy matching.


  0%|          | 0/47 [00:00<?, ?it/s]

done in 25.07 seconds, 119.65 sentences/sec
BLEU: 0.0240, CIDEr: 0.0622, METEOR: 0.1259, ROUGE-L: 0.1938, BERT Score: 0.8458


(0.02400378036626487,
 0.062165027043621306,
 0.12588561906337187,
 0.19379292239336057,
 0.8457581400871277)