# Cross-Modal Autoencoder for Image-Caption Retrieval

This notebook implements a cross-modal autoencoder where:
- Image latents decode into text feature space
- Text latents decode into image feature space

This forces the shared latent space to capture cross-modal information.

In [1]:
# Imports
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
from PIL import Image
import random

## 1. Load Pre-extracted Features

We'll use the features already extracted from the previous notebook:
- Image features: ResNet50 (2048-dim)
- Caption features: BERT (768-dim)

In [2]:
# Load the pre-split data
image_train = np.load("train_image_features.npy")
caption_train = np.load("train_caption_features.npy")
cap2img_train = np.load("train_caption_to_image.npy")

image_val = np.load("val_image_features.npy")
caption_val = np.load("val_caption_features.npy")
cap2img_val = np.load("val_caption_to_image.npy")

image_test = np.load("test_image_features.npy")
caption_test = np.load("test_caption_features.npy")
cap2img_test = np.load("test_caption_to_image.npy")

print("Train shapes:", image_train.shape, caption_train.shape, cap2img_train.shape)
print("Val shapes:", image_val.shape, caption_val.shape, cap2img_val.shape)
print("Test shapes:", image_test.shape, caption_test.shape, cap2img_test.shape)

Train shapes: (5663, 2048) (28315, 768) (28315,)
Val shapes: (1214, 2048) (6070, 768) (6070,)
Test shapes: (1214, 2048) (6070, 768) (6070,)


## 2. Normalize Features

Normalize using training set statistics

In [3]:
# Compute normalization parameters from training set
img_mean = image_train.mean(axis=0, keepdims=True)
img_std = image_train.std(axis=0, keepdims=True) + 1e-6

txt_mean = caption_train.mean(axis=0, keepdims=True)
txt_std = caption_train.std(axis=0, keepdims=True) + 1e-6

# Apply normalization
def normalize_images(x):
    return (x - img_mean) / img_std

def normalize_texts(x):
    return (x - txt_mean) / txt_std

image_train = normalize_images(image_train)
image_val = normalize_images(image_val)
image_test = normalize_images(image_test)

caption_train = normalize_texts(caption_train)
caption_val = normalize_texts(caption_val)
caption_test = normalize_texts(caption_test)

print("Normalization complete!")

Normalization complete!


## 3. Dataset and DataLoader

Pairs captions with their corresponding images

In [4]:
class CaptionImagePairedDataset(Dataset):
    """
    Iterates over captions. For index i, returns:
      caption_features[i], image_features[caption_to_image_idx[i]]
    """
    def __init__(self, caption_feats, image_feats, caption_to_image_idx):
        assert len(caption_feats) == len(caption_to_image_idx)
        self.caption_feats = caption_feats.astype(np.float32)
        self.image_feats = image_feats.astype(np.float32)
        self.cap2img = caption_to_image_idx.astype(np.int64)

    def __len__(self):
        return len(self.caption_feats)

    def __getitem__(self, idx):
        cap = self.caption_feats[idx]
        img = self.image_feats[self.cap2img[idx]]
        return {"image": torch.from_numpy(img), "caption": torch.from_numpy(cap)}

## 4. Configuration

In [5]:
config = {
    "latent_dim": 512,
    "img_input_dim": 2048,
    "txt_input_dim": 768,
    "img_encoder_hidden": 1024,
    "txt_encoder_hidden": 512,
    "batch_size": 128,
    "lr": 1e-3,
    "weight_decay": 1e-5,
    "epochs": 40,
    "lambda_recon": 1.0,      # Weight for cross-modal reconstruction
    "lambda_contrastive": 1.0,  # Weight for contrastive loss
    "temperature": 0.07,       # Temperature for contrastive loss
    "checkpoint_dir": "./cross_modal_checkpoints",
    "seed": 42,
}

# Set seed
torch.manual_seed(config["seed"])
np.random.seed(config["seed"])

# Create checkpoint directory
os.makedirs(config["checkpoint_dir"], exist_ok=True)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## 5. Cross-Modal Autoencoder Architecture

Key idea:
- Image encoder → latent → Text decoder (outputs text features)
- Text encoder → latent → Image decoder (outputs image features)

Both encoders output the same latent dimension, forcing shared representation.

In [6]:
class ImageEncoder(nn.Module):
    """Encodes images (2048-dim) to shared latent space (512-dim)"""
    def __init__(self, input_dim=2048, hidden_dim=1024, latent_dim=512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, latent_dim)
        )

    def forward(self, x):
        return self.encoder(x)


class TextEncoder(nn.Module):
    """Encodes text (768-dim) to shared latent space (512-dim)"""
    def __init__(self, input_dim=768, hidden_dim=512, latent_dim=512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, latent_dim)
        )

    def forward(self, x):
        return self.encoder(x)


class ImageDecoder(nn.Module):
    """Decodes from latent space (512-dim) to IMAGE features (2048-dim)"""
    def __init__(self, latent_dim=512, hidden_dim=1024, output_dim=2048):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, z):
        return self.decoder(z)


class TextDecoder(nn.Module):
    """Decodes from latent space (512-dim) to TEXT features (768-dim)"""
    def __init__(self, latent_dim=512, hidden_dim=512, output_dim=768):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, z):
        return self.decoder(z)


print("Model architecture defined!")

Model architecture defined!


## 6. Loss Functions

In [7]:
def contrastive_loss(z_img, z_txt, temperature=0.07):
    """
    Contrastive loss to align image and text embeddings.
    Encourages matching pairs to be close and non-matching pairs to be far.
    """
    # Normalize embeddings
    z_img_norm = F.normalize(z_img, dim=1)
    z_txt_norm = F.normalize(z_txt, dim=1)
    
    # Compute similarity matrix
    logits = torch.matmul(z_img_norm, z_txt_norm.T) / temperature
    
    # Labels: diagonal elements are positive pairs
    labels = torch.arange(z_img.size(0)).to(z_img.device)
    
    # Cross-entropy in both directions
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)
    
    return (loss_i2t + loss_t2i) / 2


# Reconstruction loss
recon_loss_fn = nn.MSELoss()

print("Loss functions defined!")

Loss functions defined!


## 7. Initialize Models and Optimizer

In [8]:
# Instantiate models
img_encoder = ImageEncoder(
    input_dim=config["img_input_dim"],
    hidden_dim=config["img_encoder_hidden"],
    latent_dim=config["latent_dim"]
).to(device)

txt_encoder = TextEncoder(
    input_dim=config["txt_input_dim"],
    hidden_dim=config["txt_encoder_hidden"],
    latent_dim=config["latent_dim"]
).to(device)

img_decoder = ImageDecoder(
    latent_dim=config["latent_dim"],
    hidden_dim=config["img_encoder_hidden"],
    output_dim=config["img_input_dim"]
).to(device)

txt_decoder = TextDecoder(
    latent_dim=config["latent_dim"],
    hidden_dim=config["txt_encoder_hidden"],
    output_dim=config["txt_input_dim"]
).to(device)

# Optimizer for all parameters
params = (
    list(img_encoder.parameters()) + 
    list(txt_encoder.parameters()) + 
    list(img_decoder.parameters()) + 
    list(txt_decoder.parameters())
)
optimizer = Adam(params, lr=config["lr"], weight_decay=config["weight_decay"])

print(f"Models initialized with {sum(p.numel() for p in params):,} total parameters")

  from .autonotebook import tqdm as notebook_tqdm


Models initialized with 6,560,512 total parameters


## 8. Create DataLoaders

In [9]:
train_dataset = CaptionImagePairedDataset(caption_train, image_train, cap2img_train)
val_dataset = CaptionImagePairedDataset(caption_val, image_val, cap2img_val)

train_loader = DataLoader(
    train_dataset, 
    batch_size=config["batch_size"], 
    shuffle=True, 
    num_workers=0
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=config["batch_size"], 
    shuffle=False, 
    num_workers=0
)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

Train batches: 222, Val batches: 48


## 9. Training Loop

Cross-modal reconstruction:
- Image → img_encoder → z_img → txt_decoder → reconstructed text features
- Text → txt_encoder → z_txt → img_decoder → reconstructed image features

Plus contrastive loss to align z_img and z_txt

In [10]:
def run_epoch(loader, training=True):
    if training:
        img_encoder.train()
        txt_encoder.train()
        img_decoder.train()
        txt_decoder.train()
    else:
        img_encoder.eval()
        txt_encoder.eval()
        img_decoder.eval()
        txt_decoder.eval()

    total_recon_i2t = 0.0  # Image to text reconstruction
    total_recon_t2i = 0.0  # Text to image reconstruction
    total_contrast = 0.0   # Contrastive loss
    total_loss = 0.0
    n_samples = 0

    pbar = tqdm(loader, desc="train" if training else "val")
    with torch.set_grad_enabled(training):
        for batch in pbar:
            imgs = batch["image"].to(device)    # (B, 2048)
            caps = batch["caption"].to(device)  # (B, 768)
            batch_size = imgs.shape[0]

            # Forward pass: CROSS-MODAL reconstruction
            z_img = img_encoder(imgs)              # Image → latent
            z_txt = txt_encoder(caps)              # Text → latent
            
            txt_recon = txt_decoder(z_img)         # Image latent → text features
            img_recon = img_decoder(z_txt)         # Text latent → image features

            # Losses
            L_i2t = recon_loss_fn(txt_recon, caps)  # Image should reconstruct text
            L_t2i = recon_loss_fn(img_recon, imgs)  # Text should reconstruct image
            L_contrast = contrastive_loss(z_img, z_txt, temperature=config["temperature"])

            # Total loss
            loss = (
                config["lambda_recon"] * (L_i2t + L_t2i) + 
                config["lambda_contrastive"] * L_contrast
            )

            if training:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # Accumulate metrics
            total_recon_i2t += L_i2t.item() * batch_size
            total_recon_t2i += L_t2i.item() * batch_size
            total_contrast += L_contrast.item() * batch_size
            total_loss += loss.item() * batch_size
            n_samples += batch_size

            pbar.set_postfix({
                "loss": f"{total_loss / n_samples:.4f}",
                "i2t": f"{total_recon_i2t / n_samples:.4f}",
                "t2i": f"{total_recon_t2i / n_samples:.4f}",
                "contrast": f"{total_contrast / n_samples:.4f}"
            })

    return {
        "loss": total_loss / n_samples,
        "L_i2t": total_recon_i2t / n_samples,
        "L_t2i": total_recon_t2i / n_samples,
        "L_contrast": total_contrast / n_samples
    }

## 10. Train the Model

In [11]:
best_val_loss = float("inf")

for epoch in range(1, config["epochs"] + 1):
    print(f"\n=== Epoch {epoch}/{config['epochs']} ===")
    train_metrics = run_epoch(train_loader, training=True)
    val_metrics = run_epoch(val_loader, training=False)

    print(f"Train loss: {train_metrics['loss']:.4f} | Val loss: {val_metrics['loss']:.4f}")

    # Save checkpoint every epoch
    ckpt = {
        "epoch": epoch,
        "img_encoder": img_encoder.state_dict(),
        "txt_encoder": txt_encoder.state_dict(),
        "img_decoder": img_decoder.state_dict(),
        "txt_decoder": txt_decoder.state_dict(),
        "optimizer": optimizer.state_dict(),
        "train_metrics": train_metrics,
        "val_metrics": val_metrics,
        "config": config
    }
    ckpt_path = os.path.join(config["checkpoint_dir"], f"cross_modal_epoch{epoch}.pt")
    torch.save(ckpt, ckpt_path)

    # Save best checkpoint
    if val_metrics["loss"] < best_val_loss:
        best_val_loss = val_metrics["loss"]
        torch.save(ckpt, os.path.join(config["checkpoint_dir"], "cross_modal_best.pt"))
        print("Saved best checkpoint.")

print("Training finished!")


=== Epoch 1/40 ===


train: 100%|██████████| 222/222 [00:20<00:00, 10.58it/s, loss=4.5261, i2t=0.9371, t2i=0.8927, contrast=2.6962]
val: 100%|██████████| 48/48 [00:01<00:00, 42.69it/s, loss=4.6912, i2t=0.9260, t2i=0.8727, contrast=2.8926]


Train loss: 4.5261 | Val loss: 4.6912
Saved best checkpoint.

=== Epoch 2/40 ===


train: 100%|██████████| 222/222 [00:20<00:00, 10.77it/s, loss=3.6650, i2t=0.9141, t2i=0.8541, contrast=1.8968]
val: 100%|██████████| 48/48 [00:00<00:00, 57.95it/s, loss=4.6332, i2t=0.9258, t2i=0.8605, contrast=2.8469]


Train loss: 3.6650 | Val loss: 4.6332
Saved best checkpoint.

=== Epoch 3/40 ===


train: 100%|██████████| 222/222 [00:20<00:00, 10.58it/s, loss=3.2950, i2t=0.9044, t2i=0.8403, contrast=1.5503]
val: 100%|██████████| 48/48 [00:00<00:00, 56.06it/s, loss=4.5974, i2t=0.9203, t2i=0.8513, contrast=2.8258]


Train loss: 3.2950 | Val loss: 4.5974
Saved best checkpoint.

=== Epoch 4/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.16it/s, loss=3.0624, i2t=0.8980, t2i=0.8295, contrast=1.3349]
val: 100%|██████████| 48/48 [00:00<00:00, 54.91it/s, loss=4.6093, i2t=0.9205, t2i=0.8496, contrast=2.8392]


Train loss: 3.0624 | Val loss: 4.6093

=== Epoch 5/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.20it/s, loss=2.9081, i2t=0.8932, t2i=0.8227, contrast=1.1922]
val: 100%|██████████| 48/48 [00:00<00:00, 55.87it/s, loss=4.5938, i2t=0.9219, t2i=0.8471, contrast=2.8249]


Train loss: 2.9081 | Val loss: 4.5938
Saved best checkpoint.

=== Epoch 6/40 ===


train: 100%|██████████| 222/222 [00:20<00:00, 10.70it/s, loss=2.7760, i2t=0.8899, t2i=0.8177, contrast=1.0684]
val: 100%|██████████| 48/48 [00:00<00:00, 54.44it/s, loss=4.6020, i2t=0.9249, t2i=0.8463, contrast=2.8308]


Train loss: 2.7760 | Val loss: 4.6020

=== Epoch 7/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.22it/s, loss=2.6814, i2t=0.8876, t2i=0.8130, contrast=0.9808]
val: 100%|██████████| 48/48 [00:00<00:00, 58.74it/s, loss=4.6061, i2t=0.9270, t2i=0.8435, contrast=2.8356]


Train loss: 2.6814 | Val loss: 4.6061

=== Epoch 8/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.71it/s, loss=2.6092, i2t=0.8858, t2i=0.8097, contrast=0.9137]
val: 100%|██████████| 48/48 [00:00<00:00, 57.98it/s, loss=4.5893, i2t=0.9232, t2i=0.8426, contrast=2.8236]


Train loss: 2.6092 | Val loss: 4.5893
Saved best checkpoint.

=== Epoch 9/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.81it/s, loss=2.5446, i2t=0.8841, t2i=0.8054, contrast=0.8551]
val: 100%|██████████| 48/48 [00:00<00:00, 53.80it/s, loss=4.5992, i2t=0.9232, t2i=0.8396, contrast=2.8365]


Train loss: 2.5446 | Val loss: 4.5992

=== Epoch 10/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.62it/s, loss=2.4956, i2t=0.8823, t2i=0.8030, contrast=0.8103]
val: 100%|██████████| 48/48 [00:00<00:00, 56.84it/s, loss=4.6096, i2t=0.9228, t2i=0.8405, contrast=2.8464]


Train loss: 2.4956 | Val loss: 4.6096

=== Epoch 11/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.43it/s, loss=2.4505, i2t=0.8813, t2i=0.7995, contrast=0.7697]
val: 100%|██████████| 48/48 [00:00<00:00, 56.69it/s, loss=4.6146, i2t=0.9220, t2i=0.8423, contrast=2.8504]


Train loss: 2.4505 | Val loss: 4.6146

=== Epoch 12/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.53it/s, loss=2.4070, i2t=0.8801, t2i=0.7970, contrast=0.7300]
val: 100%|██████████| 48/48 [00:00<00:00, 57.50it/s, loss=4.5835, i2t=0.9251, t2i=0.8405, contrast=2.8178]


Train loss: 2.4070 | Val loss: 4.5835
Saved best checkpoint.

=== Epoch 13/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.48it/s, loss=2.3720, i2t=0.8787, t2i=0.7942, contrast=0.6990]
val: 100%|██████████| 48/48 [00:00<00:00, 56.57it/s, loss=4.6264, i2t=0.9245, t2i=0.8432, contrast=2.8588]


Train loss: 2.3720 | Val loss: 4.6264

=== Epoch 14/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.43it/s, loss=2.3443, i2t=0.8781, t2i=0.7919, contrast=0.6743]
val: 100%|██████████| 48/48 [00:00<00:00, 53.69it/s, loss=4.6065, i2t=0.9249, t2i=0.8398, contrast=2.8419]


Train loss: 2.3443 | Val loss: 4.6065

=== Epoch 15/40 ===


train: 100%|██████████| 222/222 [00:20<00:00, 10.91it/s, loss=2.3100, i2t=0.8770, t2i=0.7896, contrast=0.6433]
val: 100%|██████████| 48/48 [00:00<00:00, 56.96it/s, loss=4.6115, i2t=0.9247, t2i=0.8408, contrast=2.8460]


Train loss: 2.3100 | Val loss: 4.6115

=== Epoch 16/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.24it/s, loss=2.2812, i2t=0.8765, t2i=0.7881, contrast=0.6166]
val: 100%|██████████| 48/48 [00:00<00:00, 53.85it/s, loss=4.6248, i2t=0.9262, t2i=0.8403, contrast=2.8582]


Train loss: 2.2812 | Val loss: 4.6248

=== Epoch 17/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.65it/s, loss=2.2604, i2t=0.8750, t2i=0.7850, contrast=0.6004]
val: 100%|██████████| 48/48 [00:00<00:00, 54.45it/s, loss=4.6035, i2t=0.9250, t2i=0.8401, contrast=2.8384]


Train loss: 2.2604 | Val loss: 4.6035

=== Epoch 18/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.57it/s, loss=2.2440, i2t=0.8751, t2i=0.7841, contrast=0.5848]
val: 100%|██████████| 48/48 [00:00<00:00, 54.63it/s, loss=4.5950, i2t=0.9253, t2i=0.8424, contrast=2.8273]


Train loss: 2.2440 | Val loss: 4.5950

=== Epoch 19/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.61it/s, loss=2.2208, i2t=0.8737, t2i=0.7824, contrast=0.5647]
val: 100%|██████████| 48/48 [00:00<00:00, 53.50it/s, loss=4.6107, i2t=0.9254, t2i=0.8392, contrast=2.8460]


Train loss: 2.2208 | Val loss: 4.6107

=== Epoch 20/40 ===


train: 100%|██████████| 222/222 [00:20<00:00, 11.01it/s, loss=2.2059, i2t=0.8730, t2i=0.7802, contrast=0.5527]
val: 100%|██████████| 48/48 [00:00<00:00, 58.80it/s, loss=4.6005, i2t=0.9263, t2i=0.8394, contrast=2.8347]


Train loss: 2.2059 | Val loss: 4.6005

=== Epoch 21/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.88it/s, loss=2.1927, i2t=0.8728, t2i=0.7784, contrast=0.5414]
val: 100%|██████████| 48/48 [00:00<00:00, 54.58it/s, loss=4.6181, i2t=0.9251, t2i=0.8415, contrast=2.8515]


Train loss: 2.1927 | Val loss: 4.6181

=== Epoch 22/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.56it/s, loss=2.1754, i2t=0.8723, t2i=0.7770, contrast=0.5260]
val: 100%|██████████| 48/48 [00:00<00:00, 54.38it/s, loss=4.6254, i2t=0.9268, t2i=0.8404, contrast=2.8581]


Train loss: 2.1754 | Val loss: 4.6254

=== Epoch 23/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.87it/s, loss=2.1590, i2t=0.8709, t2i=0.7747, contrast=0.5133]
val: 100%|██████████| 48/48 [00:00<00:00, 55.57it/s, loss=4.6273, i2t=0.9292, t2i=0.8392, contrast=2.8590]


Train loss: 2.1590 | Val loss: 4.6273

=== Epoch 24/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.91it/s, loss=2.1476, i2t=0.8709, t2i=0.7742, contrast=0.5025]
val: 100%|██████████| 48/48 [00:00<00:00, 58.81it/s, loss=4.6232, i2t=0.9269, t2i=0.8501, contrast=2.8462]


Train loss: 2.1476 | Val loss: 4.6232

=== Epoch 25/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.95it/s, loss=2.1310, i2t=0.8702, t2i=0.7710, contrast=0.4898]
val: 100%|██████████| 48/48 [00:00<00:00, 58.48it/s, loss=4.6205, i2t=0.9282, t2i=0.8404, contrast=2.8519]


Train loss: 2.1310 | Val loss: 4.6205

=== Epoch 26/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.82it/s, loss=2.1261, i2t=0.8697, t2i=0.7717, contrast=0.4847]
val: 100%|██████████| 48/48 [00:00<00:00, 58.07it/s, loss=4.6293, i2t=0.9268, t2i=0.8413, contrast=2.8612]


Train loss: 2.1261 | Val loss: 4.6293

=== Epoch 27/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.72it/s, loss=2.1080, i2t=0.8702, t2i=0.7688, contrast=0.4690]
val: 100%|██████████| 48/48 [00:00<00:00, 56.47it/s, loss=4.6160, i2t=0.9251, t2i=0.8383, contrast=2.8526]


Train loss: 2.1080 | Val loss: 4.6160

=== Epoch 28/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.76it/s, loss=2.0934, i2t=0.8689, t2i=0.7675, contrast=0.4569]
val: 100%|██████████| 48/48 [00:00<00:00, 56.80it/s, loss=4.6373, i2t=0.9277, t2i=0.8421, contrast=2.8674]


Train loss: 2.0934 | Val loss: 4.6373

=== Epoch 29/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.80it/s, loss=2.0866, i2t=0.8680, t2i=0.7662, contrast=0.4524]
val: 100%|██████████| 48/48 [00:00<00:00, 57.52it/s, loss=4.6144, i2t=0.9247, t2i=0.8398, contrast=2.8499]


Train loss: 2.0866 | Val loss: 4.6144

=== Epoch 30/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.86it/s, loss=2.0766, i2t=0.8675, t2i=0.7655, contrast=0.4437]
val: 100%|██████████| 48/48 [00:00<00:00, 58.84it/s, loss=4.6304, i2t=0.9261, t2i=0.8419, contrast=2.8624]


Train loss: 2.0766 | Val loss: 4.6304

=== Epoch 31/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.96it/s, loss=2.0636, i2t=0.8674, t2i=0.7620, contrast=0.4343]
val: 100%|██████████| 48/48 [00:00<00:00, 59.00it/s, loss=4.6456, i2t=0.9230, t2i=0.8456, contrast=2.8769]


Train loss: 2.0636 | Val loss: 4.6456

=== Epoch 32/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.86it/s, loss=2.0605, i2t=0.8663, t2i=0.7620, contrast=0.4323]
val: 100%|██████████| 48/48 [00:00<00:00, 59.67it/s, loss=4.6425, i2t=0.9288, t2i=0.8408, contrast=2.8729]


Train loss: 2.0605 | Val loss: 4.6425

=== Epoch 33/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.66it/s, loss=2.0443, i2t=0.8658, t2i=0.7599, contrast=0.4186]
val: 100%|██████████| 48/48 [00:00<00:00, 54.43it/s, loss=4.6322, i2t=0.9260, t2i=0.8419, contrast=2.8643]


Train loss: 2.0443 | Val loss: 4.6322

=== Epoch 34/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.82it/s, loss=2.0386, i2t=0.8655, t2i=0.7593, contrast=0.4138]
val: 100%|██████████| 48/48 [00:00<00:00, 59.19it/s, loss=4.6434, i2t=0.9288, t2i=0.8487, contrast=2.8659]


Train loss: 2.0386 | Val loss: 4.6434

=== Epoch 35/40 ===


train: 100%|██████████| 222/222 [00:20<00:00, 10.90it/s, loss=2.0313, i2t=0.8656, t2i=0.7569, contrast=0.4087]
val: 100%|██████████| 48/48 [00:00<00:00, 52.86it/s, loss=4.6449, i2t=0.9278, t2i=0.8429, contrast=2.8742]


Train loss: 2.0313 | Val loss: 4.6449

=== Epoch 36/40 ===


train: 100%|██████████| 222/222 [00:21<00:00, 10.51it/s, loss=2.0229, i2t=0.8649, t2i=0.7562, contrast=0.4018]
val: 100%|██████████| 48/48 [00:00<00:00, 56.70it/s, loss=4.6356, i2t=0.9288, t2i=0.8381, contrast=2.8687]


Train loss: 2.0229 | Val loss: 4.6356

=== Epoch 37/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.44it/s, loss=2.0208, i2t=0.8645, t2i=0.7559, contrast=0.4005]
val: 100%|██████████| 48/48 [00:00<00:00, 55.01it/s, loss=4.6364, i2t=0.9276, t2i=0.8405, contrast=2.8683]


Train loss: 2.0208 | Val loss: 4.6364

=== Epoch 38/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.61it/s, loss=2.0076, i2t=0.8640, t2i=0.7542, contrast=0.3894]
val: 100%|██████████| 48/48 [00:00<00:00, 56.78it/s, loss=4.6464, i2t=0.9346, t2i=0.8403, contrast=2.8714]


Train loss: 2.0076 | Val loss: 4.6464

=== Epoch 39/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.77it/s, loss=1.9990, i2t=0.8639, t2i=0.7531, contrast=0.3821]
val: 100%|██████████| 48/48 [00:00<00:00, 52.79it/s, loss=4.6359, i2t=0.9281, t2i=0.8413, contrast=2.8664]


Train loss: 1.9990 | Val loss: 4.6359

=== Epoch 40/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.81it/s, loss=1.9940, i2t=0.8622, t2i=0.7516, contrast=0.3801]
val: 100%|██████████| 48/48 [00:00<00:00, 55.76it/s, loss=4.6306, i2t=0.9277, t2i=0.8418, contrast=2.8612]


Train loss: 1.9940 | Val loss: 4.6306
Training finished!


## 11. Evaluation: Load Best Model and Compute Retrieval Metrics

In [12]:
# Load best checkpoint
best_ckpt_path = os.path.join(config["checkpoint_dir"], "cross_modal_best.pt")
ckpt = torch.load(best_ckpt_path, map_location=device)

img_encoder.load_state_dict(ckpt["img_encoder"])
txt_encoder.load_state_dict(ckpt["txt_encoder"])
img_encoder.eval()
txt_encoder.eval()

print(f"Loaded best checkpoint from epoch {ckpt['epoch']}")

Loaded best checkpoint from epoch 12


In [13]:
def encode_features(image_feats, caption_feats):
    """
    Encode images and captions into latent space.
    """
    with torch.no_grad():
        # Encode images
        Z_imgs = []
        for i in range(0, image_feats.shape[0], 256):
            batch = torch.from_numpy(image_feats[i:i+256]).float().to(device)
            z = img_encoder(batch)
            Z_imgs.append(z.cpu().numpy())
        Z_imgs = np.concatenate(Z_imgs, axis=0)

        # Encode captions
        Z_caps = []
        for i in range(0, caption_feats.shape[0], 256):
            batch = torch.from_numpy(caption_feats[i:i+256]).float().to(device)
            z = txt_encoder(batch)
            Z_caps.append(z.cpu().numpy())
        Z_caps = np.concatenate(Z_caps, axis=0)

    return Z_imgs, Z_caps

In [14]:
def retrieval_metrics(Z_caps, Z_imgs, caption_to_image_idx):
    """
    Compute Recall@K metrics for image retrieval given captions.
    """
    sims = cosine_similarity(Z_caps, Z_imgs)  # (num_caps, num_imgs)
    ranks = []
    for i, true_img_idx in enumerate(caption_to_image_idx):
        sim_scores = sims[i]
        sorted_indices = np.argsort(-sim_scores)  # descending
        rank = np.where(sorted_indices == true_img_idx)[0][0] + 1
        ranks.append(rank)

    ranks = np.array(ranks)
    recall_at_1 = np.mean(ranks <= 1)
    recall_at_5 = np.mean(ranks <= 5)
    recall_at_10 = np.mean(ranks <= 10)
    med_rank = np.median(ranks)

    return {
        "Recall@1": recall_at_1,
        "Recall@5": recall_at_5,
        "Recall@10": recall_at_10,
        "MedianRank": med_rank
    }

### Validation Set Metrics

In [15]:
print("Encoding validation set...")
Z_imgs_val, Z_caps_val = encode_features(image_val, caption_val)
print(f"Encoded shapes: {Z_imgs_val.shape}, {Z_caps_val.shape}")

metrics_val = retrieval_metrics(Z_caps_val, Z_imgs_val, cap2img_val)
print("\nValidation Set Metrics:")
for k, v in metrics_val.items():
    print(f"  {k}: {v:.4f}")

Encoding validation set...
Encoded shapes: (1214, 512), (6070, 512)

Validation Set Metrics:
  Recall@1: 0.1333
  Recall@5: 0.3629
  Recall@10: 0.4906
  MedianRank: 11.0000


### Test Set Metrics

In [16]:
print("Encoding test set...")
Z_imgs_test, Z_caps_test = encode_features(image_test, caption_test)
print(f"Encoded shapes: {Z_imgs_test.shape}, {Z_caps_test.shape}")

metrics_test = retrieval_metrics(Z_caps_test, Z_imgs_test, cap2img_test)
print("\nTest Set Metrics:")
for k, v in metrics_test.items():
    print(f"  {k}: {v:.4f}")

Encoding test set...
Encoded shapes: (1214, 512), (6070, 512)

Test Set Metrics:
  Recall@1: 0.1336
  Recall@5: 0.3718
  Recall@10: 0.5033
  MedianRank: 10.0000


## 12. Visualization (Optional)

Visualize some retrieval results

In [None]:
# Load additional data for visualization
import pandas as pd

# Load original dataframe and metadata
df = pd.read_csv('./flickr8k_data/captions.txt')
image_names = np.load('flickr8k_image_names.npy')

# Load split indices to get validation masks
from sklearn.model_selection import train_test_split
n_images = len(image_names)
indices = np.arange(n_images)
train_idx, temp_idx = train_test_split(indices, test_size=0.30, random_state=42, shuffle=True)
val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42, shuffle=True)

# Create masks
caption_to_image_idx = df["image"].map({name: i for i, name in enumerate(image_names)}).values.astype(int)
val_mask = np.isin(caption_to_image_idx, val_idx)

# Get validation caption texts and image names
val_caption_texts = df["caption"].values[val_mask]
val_image_names = np.array([image_names[i] for i in val_idx])

print(f"Loaded {len(val_caption_texts)} validation captions")

In [None]:
image_dir = "./flickr8k_data/Images"

def show_top_images_for_caption(caption_idx, top_k=5):
    """
    Show top-k retrieved validation images for a given caption index.
    """
    caption_embedding = Z_caps_val[caption_idx].reshape(1, -1)
    sims = cosine_similarity(caption_embedding, Z_imgs_val)[0]
    top_img_indices = np.argsort(-sims)[:top_k]

    print(f"\nCAPTION: {val_caption_texts[caption_idx]}")
    true_img_idx = cap2img_val[caption_idx]
    print(f"TRUE IMAGE: {val_image_names[true_img_idx]} (index {true_img_idx})")
    
    plt.figure(figsize=(18, 4))
    
    # Show true image
    true_img_name = val_image_names[true_img_idx]
    true_img_path = os.path.join(image_dir, true_img_name)
    try:
        true_img = Image.open(true_img_path)
        plt.subplot(1, top_k + 1, 1)
        plt.imshow(true_img)
        plt.axis('off')
        plt.title("TRUE IMAGE", fontweight='bold', color='green')
    except Exception as e:
        print(f"Could not open true image: {e}")
    
    # Show retrieved images
    for i, img_idx in enumerate(top_img_indices):
        img_name = val_image_names[img_idx]
        img_path = os.path.join(image_dir, img_name)
        try:
            img = Image.open(img_path)
        except Exception as e:
            print(f"Could not open {img_path}: {e}")
            continue
        
        plt.subplot(1, top_k + 1, i + 2)
        plt.imshow(img)
        plt.axis('off')
        
        if img_idx == true_img_idx:
            plt.title(f"Rank {i+1} ✓", fontweight='bold', color='green')
        else:
            plt.title(f"Rank {i+1}")
    
    plt.tight_layout()
    plt.show()

# Show some random examples
print("\n=== Retrieval Examples ===")
for i in random.sample(range(len(Z_caps_val)), 3):
    show_top_images_for_caption(i, top_k=5)