In [None]:
# This script is a refactored, production-quality implementation of the image captioning project.
#
# --- NEW in this version: Full Model Fine-Tuning ---
# Based on the user's request, this version removes all layer-freezing logic. The entirety
# of both encoders (Swin and ViT) and both decoders (T5-small) are now fully fine-tuned.
# This grants the models maximum flexibility to adapt to the dataset, though it also
# significantly increases the risk of overfitting.

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from torchvision import transforms
from transformers import (
    SwinModel, ViTModel, T5ForConditionalGeneration, AutoImageProcessor, AutoTokenizer, AutoConfig
)
from PIL import Image
import evaluate
import warnings
import os
from collections import defaultdict
from tqdm import tqdm
import nltk
from zipfile import BadZipFile
import shutil

from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.spice.spice import Spice
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer

warnings.filterwarnings("ignore")

# --- 1. Configuration ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

IMAGES_PATH = "/content/flickr8k_data/Images"
CAPTIONS_FILE = "/content/flickr8k_data/captions.txt"

# --- Model & Training Hyperparameters ---
ENCODER_SWIN = "microsoft/swin-base-patch4-window7-224-in22k"
ENCODER_VIT = "google/vit-base-patch16-224-in21k"
DECODER_NAME = "t5-small"

TOTAL_EPOCHS = 40
BATCH_SIZE = 16
LEARNING_RATE = 1e-5
WEIGHT_DECAY = 0.2
MAX_LENGTH = 50
BEAM_WIDTH = 10
T5_DROPOUT_RATE = 0.5



In [None]:
# --- 2. Flickr8k Dataset ---
class Flickr8kDataset(Dataset):
    def __init__(self, images_path, captions_file, image_processor, tokenizer):
        self.images_path = images_path
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        self.captions_map = defaultdict(list)
        self.image_files = []
        with open(captions_file, 'r', encoding='utf-8') as f:
            first_line = f.readline()
            if 'image,caption' not in first_line: self._process_line(first_line)
            for line in f: self._process_line(line)
    def _process_line(self, line):
        parts = line.strip().split(',', 1)
        if len(parts) < 2: return
        image_file, caption = parts
        if image_file not in self.captions_map: self.image_files.append(image_file)
        self.captions_map[image_file].append(caption.strip())
    def __len__(self): return len(self.image_files)
        
    def __getitem__(self, idx):
        image_name = self.image_files[idx]
        image_path = os.path.join(self.images_path, image_name)
        try: image = Image.open(image_path).convert("RGB")
        except FileNotFoundError: image = Image.new('RGB', (224, 224), color='red')
        pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
        caption = self.captions_map[image_name][0]
        labels = self.tokenizer(caption, padding="max_length", max_length=MAX_LENGTH, truncation=True, return_tensors="pt").input_ids
        return {"pixel_values": pixel_values.squeeze(0), "labels": labels.squeeze(0),
                "image_id": image_name, "raw_captions_list": self.captions_map[image_name]}

def custom_collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    image_ids = [item['image_id'] for item in batch]
    raw_captions_lists = [item['raw_captions_list'] for item in batch]
    return {"pixel_values": pixel_values, "labels": labels, "image_id": image_ids, "raw_captions_list": raw_captions_lists}



In [None]:
# --- 3. Model Architecture ---
class MappingLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.proj = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        return self.proj(x)



In [None]:
# --- 4. Training and Evaluation Functions ---
def train_model(encoder, decoder, mapping, train_loader, val_loader, optimizer, device, epochs, model_name=""):
    print(f"\n--- Training {model_name} ---")
    for epoch in range(epochs):
        encoder.train(); decoder.train(); mapping.train()
        total_loss = 0
        desc = f"Training {model_name} Epoch {epoch + 1}/{epochs}"
        for batch in tqdm(train_loader, desc=desc):
            optimizer.zero_grad()
            pixel_values, labels = batch["pixel_values"].to(device), batch["labels"].to(device)
            encoder_outputs = encoder(pixel_values=pixel_values)
            mapped_encoder_outputs = mapping(encoder_outputs.last_hidden_state)
            outputs = decoder(encoder_outputs=(mapped_encoder_outputs,), labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_train_loss = total_loss / len(train_loader)

        encoder.eval(); decoder.eval(); mapping.eval()
        total_val_loss = 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"):
                pixel_values, labels = batch["pixel_values"].to(device), batch["labels"].to(device)
                encoder_outputs = encoder(pixel_values=pixel_values)
                mapped_encoder_outputs = mapping(encoder_outputs.last_hidden_state)
                outputs = decoder(encoder_outputs=(mapped_encoder_outputs,), labels=labels)
                total_val_loss += outputs.loss.item()
        avg_val_loss = total_val_loss / len(val_loader)
        print(f"Epoch {epoch+1}/{epochs} -> Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")



In [None]:
def evaluate_ensemble(
    encoder1, decoder1, mapping1, image_processor1,
    encoder2, decoder2, mapping2, image_processor2,
    dataloader, tokenizer, device, split_name
):
    print(f"\n--- Evaluating Ensemble on {split_name} set ---")
    encoder1.eval(); decoder1.eval(); mapping1.eval()
    encoder2.eval(); decoder2.eval(); mapping2.eval()

    predictions_map, ground_truths_map = {}, {}
    gen_kwargs = {"max_length": MAX_LENGTH, "num_beams": BEAM_WIDTH, "return_dict_in_generate": True, "output_scores": True}

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        for batch in tqdm(dataloader, desc=f"Evaluating Ensemble on {split_name}"):
            image_ids = batch["image_id"]
            raw_captions = batch["raw_captions_list"]

            images = []
            for img_id in image_ids:
                try:
                    img_path = os.path.join(IMAGES_PATH, img_id)
                    images.append(Image.open(img_path).convert("RGB"))
                except FileNotFoundError:
                    images.append(Image.new('RGB', (224, 224), color='red'))

            pixel_values1 = image_processor1(images, return_tensors="pt").pixel_values.to(DEVICE)
            pixel_values2 = image_processor2(images, return_tensors="pt").pixel_values.to(DEVICE)

            with torch.no_grad():
                from transformers.modeling_outputs import BaseModelOutput

                encoder_outputs1 = encoder1(pixel_values=pixel_values1)
                mapped1 = mapping1(encoder_outputs1.last_hidden_state)
                encoder1_for_generate = BaseModelOutput(last_hidden_state=mapped1)
                outputs1 = decoder1.generate(encoder_outputs=encoder1_for_generate, **gen_kwargs)

                encoder_outputs2 = encoder2(pixel_values=pixel_values2)
                mapped2 = mapping2(encoder_outputs2.last_hidden_state)
                encoder2_for_generate = BaseModelOutput(last_hidden_state=mapped2)
                outputs2 = decoder2.generate(encoder_outputs=encoder2_for_generate, **gen_kwargs)

            captions1 = tokenizer.batch_decode(outputs1.sequences, skip_special_tokens=True)
            captions2 = tokenizer.batch_decode(outputs2.sequences, skip_special_tokens=True)
            scores1 = outputs1.sequences_scores
            scores2 = outputs2.sequences_scores

            for i, image_id in enumerate(image_ids):
                score1, score2 = scores1[i].item(), scores2[i].item()
                final_caption = captions1[i].strip() if score1 > score2 else captions2[i].strip()
                predictions_map[image_id] = final_caption if final_caption else "."
                if image_id not in ground_truths_map: ground_truths_map[image_id] = raw_captions[i]

    gts_coco_format = {img_id: [{'caption': c} for c in caps] for img_id, caps in ground_truths_map.items()}
    res_coco_format = {img_id: [{'caption': cap}] for img_id, cap in predictions_map.items()}
    ptb_tokenizer = PTBTokenizer()
    gts_tokenized, res_tokenized = ptb_tokenizer.tokenize(gts_coco_format), ptb_tokenizer.tokenize(res_coco_format)
    cider_scorer, spice_scorer = Cider(), Spice()
    cider_score, _ = cider_scorer.compute_score(gts_tokenized, res_tokenized)
    spice_score, _ = spice_scorer.compute_score(gts_tokenized, res_tokenized)

    preds_list, refs_list = list(predictions_map.values()), [ground_truths_map[k] for k in predictions_map.keys()]
    bleu, rouge, meteor = evaluate.load("bleu"), evaluate.load("rouge"), evaluate.load("meteor")
    bleu_score = bleu.compute(predictions=preds_list, references=refs_list, max_order=4)
    rouge_score = rouge.compute(predictions=preds_list, references=refs_list)
    meteor_score = meteor.compute(predictions=preds_list, references=refs_list)

    results = {
        "B1": float(f"{bleu_score['precisions'][0]:.3f}"), "B2": float(f"{bleu_score['precisions'][1]:.3f}"),
        "B3": float(f"{bleu_score['precisions'][2]:.3f}"), "B4": float(f"{bleu_score['precisions'][3]:.3f}"),
        "ROUGE-L": float(f"{rouge_score['rougeL']:.3f}"), "METEOR": float(f"{meteor_score['meteor']:.3f}"),
        "CIDEr": float(f"{cider_score:.3f}"), "SPICE": float(f"{spice_score:.3f}")
    }
    print(f"Results for ENSEMBLE {split_name}: {results}")
    return results



In [None]:
def generate_final_caption_with_ensemble(
    image_path, encoder1, decoder1, mapping1, image_processor1,
    encoder2, decoder2, mapping2, image_processor2, tokenizer
):
    print("\n--- Uncertainty-based Fusion Ensemble Inference ---")
    print(f"Image: {os.path.basename(image_path)}") # FIX: Print the image file name

    image = Image.open(image_path).convert("RGB")
    pixel_values1 = image_processor1(image, return_tensors="pt").pixel_values.to(DEVICE)
    pixel_values2 = image_processor2(image, return_tensors="pt").pixel_values.to(DEVICE)

    encoder1.eval(); decoder1.eval(); mapping1.eval()
    encoder2.eval(); decoder2.eval(); mapping2.eval()

    with torch.no_grad():
        from transformers.modeling_outputs import BaseModelOutput
        gen_kwargs = {"max_length": MAX_LENGTH, "num_beams": BEAM_WIDTH, "return_dict_in_generate": True, "output_scores": True}

        encoder_outputs1 = encoder1(pixel_values=pixel_values1)
        mapped1 = mapping1(encoder_outputs1.last_hidden_state)
        encoder1_for_generate = BaseModelOutput(last_hidden_state=mapped1)
        output1 = decoder1.generate(encoder_outputs=encoder1_for_generate, **gen_kwargs)

        encoder_outputs2 = encoder2(pixel_values=pixel_values2)
        mapped2 = mapping2(encoder_outputs2.last_hidden_state)
        encoder2_for_generate = BaseModelOutput(last_hidden_state=mapped2)
        output2 = decoder2.generate(encoder_outputs=encoder2_for_generate, **gen_kwargs)

    caption1 = tokenizer.decode(output1.sequences[0], skip_special_tokens=True)
    caption2 = tokenizer.decode(output2.sequences[0], skip_special_tokens=True)
    score1 = output1.sequences_scores[0].item()
    score2 = output2.sequences_scores[0].item()

    print(f"Swin-T5 Candidate: '{caption1}' (Confidence Score: {score1:.2f})")
    print(f"ViT-T5  Candidate: '{caption2}' (Confidence Score: {score2:.2f})")

    final_caption = caption1 if score1 > score2 else caption2
    print(f"\nâœ¨ Final Selected Caption: '{final_caption}'")




In [None]:
# --- 5. Main Execution ---
if __name__ == '__main__':
    try:
        nltk.data.find('tokenizers/punkt'); nltk.data.find('corpora/wordnet'); nltk.data.find('corpora/omw-1.4')
    except (OSError, LookupError, BadZipFile):
        print("NLTK data not found, downloading...")
        try:
            nltk_data_path = os.path.join(nltk.data.path[0])
            if os.path.exists(nltk_data_path): shutil.rmtree(nltk_data_path)
        except Exception: pass
        nltk.download('punkt', quiet=True); nltk.download('wordnet', quiet=True); nltk.download('omw-1.4', quiet=True)

    decoder_config = AutoConfig.from_pretrained(DECODER_NAME)
    decoder_config.dropout_rate = T5_DROPOUT_RATE
    tokenizer = AutoTokenizer.from_pretrained(DECODER_NAME)

    # --- MODEL 1: Swin-T5 ---
    print("--- Initializing Model 1: Swin-T5 ---")
    encoder_swin = SwinModel.from_pretrained(ENCODER_SWIN).to(DEVICE)
    decoder_t5_1 = T5ForConditionalGeneration.from_pretrained(DECODER_NAME, config=decoder_config).to(DEVICE)
    mapping_swin = MappingLayer(encoder_swin.config.hidden_size, decoder_t5_1.config.d_model).to(DEVICE)
    image_processor_swin = AutoImageProcessor.from_pretrained(ENCODER_SWIN)

    full_ds_swin = Flickr8kDataset(IMAGES_PATH, CAPTIONS_FILE, image_processor_swin, tokenizer)
    total_size = len(full_ds_swin)
    train_size, val_size, test_size = int(0.8 * total_size), int(0.1 * total_size), total_size - int(0.8 * total_size) - int(0.1 * total_size)
    print(f"\n--- Dataset Split ---\nTotal: {total_size}, Train: {train_size}, Val: {val_size}, Test: {test_size}\n")
    train_indices, val_indices, test_indices = random_split(range(total_size), [train_size, val_size, test_size])
    train_ds_swin = torch.utils.data.Subset(full_ds_swin, train_indices)
    val_ds_swin = torch.utils.data.Subset(full_ds_swin, val_indices)
    test_ds_swin = torch.utils.data.Subset(full_ds_swin, test_indices)

    train_loader_swin = DataLoader(train_ds_swin, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn)
    val_loader_swin = DataLoader(val_ds_swin, batch_size=BATCH_SIZE, collate_fn=custom_collate_fn)
    test_loader_swin = DataLoader(test_ds_swin, batch_size=BATCH_SIZE, collate_fn=custom_collate_fn)

    print("Fine-tuning ALL layers of the Swin encoder.")
    optimizer_swin = AdamW(list(encoder_swin.parameters()) + list(decoder_t5_1.parameters()) + list(mapping_swin.parameters()), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    train_model(encoder_swin, decoder_t5_1, mapping_swin, train_loader_swin, val_loader_swin, optimizer_swin, DEVICE, TOTAL_EPOCHS, "Swin-T5")

    # --- MODEL 2: ViT-T5 ---
    print("\n--- Initializing Model 2: ViT-T5 ---")
    encoder_vit = ViTModel.from_pretrained(ENCODER_VIT).to(DEVICE)
    decoder_t5_2 = T5ForConditionalGeneration.from_pretrained(DECODER_NAME, config=decoder_config).to(DEVICE)
    mapping_vit = MappingLayer(encoder_vit.config.hidden_size, decoder_t5_2.config.d_model).to(DEVICE)
    image_processor_vit = AutoImageProcessor.from_pretrained(ENCODER_VIT)

    print("Fine-tuning ALL layers of the ViT encoder.")
    full_ds_vit = Flickr8kDataset(IMAGES_PATH, CAPTIONS_FILE, image_processor_vit, tokenizer)
    train_ds_vit = torch.utils.data.Subset(full_ds_vit, train_indices)
    val_ds_vit = torch.utils.data.Subset(full_ds_vit, val_indices)
    test_ds_vit = torch.utils.data.Subset(full_ds_vit, test_indices)

    train_loader_vit = DataLoader(train_ds_vit, batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn)
    val_loader_vit = DataLoader(val_ds_vit, batch_size=BATCH_SIZE, collate_fn=custom_collate_fn)
    test_loader_vit = DataLoader(test_ds_vit, batch_size=BATCH_SIZE, collate_fn=custom_collate_fn)

    optimizer_vit = AdamW(list(encoder_vit.parameters()) + list(decoder_t5_2.parameters()) + list(mapping_vit.parameters()), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    train_model(encoder_vit, decoder_t5_2, mapping_vit, train_loader_vit, val_loader_vit, optimizer_vit, DEVICE, TOTAL_EPOCHS, "ViT-T5")

    # --- Final Ensemble Evaluation ---
    evaluate_ensemble(
        encoder_swin, decoder_t5_1, mapping_swin, image_processor_swin,
        encoder_vit, decoder_t5_2, mapping_vit, image_processor_vit,
        train_loader_swin, tokenizer, DEVICE, "Train"
    )
    evaluate_ensemble(
        encoder_swin, decoder_t5_1, mapping_swin, image_processor_swin,
        encoder_vit, decoder_t5_2, mapping_vit, image_processor_vit,
        val_loader_swin, tokenizer, DEVICE, "Validation"
    )
    evaluate_ensemble(
        encoder_swin, decoder_t5_1, mapping_swin, image_processor_swin,
        encoder_vit, decoder_t5_2, mapping_vit, image_processor_vit,
        test_loader_swin, tokenizer, DEVICE, "Test"
    )

    sample_image_path = os.path.join(IMAGES_PATH, full_ds_swin.image_files[test_ds_swin.indices[0]])
    generate_final_caption_with_ensemble(
        sample_image_path,
        encoder_swin, decoder_t5_1, mapping_swin, image_processor_swin,
        encoder_vit, decoder_t5_2, mapping_vit, image_processor_vit,
        tokenizer
    )

