# Self-Reconstruction Multi-Modal Autoencoder

This notebook implements the original architecture where:
- Image features are reconstructed by the image autoencoder
- Text features are reconstructed by the text autoencoder  
- Latent spaces are aligned using **either MSE or contrastive loss**

This forces the shared latent space to capture aligned representations.

**Prerequisites:** Run `Feature_Extraction_Batch.ipynb` first to generate the required .npy files.

In [None]:
# Imports
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
import random

## 1. Load Pre-extracted Features

We'll use the features already extracted from the Feature_Extraction notebook:
- Image features: ResNet50 (2048-dim)
- Caption features: BERT (768-dim)

In [None]:
# Load the pre-split data
image_train = np.load("train_image_features.npy")
caption_train = np.load("train_caption_features.npy")
cap2img_train = np.load("train_caption_to_image.npy")

image_val = np.load("val_image_features.npy")
caption_val = np.load("val_caption_features.npy")
cap2img_val = np.load("val_caption_to_image.npy")

image_test = np.load("test_image_features.npy")
caption_test = np.load("test_caption_features.npy")
cap2img_test = np.load("test_caption_to_image.npy")

print("Train shapes:", image_train.shape, caption_train.shape, cap2img_train.shape)
print("Val shapes:", image_val.shape, caption_val.shape, cap2img_val.shape)
print("Test shapes:", image_test.shape, caption_test.shape, cap2img_test.shape)

## 2. Normalize Features

Normalize using training set statistics

In [None]:
# Compute normalization parameters from training set
img_mean = image_train.mean(axis=0, keepdims=True)
img_std = image_train.std(axis=0, keepdims=True) + 1e-6

txt_mean = caption_train.mean(axis=0, keepdims=True)
txt_std = caption_train.std(axis=0, keepdims=True) + 1e-6

# Apply normalization
def normalize_images(x):
    return (x - img_mean) / img_std

def normalize_texts(x):
    return (x - txt_mean) / txt_std

image_train = normalize_images(image_train)
image_val = normalize_images(image_val)
image_test = normalize_images(image_test)

caption_train = normalize_texts(caption_train)
caption_val = normalize_texts(caption_val)
caption_test = normalize_texts(caption_test)

print("Normalization complete!")

## 3. Dataset and DataLoader

Pairs captions with their corresponding images

In [None]:
class CaptionImagePairedDataset(Dataset):
    """
    Iterates over captions. For index i, returns:
      caption_features[i], image_features[caption_to_image_idx[i]]
    """
    def __init__(self, caption_feats, image_feats, caption_to_image_idx):
        assert len(caption_feats) == len(caption_to_image_idx)
        self.caption_feats = caption_feats.astype(np.float32)
        self.image_feats = image_feats.astype(np.float32)
        self.cap2img = caption_to_image_idx.astype(np.int64)

    def __len__(self):
        return len(self.caption_feats)

    def __getitem__(self, idx):
        cap = self.caption_feats[idx]
        img = self.image_feats[self.cap2img[idx]]
        return {"image": torch.from_numpy(img), "caption": torch.from_numpy(cap)}

## 4. Configuration

**Choose your loss function here:** Set `LOSS_TYPE` to either `"mse"` or `"contrastive"`

In [None]:
# ============================================================
# CHOOSE YOUR LOSS FUNCTION
# ============================================================
LOSS_TYPE = "mse"  # Change to "contrastive" to use contrastive loss

# Set checkpoint directory based on loss type
if LOSS_TYPE == "mse":
    checkpoint_dir = "./corr_ae_checkpoints_mse"
elif LOSS_TYPE == "contrastive":
    checkpoint_dir = "./corr_ae_checkpoints_contrastive"
else:
    raise ValueError("LOSS_TYPE must be 'mse' or 'contrastive'")

print(f"Using {LOSS_TYPE.upper()} loss")
print(f"Checkpoints: {checkpoint_dir}")

# ============================================================
# Model Configuration
# ============================================================
config = {
    "latent_dim": 512,
    "img_input_dim": 2048,
    "txt_input_dim": 768,
    "img_hidden": 1024,
    "txt_hidden": 512,
    "batch_size": 128,
    "lr": 1e-3,
    "weight_decay": 1e-5,
    "epochs": 40,
    "lambda_align": 1.0,
    "temperature": 0.07,  # Only used for contrastive loss
    "checkpoint_dir": checkpoint_dir,
    "loss_type": LOSS_TYPE,
    "seed": 42,
}

# Set seed
torch.manual_seed(config["seed"])
np.random.seed(config["seed"])

# Create checkpoint directory
os.makedirs(config["checkpoint_dir"], exist_ok=True)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 5. Self-Reconstruction Autoencoder Architecture

Key idea:
- Image encoder → latent → Image decoder (outputs image features)
- Text encoder → latent → Text decoder (outputs text features)

Both autoencoders output the same latent dimension, with alignment enforced by loss function.

In [None]:
class ImageAE(nn.Module):
    """Image autoencoder: encodes and decodes image features"""
    def __init__(self, input_dim=2048, hidden_dim=1024, latent_dim=512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)
        return z, recon


class TextAE(nn.Module):
    """Text autoencoder: encodes and decodes text features"""
    def __init__(self, input_dim=768, hidden_dim=512, latent_dim=512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)
        return z, recon


print("Model architecture defined!")

## 6. Loss Functions

In [None]:
def contrastive_loss(z_img, z_txt, temperature=0.07):
    """
    Contrastive loss to align image and text embeddings.
    Encourages matching pairs to be close and non-matching pairs to be far.
    """
    # Normalize embeddings
    z_img_norm = F.normalize(z_img, dim=1)
    z_txt_norm = F.normalize(z_txt, dim=1)
    
    # Compute similarity matrix
    logits = torch.matmul(z_img_norm, z_txt_norm.T) / temperature
    
    # Labels: diagonal elements are positive pairs
    labels = torch.arange(z_img.size(0)).to(z_img.device)
    
    # Cross-entropy in both directions
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)
    
    return (loss_i2t + loss_t2i) / 2


# Reconstruction loss
recon_loss_fn = nn.MSELoss()

# Alignment loss - chosen based on LOSS_TYPE
if config["loss_type"] == "mse":
    align_loss_fn = nn.MSELoss()
    print("Using MSE for alignment loss")
else:
    align_loss_fn = None  # Will use contrastive_loss function
    print("Using contrastive loss for alignment")

print("Loss functions defined!")

## 7. Initialize Models and Optimizer

In [None]:
# Instantiate models
img_ae = ImageAE(
    input_dim=config["img_input_dim"],
    hidden_dim=config["img_hidden"],
    latent_dim=config["latent_dim"]
).to(device)

txt_ae = TextAE(
    input_dim=config["txt_input_dim"],
    hidden_dim=config["txt_hidden"],
    latent_dim=config["latent_dim"]
).to(device)

# Optimizer for all parameters
params = list(img_ae.parameters()) + list(txt_ae.parameters())
optimizer = Adam(params, lr=config["lr"], weight_decay=config["weight_decay"])

print(f"Models initialized with {sum(p.numel() for p in params):,} total parameters")

## 8. Create DataLoaders

In [None]:
train_dataset = CaptionImagePairedDataset(caption_train, image_train, cap2img_train)
val_dataset = CaptionImagePairedDataset(caption_val, image_val, cap2img_val)

train_loader = DataLoader(
    train_dataset, 
    batch_size=config["batch_size"], 
    shuffle=True, 
    num_workers=0
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=config["batch_size"], 
    shuffle=False, 
    num_workers=0
)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

## 9. Training Loop

Self-reconstruction:
- Image → img_ae → z_img, img_recon (reconstructs itself)
- Text → txt_ae → z_txt, txt_recon (reconstructs itself)

Plus alignment loss (MSE or contrastive) to align z_img and z_txt

In [None]:
def run_epoch(loader, training=True):
    if training:
        img_ae.train()
        txt_ae.train()
    else:
        img_ae.eval()
        txt_ae.eval()

    total_recon_img = 0.0
    total_recon_txt = 0.0
    total_align = 0.0
    total_loss = 0.0
    n_samples = 0

    pbar = tqdm(loader, desc="train" if training else "val")
    with torch.set_grad_enabled(training):
        for batch in pbar:
            imgs = batch["image"].to(device)
            caps = batch["caption"].to(device)
            batch_size = imgs.shape[0]

            # Forward pass: self-reconstruction
            z_img, img_recon = img_ae(imgs)
            z_txt, txt_recon = txt_ae(caps)

            # Reconstruction losses
            L_img = recon_loss_fn(img_recon, imgs)
            L_txt = recon_loss_fn(txt_recon, caps)
            
            # Alignment loss
            if config["loss_type"] == "mse":
                L_align = align_loss_fn(z_img, z_txt)
            else:  # contrastive
                L_align = contrastive_loss(z_img, z_txt, temperature=config["temperature"])

            # Total loss
            loss = L_img + L_txt + config["lambda_align"] * L_align

            if training:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # Accumulate metrics
            total_recon_img += L_img.item() * batch_size
            total_recon_txt += L_txt.item() * batch_size
            total_align += L_align.item() * batch_size
            total_loss += loss.item() * batch_size
            n_samples += batch_size

            pbar.set_postfix({
                "loss": f"{total_loss / n_samples:.4f}",
                "Limg": f"{total_recon_img / n_samples:.4f}",
                "Ltxt": f"{total_recon_txt / n_samples:.4f}",
                "Lalign": f"{total_align / n_samples:.4f}"
            })

    return {
        "loss": total_loss / n_samples,
        "Limg": total_recon_img / n_samples,
        "Ltxt": total_recon_txt / n_samples,
        "Lalign": total_align / n_samples
    }

## 10. Train the Model

In [None]:
best_val_loss = float("inf")

for epoch in range(1, config["epochs"] + 1):
    print(f"\n=== Epoch {epoch}/{config['epochs']} ===")
    train_metrics = run_epoch(train_loader, training=True)
    val_metrics = run_epoch(val_loader, training=False)

    print(f"Train loss: {train_metrics['loss']:.4f} | Val loss: {val_metrics['loss']:.4f}")

    # Save checkpoint every epoch
    ckpt = {
        "epoch": epoch,
        "img_state": img_ae.state_dict(),
        "txt_state": txt_ae.state_dict(),
        "optimizer": optimizer.state_dict(),
        "train_metrics": train_metrics,
        "val_metrics": val_metrics,
        "config": config
    }
    ckpt_path = os.path.join(config["checkpoint_dir"], f"corr_ae_epoch{epoch}.pt")
    torch.save(ckpt, ckpt_path)

    # Save best checkpoint
    if val_metrics["loss"] < best_val_loss:
        best_val_loss = val_metrics["loss"]
        torch.save(ckpt, os.path.join(config["checkpoint_dir"], "corr_ae_best.pt"))
        print("Saved best checkpoint.")

print("Training finished!")

## 11. Evaluation: Load Best Model and Compute Retrieval Metrics

In [None]:
# Load best checkpoint
best_ckpt_path = os.path.join(config["checkpoint_dir"], "corr_ae_best.pt")
ckpt = torch.load(best_ckpt_path, map_location=device)

img_ae.load_state_dict(ckpt["img_state"])
txt_ae.load_state_dict(ckpt["txt_state"])
img_ae.eval()
txt_ae.eval()

print(f"Loaded best checkpoint from epoch {ckpt['epoch']}")

In [None]:
def encode_features(image_feats, caption_feats):
    """
    Encode images and captions into latent space.
    """
    with torch.no_grad():
        # Encode images
        Z_imgs = []
        for i in range(0, image_feats.shape[0], 256):
            batch = torch.from_numpy(image_feats[i:i+256]).float().to(device)
            z, _ = img_ae(batch)
            Z_imgs.append(z.cpu().numpy())
        Z_imgs = np.concatenate(Z_imgs, axis=0)

        # Encode captions
        Z_caps = []
        for i in range(0, caption_feats.shape[0], 256):
            batch = torch.from_numpy(caption_feats[i:i+256]).float().to(device)
            z, _ = txt_ae(batch)
            Z_caps.append(z.cpu().numpy())
        Z_caps = np.concatenate(Z_caps, axis=0)

    return Z_imgs, Z_caps

In [None]:
def retrieval_metrics(Z_caps, Z_imgs, caption_to_image_idx):
    """
    Compute Recall@K metrics for image retrieval given captions.
    """
    sims = cosine_similarity(Z_caps, Z_imgs)
    ranks = []
    for i, true_img_idx in enumerate(caption_to_image_idx):
        sim_scores = sims[i]
        sorted_indices = np.argsort(-sim_scores)
        rank = np.where(sorted_indices == true_img_idx)[0][0] + 1
        ranks.append(rank)

    ranks = np.array(ranks)
    return {
        "Recall@1": np.mean(ranks <= 1),
        "Recall@5": np.mean(ranks <= 5),
        "Recall@10": np.mean(ranks <= 10),
        "MedianRank": np.median(ranks)
    }

### Validation Set Metrics

In [None]:
print("Encoding validation set...")
Z_imgs_val, Z_caps_val = encode_features(image_val, caption_val)
print(f"Encoded shapes: {Z_imgs_val.shape}, {Z_caps_val.shape}")

metrics_val = retrieval_metrics(Z_caps_val, Z_imgs_val, cap2img_val)
print("\nValidation Set Metrics:")
for k, v in metrics_val.items():
    print(f"  {k}: {v:.4f}")

### Test Set Metrics

In [None]:
print("Encoding test set...")
Z_imgs_test, Z_caps_test = encode_features(image_test, caption_test)
print(f"Encoded shapes: {Z_imgs_test.shape}, {Z_caps_test.shape}")

metrics_test = retrieval_metrics(Z_caps_test, Z_imgs_test, cap2img_test)
print("\nTest Set Metrics:")
for k, v in metrics_test.items():
    print(f"  {k}: {v:.4f}")

## 12. Visualization (Optional)

Visualize some retrieval results

In [None]:
# Load additional data for visualization
df = pd.read_csv('./flickr8k_data/captions.txt')
image_names = np.load('flickr8k_image_names.npy')

# Recreate splits to get validation masks
n_images = len(image_names)
indices = np.arange(n_images)
train_idx, temp_idx = train_test_split(indices, test_size=0.30, random_state=42, shuffle=True)
val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42, shuffle=True)

# Create masks
caption_to_image_idx = df["image"].map({name: i for i, name in enumerate(image_names)}).values.astype(int)
val_mask = np.isin(caption_to_image_idx, val_idx)

# Get validation caption texts and image names
val_caption_texts = df["caption"].values[val_mask]
val_image_names = np.array([image_names[i] for i in val_idx])

print(f"Loaded {len(val_caption_texts)} validation captions")

In [None]:
image_dir = "./flickr8k_data/Images"

def show_top_images_for_caption(caption_idx, top_k=5):
    """
    Show top-k retrieved validation images for a given caption index.
    """
    caption_embedding = Z_caps_val[caption_idx].reshape(1, -1)
    sims = cosine_similarity(caption_embedding, Z_imgs_val)[0]
    top_img_indices = np.argsort(-sims)[:top_k]

    print(f"\nCAPTION: {val_caption_texts[caption_idx]}")
    true_img_idx = cap2img_val[caption_idx]
    print(f"TRUE IMAGE: {val_image_names[true_img_idx]} (index {true_img_idx})")
    
    plt.figure(figsize=(18, 4))
    
    # Show true image
    true_img_name = val_image_names[true_img_idx]
    true_img_path = os.path.join(image_dir, true_img_name)
    try:
        true_img = Image.open(true_img_path)
        plt.subplot(1, top_k + 1, 1)
        plt.imshow(true_img)
        plt.axis('off')
        plt.title("TRUE IMAGE", fontweight='bold', color='green')
    except Exception as e:
        print(f"Could not open true image: {e}")
    
    # Show retrieved images
    for i, img_idx in enumerate(top_img_indices):
        img_name = val_image_names[img_idx]
        img_path = os.path.join(image_dir, img_name)
        try:
            img = Image.open(img_path)
        except Exception as e:
            print(f"Could not open {img_path}: {e}")
            continue
        
        plt.subplot(1, top_k + 1, i + 2)
        plt.imshow(img)
        plt.axis('off')
        
        if img_idx == true_img_idx:
            plt.title(f"Rank {i+1} ✓", fontweight='bold', color='green')
        else:
            plt.title(f"Rank {i+1}")
    
    plt.tight_layout()
    plt.show()

# Show some random examples
print("\n=== Retrieval Examples ===")
for i in random.sample(range(len(Z_caps_val)), 3):
    show_top_images_for_caption(i, top_k=5)