# Architecture 1: Self-Reconstruction Multi-Modal Autoencoder

This notebook implements the original architecture where:
- Image features are reconstructed by the image autoencoder
- Text features are reconstructed by the text autoencoder
- Latent spaces are aligned using contrastive loss

**Prerequisites:** Run `Feature_Extraction_Batch.ipynb` first to generate the required .npy files.

5. Multi-Modal Autoencoder Architecture

In [37]:
# Now time to set up and train the actual multi-modal autoencoder
config = {
    "image_feat_path": "train_image_features.npy",   # we'll load proper files below
    "caption_feat_path": "train_caption_features.npy",
    "caption_to_image_path": "train_caption_to_image.npy",
    "val_image_feat_path": "val_image_features.npy",
    "val_caption_feat_path": "val_caption_features.npy",
    "val_caption_to_image_path": "val_caption_to_image.npy",
    "latent_dim": 512,
    "img_input_dim": 2048,
    "txt_input_dim": 768,
    "img_hidden": 1024,
    "txt_hidden": 512,
    "batch_size": 128,    # try 128; lower if memory limited (e.g., 64)
    "lr": 1e-3,
    "weight_decay": 1e-5,
    "epochs": 40,
    "lambda_align": 1.0,  # weight for latent alignment loss; tuneable
    "checkpoint_dir": "./corr_ae_checkpoints_contrastive",
    "seed": 42,
}

In [38]:
torch.manual_seed(config["seed"])
np.random.seed(config["seed"])

os.makedirs(config["checkpoint_dir"], exist_ok=True)

In [39]:
#I don't need to reload these if we run it all in the same notebook but I'm pasting the load functions here anyway
image_train = np.load(config["image_feat_path"])
caption_train = np.load(config["caption_feat_path"])
cap2img_train = np.load(config["caption_to_image_path"])

image_val = np.load(config["val_image_feat_path"])
caption_val = np.load(config["val_caption_feat_path"])
cap2img_val = np.load(config["val_caption_to_image_path"])

print("Shapes (train):", image_train.shape, caption_train.shape, cap2img_train.shape)
print("Shapes (val):", image_val.shape, caption_val.shape, cap2img_val.shape)

Shapes (train): (5663, 2048) (28315, 768) (28315,)
Shapes (val): (1214, 2048) (6070, 768) (6070,)


In [40]:
# 3) Compute train-set normalization (mean/std) and apply to all splits
# Normalize per-feature (column-wise) using training set statistics
img_mean = image_train.mean(axis=0, keepdims=True)
img_std = image_train.std(axis=0, keepdims=True) + 1e-6

txt_mean = caption_train.mean(axis=0, keepdims=True)
txt_std = caption_train.std(axis=0, keepdims=True) + 1e-6

In [41]:
def normalize_images(x):
    return (x - img_mean) / img_std

def normalize_texts(x):
    return (x - txt_mean) / txt_std

image_train = normalize_images(image_train)
image_val   = normalize_images(image_val)

caption_train = normalize_texts(caption_train)
caption_val   = normalize_texts(caption_val)

In [42]:
# 4) Dataset that returns paired (image_feat, caption_feat) for each caption
class CaptionImagePairedDataset(Dataset):
    """
    Iterates over captions. For index i, returns:
      caption_features[i], image_features[ caption_to_image_idx[i] ]
    """
    def __init__(self, caption_feats, image_feats, caption_to_image_idx):
        assert len(caption_feats) == len(caption_to_image_idx)
        self.caption_feats = caption_feats.astype(np.float32)
        self.image_feats = image_feats.astype(np.float32)
        self.cap2img = caption_to_image_idx.astype(np.int64)

    def __len__(self):
        return len(self.caption_feats)

    def __getitem__(self, idx):
        cap = self.caption_feats[idx]
        img = self.image_feats[self.cap2img[idx]]
        return {"image": torch.from_numpy(img), "caption": torch.from_numpy(cap)}

train_dataset = CaptionImagePairedDataset(caption_train, image_train, cap2img_train)
val_dataset   = CaptionImagePairedDataset(caption_val, image_val, cap2img_val)

train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=0)

In [None]:
# 5) Model: two autoencoders with shared latent dimension
class ImageAE(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=1024, latent_dim=512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)
        return z, recon

class TextAE(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=512, latent_dim=512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)
        return z, recon

In [44]:
# contrastive loss function
def contrastive_loss(z_img, z_txt, temperature=0.07):
    # normalize both embeddings
    z_img_norm = F.normalize(z_img, dim=1)
    z_txt_norm = F.normalize(z_txt, dim=1)

    # compute similarity matrix (batch_size x batch_size )
    logits = torch.matmul(z_img_norm, z_txt_norm.T) / temperature

    # create labels
    labels = torch.arange(z_img.size(0)).to(z_img.device)

    # compute cross-entropy for image to text direction
    loss_i2t = F.cross_entropy(logits, labels)

    # compute cross-entropy for text-to-image direction
    loss_t2i = F.cross_entropy(logits.T, labels)
    
    # average the two losses
    return (loss_i2t + loss_t2i) /2

In [45]:
# Instantiate models
img_ae = ImageAE(
    input_dim=config["img_input_dim"],
    hidden_dim=config["img_hidden"],
    latent_dim=config["latent_dim"]
).to(device)

txt_ae = TextAE(
    input_dim=config["txt_input_dim"],
    hidden_dim=config["txt_hidden"],
    latent_dim=config["latent_dim"]
).to(device)

# 6) Losses and optimizer
recon_loss_fn = nn.MSELoss()    # reconstruction for both
# align_loss_fn = nn.MSELoss()    # align latents

params = list(img_ae.parameters()) + list(txt_ae.parameters())
optimizer = Adam(params, lr=config["lr"], weight_decay=config["weight_decay"])

6. Training

In [None]:
# 7) Training / validation loop
def run_epoch(loader, training=True):
    if training:
        img_ae.train(); txt_ae.train()
    else:
        img_ae.eval(); txt_ae.eval()

    total_recon_img = 0.0
    total_recon_txt = 0.0
    total_align = 0.0
    total_loss = 0.0
    n_samples = 0

    pbar = tqdm(loader, desc="train" if training else "val")
    with torch.set_grad_enabled(training):
        for batch in pbar:
            imgs = batch["image"].to(device)    # shape (B, img_dim)
            caps = batch["caption"].to(device)  # shape (B, txt_dim)
            batch_size = imgs.shape[0]

            # forward
            z_img, img_recon = img_ae(imgs)
            z_txt, txt_recon = txt_ae(caps)

            # losses
            L_img = recon_loss_fn(img_recon, imgs)
            L_txt = recon_loss_fn(txt_recon, caps)
            # L_align = align_loss_fn(z_img, z_txt)
            L_align = contrastive_loss(z_img, z_txt, temperature=0.7)

            loss = L_img + L_txt + config["lambda_align"] * L_align

            if training:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            total_recon_img += L_img.item() * batch_size
            total_recon_txt += L_txt.item() * batch_size
            total_align += L_align.item() * batch_size
            total_loss += loss.item() * batch_size
            n_samples += batch_size

            pbar.set_postfix({
                "loss": f"{total_loss / n_samples:.4f}",
                "Limg": f"{total_recon_img / n_samples:.4f}",
                "Ltxt": f"{total_recon_txt / n_samples:.4f}",
                "Lalign": f"{total_align / n_samples:.4f}"
            })

    return {
        "loss": total_loss / n_samples,
        "Limg": total_recon_img / n_samples,
        "Ltxt": total_recon_txt / n_samples,
        "Lalign": total_align / n_samples
    }

In [47]:
best_val_loss = float("inf")

for epoch in range(1, config["epochs"] + 1):
    print(f"\n=== Epoch {epoch}/{config['epochs']} ===")
    train_metrics = run_epoch(train_loader, training=True)
    val_metrics = run_epoch(val_loader, training=False)

    print(f"Train loss: {train_metrics['loss']:.4f} | Val loss: {val_metrics['loss']:.4f}")

    # Save checkpoint (every epoch)
    ckpt = {
        "epoch": epoch,
        "img_state": img_ae.state_dict(),
        "txt_state": txt_ae.state_dict(),
        "optimizer": optimizer.state_dict(),
        "train_metrics": train_metrics,
        "val_metrics": val_metrics,
        "config": config
    }
    ckpt_path = os.path.join(config["checkpoint_dir"], f"corr_ae_epoch{epoch}.pt")
    torch.save(ckpt, ckpt_path)

    # Keep best
    if val_metrics["loss"] < best_val_loss:
        best_val_loss = val_metrics["loss"]
        torch.save(ckpt, os.path.join(config["checkpoint_dir"], "corr_ae_best.pt"))
        print("Saved best checkpoint.")

print("Training finished.")


=== Epoch 1/40 ===


train: 100%|██████████| 222/222 [00:20<00:00, 11.06it/s, loss=4.8599, Limg=0.4567, Ltxt=0.3779, Lalign=4.0253]
val: 100%|██████████| 48/48 [00:00<00:00, 52.04it/s, loss=4.6222, Limg=0.3231, Ltxt=0.2561, Lalign=4.0430]


Train loss: 4.8599 | Val loss: 4.6222
Saved best checkpoint.

=== Epoch 2/40 ===


train: 100%|██████████| 222/222 [00:20<00:00, 10.97it/s, loss=4.3459, Limg=0.2331, Ltxt=0.2241, Lalign=3.8888]
val: 100%|██████████| 48/48 [00:00<00:00, 60.32it/s, loss=4.5043, Limg=0.2619, Ltxt=0.2119, Lalign=4.0305]


Train loss: 4.3459 | Val loss: 4.5043
Saved best checkpoint.

=== Epoch 3/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.78it/s, loss=4.2180, Limg=0.1850, Ltxt=0.1899, Lalign=3.8431]
val: 100%|██████████| 48/48 [00:00<00:00, 51.84it/s, loss=4.4421, Limg=0.2350, Ltxt=0.1860, Lalign=4.0211]


Train loss: 4.2180 | Val loss: 4.4421
Saved best checkpoint.

=== Epoch 4/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.90it/s, loss=4.1479, Limg=0.1612, Ltxt=0.1713, Lalign=3.8155]
val: 100%|██████████| 48/48 [00:00<00:00, 58.59it/s, loss=4.4190, Limg=0.2269, Ltxt=0.1724, Lalign=4.0197]


Train loss: 4.1479 | Val loss: 4.4190
Saved best checkpoint.

=== Epoch 5/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 12.07it/s, loss=4.1051, Limg=0.1481, Ltxt=0.1607, Lalign=3.7963]
val: 100%|██████████| 48/48 [00:00<00:00, 60.46it/s, loss=4.4039, Limg=0.2196, Ltxt=0.1662, Lalign=4.0181]


Train loss: 4.1051 | Val loss: 4.4039
Saved best checkpoint.

=== Epoch 6/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.89it/s, loss=4.0716, Limg=0.1383, Ltxt=0.1513, Lalign=3.7819]
val: 100%|██████████| 48/48 [00:00<00:00, 60.07it/s, loss=4.3838, Limg=0.2150, Ltxt=0.1548, Lalign=4.0140]


Train loss: 4.0716 | Val loss: 4.3838
Saved best checkpoint.

=== Epoch 7/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 12.01it/s, loss=4.0492, Limg=0.1318, Ltxt=0.1453, Lalign=3.7722]
val: 100%|██████████| 48/48 [00:00<00:00, 61.16it/s, loss=4.3830, Limg=0.2154, Ltxt=0.1537, Lalign=4.0139]


Train loss: 4.0492 | Val loss: 4.3830
Saved best checkpoint.

=== Epoch 8/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.68it/s, loss=4.0290, Limg=0.1271, Ltxt=0.1396, Lalign=3.7622]
val: 100%|██████████| 48/48 [00:00<00:00, 57.36it/s, loss=4.3710, Limg=0.2072, Ltxt=0.1456, Lalign=4.0181]


Train loss: 4.0290 | Val loss: 4.3710
Saved best checkpoint.

=== Epoch 9/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.85it/s, loss=4.0113, Limg=0.1220, Ltxt=0.1355, Lalign=3.7538]
val: 100%|██████████| 48/48 [00:00<00:00, 62.66it/s, loss=4.3649, Limg=0.2077, Ltxt=0.1426, Lalign=4.0146]


Train loss: 4.0113 | Val loss: 4.3649
Saved best checkpoint.

=== Epoch 10/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 12.09it/s, loss=3.9992, Limg=0.1191, Ltxt=0.1315, Lalign=3.7485]
val: 100%|██████████| 48/48 [00:00<00:00, 60.97it/s, loss=4.3698, Limg=0.2115, Ltxt=0.1425, Lalign=4.0159]


Train loss: 3.9992 | Val loss: 4.3698

=== Epoch 11/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 12.09it/s, loss=3.9872, Limg=0.1170, Ltxt=0.1273, Lalign=3.7429]
val: 100%|██████████| 48/48 [00:00<00:00, 58.54it/s, loss=4.3577, Limg=0.2080, Ltxt=0.1356, Lalign=4.0141]


Train loss: 3.9872 | Val loss: 4.3577
Saved best checkpoint.

=== Epoch 12/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.80it/s, loss=3.9767, Limg=0.1140, Ltxt=0.1254, Lalign=3.7373]
val: 100%|██████████| 48/48 [00:00<00:00, 52.62it/s, loss=4.3507, Limg=0.2033, Ltxt=0.1290, Lalign=4.0184]


Train loss: 3.9767 | Val loss: 4.3507
Saved best checkpoint.

=== Epoch 13/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 12.03it/s, loss=3.9642, Limg=0.1110, Ltxt=0.1215, Lalign=3.7318]
val: 100%|██████████| 48/48 [00:00<00:00, 54.66it/s, loss=4.3518, Limg=0.2028, Ltxt=0.1339, Lalign=4.0152]


Train loss: 3.9642 | Val loss: 4.3518

=== Epoch 14/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.97it/s, loss=3.9581, Limg=0.1100, Ltxt=0.1195, Lalign=3.7286]
val: 100%|██████████| 48/48 [00:00<00:00, 58.58it/s, loss=4.3530, Limg=0.2041, Ltxt=0.1291, Lalign=4.0198]


Train loss: 3.9581 | Val loss: 4.3530

=== Epoch 15/40 ===


train: 100%|██████████| 222/222 [00:20<00:00, 11.01it/s, loss=3.9493, Limg=0.1086, Ltxt=0.1164, Lalign=3.7243]
val: 100%|██████████| 48/48 [00:00<00:00, 52.72it/s, loss=4.3508, Limg=0.2014, Ltxt=0.1308, Lalign=4.0186]


Train loss: 3.9493 | Val loss: 4.3508

=== Epoch 16/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.59it/s, loss=3.9414, Limg=0.1060, Ltxt=0.1146, Lalign=3.7208]
val: 100%|██████████| 48/48 [00:00<00:00, 57.55it/s, loss=4.3467, Limg=0.2029, Ltxt=0.1221, Lalign=4.0217]


Train loss: 3.9414 | Val loss: 4.3467
Saved best checkpoint.

=== Epoch 17/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.99it/s, loss=3.9356, Limg=0.1057, Ltxt=0.1119, Lalign=3.7180]
val: 100%|██████████| 48/48 [00:00<00:00, 53.67it/s, loss=4.3500, Limg=0.2060, Ltxt=0.1210, Lalign=4.0230]


Train loss: 3.9356 | Val loss: 4.3500

=== Epoch 18/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.66it/s, loss=3.9326, Limg=0.1050, Ltxt=0.1106, Lalign=3.7170]
val: 100%|██████████| 48/48 [00:00<00:00, 58.00it/s, loss=4.3378, Limg=0.2011, Ltxt=0.1188, Lalign=4.0180]


Train loss: 3.9326 | Val loss: 4.3378
Saved best checkpoint.

=== Epoch 19/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 12.24it/s, loss=3.9244, Limg=0.1031, Ltxt=0.1076, Lalign=3.7138]
val: 100%|██████████| 48/48 [00:00<00:00, 59.13it/s, loss=4.3391, Limg=0.2036, Ltxt=0.1173, Lalign=4.0182]


Train loss: 3.9244 | Val loss: 4.3391

=== Epoch 20/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 12.01it/s, loss=3.9194, Limg=0.1024, Ltxt=0.1066, Lalign=3.7104]
val: 100%|██████████| 48/48 [00:00<00:00, 59.49it/s, loss=4.3401, Limg=0.2048, Ltxt=0.1129, Lalign=4.0225]


Train loss: 3.9194 | Val loss: 4.3401

=== Epoch 21/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.73it/s, loss=3.9142, Limg=0.1016, Ltxt=0.1039, Lalign=3.7087]
val: 100%|██████████| 48/48 [00:00<00:00, 57.38it/s, loss=4.3320, Limg=0.2037, Ltxt=0.1120, Lalign=4.0162]


Train loss: 3.9142 | Val loss: 4.3320
Saved best checkpoint.

=== Epoch 22/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.69it/s, loss=3.9088, Limg=0.1005, Ltxt=0.1020, Lalign=3.7062]
val: 100%|██████████| 48/48 [00:00<00:00, 55.40it/s, loss=4.3303, Limg=0.2013, Ltxt=0.1119, Lalign=4.0172]


Train loss: 3.9088 | Val loss: 4.3303
Saved best checkpoint.

=== Epoch 23/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.97it/s, loss=3.9044, Limg=0.0991, Ltxt=0.1006, Lalign=3.7048]
val: 100%|██████████| 48/48 [00:00<00:00, 60.32it/s, loss=4.3359, Limg=0.2025, Ltxt=0.1097, Lalign=4.0236]


Train loss: 3.9044 | Val loss: 4.3359

=== Epoch 24/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 12.09it/s, loss=3.9015, Limg=0.0995, Ltxt=0.0992, Lalign=3.7028]
val: 100%|██████████| 48/48 [00:00<00:00, 52.05it/s, loss=4.3281, Limg=0.2008, Ltxt=0.1078, Lalign=4.0196]


Train loss: 3.9015 | Val loss: 4.3281
Saved best checkpoint.

=== Epoch 25/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.98it/s, loss=3.8953, Limg=0.0981, Ltxt=0.0966, Lalign=3.7006]
val: 100%|██████████| 48/48 [00:00<00:00, 58.28it/s, loss=4.3261, Limg=0.1992, Ltxt=0.1079, Lalign=4.0190]


Train loss: 3.8953 | Val loss: 4.3261
Saved best checkpoint.

=== Epoch 26/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 12.05it/s, loss=3.8932, Limg=0.0981, Ltxt=0.0960, Lalign=3.6990]
val: 100%|██████████| 48/48 [00:00<00:00, 58.71it/s, loss=4.3216, Limg=0.2006, Ltxt=0.1048, Lalign=4.0162]


Train loss: 3.8932 | Val loss: 4.3216
Saved best checkpoint.

=== Epoch 27/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.96it/s, loss=3.8875, Limg=0.0967, Ltxt=0.0935, Lalign=3.6973]
val: 100%|██████████| 48/48 [00:00<00:00, 55.73it/s, loss=4.3244, Limg=0.2018, Ltxt=0.1044, Lalign=4.0181]


Train loss: 3.8875 | Val loss: 4.3244

=== Epoch 28/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.34it/s, loss=3.8846, Limg=0.0962, Ltxt=0.0923, Lalign=3.6961]
val: 100%|██████████| 48/48 [00:00<00:00, 54.19it/s, loss=4.3247, Limg=0.2013, Ltxt=0.1017, Lalign=4.0218]


Train loss: 3.8846 | Val loss: 4.3247

=== Epoch 29/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.67it/s, loss=3.8805, Limg=0.0952, Ltxt=0.0911, Lalign=3.6943]
val: 100%|██████████| 48/48 [00:00<00:00, 54.75it/s, loss=4.3243, Limg=0.2010, Ltxt=0.1027, Lalign=4.0206]


Train loss: 3.8805 | Val loss: 4.3243

=== Epoch 30/40 ===


train: 100%|██████████| 222/222 [00:20<00:00, 10.99it/s, loss=3.8762, Limg=0.0950, Ltxt=0.0896, Lalign=3.6916]
val: 100%|██████████| 48/48 [00:00<00:00, 57.81it/s, loss=4.3162, Limg=0.2034, Ltxt=0.0967, Lalign=4.0161]


Train loss: 3.8762 | Val loss: 4.3162
Saved best checkpoint.

=== Epoch 31/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.77it/s, loss=3.8740, Limg=0.0951, Ltxt=0.0874, Lalign=3.6915]
val: 100%|██████████| 48/48 [00:00<00:00, 60.36it/s, loss=4.3208, Limg=0.1991, Ltxt=0.0968, Lalign=4.0249]


Train loss: 3.8740 | Val loss: 4.3208

=== Epoch 32/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.68it/s, loss=3.8693, Limg=0.0936, Ltxt=0.0864, Lalign=3.6893]
val: 100%|██████████| 48/48 [00:00<00:00, 57.48it/s, loss=4.3238, Limg=0.2023, Ltxt=0.0970, Lalign=4.0244]


Train loss: 3.8693 | Val loss: 4.3238

=== Epoch 33/40 ===


train: 100%|██████████| 222/222 [00:19<00:00, 11.66it/s, loss=3.8678, Limg=0.0947, Ltxt=0.0852, Lalign=3.6879]
val: 100%|██████████| 48/48 [00:00<00:00, 58.93it/s, loss=4.3179, Limg=0.2032, Ltxt=0.0925, Lalign=4.0222]


Train loss: 3.8678 | Val loss: 4.3179

=== Epoch 34/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.99it/s, loss=3.8654, Limg=0.0936, Ltxt=0.0835, Lalign=3.6883]
val: 100%|██████████| 48/48 [00:00<00:00, 55.00it/s, loss=4.3167, Limg=0.2020, Ltxt=0.0925, Lalign=4.0222]


Train loss: 3.8654 | Val loss: 4.3167

=== Epoch 35/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.94it/s, loss=3.8628, Limg=0.0933, Ltxt=0.0829, Lalign=3.6865]
val: 100%|██████████| 48/48 [00:00<00:00, 58.60it/s, loss=4.3168, Limg=0.2015, Ltxt=0.0918, Lalign=4.0235]


Train loss: 3.8628 | Val loss: 4.3168

=== Epoch 36/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 12.02it/s, loss=3.8599, Limg=0.0929, Ltxt=0.0815, Lalign=3.6855]
val: 100%|██████████| 48/48 [00:00<00:00, 58.79it/s, loss=4.3125, Limg=0.1996, Ltxt=0.0875, Lalign=4.0255]


Train loss: 3.8599 | Val loss: 4.3125
Saved best checkpoint.

=== Epoch 37/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.98it/s, loss=3.8559, Limg=0.0923, Ltxt=0.0799, Lalign=3.6837]
val: 100%|██████████| 48/48 [00:00<00:00, 59.67it/s, loss=4.3139, Limg=0.2007, Ltxt=0.0894, Lalign=4.0238]


Train loss: 3.8559 | Val loss: 4.3139

=== Epoch 38/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.91it/s, loss=3.8536, Limg=0.0924, Ltxt=0.0789, Lalign=3.6823]
val: 100%|██████████| 48/48 [00:00<00:00, 60.26it/s, loss=4.3176, Limg=0.2019, Ltxt=0.0920, Lalign=4.0237]


Train loss: 3.8536 | Val loss: 4.3176

=== Epoch 39/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.90it/s, loss=3.8513, Limg=0.0911, Ltxt=0.0785, Lalign=3.6817]
val: 100%|██████████| 48/48 [00:00<00:00, 55.38it/s, loss=4.3109, Limg=0.1992, Ltxt=0.0868, Lalign=4.0248]


Train loss: 3.8513 | Val loss: 4.3109
Saved best checkpoint.

=== Epoch 40/40 ===


train: 100%|██████████| 222/222 [00:18<00:00, 11.90it/s, loss=3.8497, Limg=0.0912, Ltxt=0.0769, Lalign=3.6816]
val: 100%|██████████| 48/48 [00:00<00:00, 60.50it/s, loss=4.3101, Limg=0.2027, Ltxt=0.0871, Lalign=4.0203]


Train loss: 3.8497 | Val loss: 4.3101
Saved best checkpoint.
Training finished.


7. Evaluation

In [48]:
#Now time to evaluate on the validation set
#Probably don't need to reload the model, but I'm going to include the code again in case we break this up into more managable files
# --- Load best checkpoint ---
best_ckpt_path = os.path.join(config["checkpoint_dir"], "corr_ae_best.pt")
ckpt = torch.load(best_ckpt_path, map_location=device)

img_ae.load_state_dict(ckpt["img_state"])
txt_ae.load_state_dict(ckpt["txt_state"])
img_ae.eval(); txt_ae.eval()

print(f"Loaded best checkpoint from {best_ckpt_path} (epoch {ckpt['epoch']})")

Loaded best checkpoint from ./corr_ae_checkpoints_contrastive/corr_ae_best.pt (epoch 40)


In [49]:
# Encode into latent space
with torch.no_grad():
    # Encode images
    Z_imgs = []
    for i in range(0, image_val.shape[0], 256):
        batch = torch.from_numpy(image_val[i:i+256]).float().to(device)
        z, _ = img_ae(batch)
        Z_imgs.append(z.cpu().numpy())
    Z_imgs = np.concatenate(Z_imgs, axis=0)   # shape (N_images, latent_dim)

    # Encode captions
    Z_caps = []
    for i in range(0, caption_val.shape[0], 256):
        batch = torch.from_numpy(caption_val[i:i+256]).float().to(device)
        z, _ = txt_ae(batch)
        Z_caps.append(z.cpu().numpy())
    Z_caps = np.concatenate(Z_caps, axis=0)   # shape (N_captions, latent_dim)

print("Encoded latent shapes:", Z_imgs.shape, Z_caps.shape)

Encoded latent shapes: (1214, 512) (6070, 512)


In [50]:
#Use Recall@1/5/10 to evaluate hyperparameter performance
#Note that we are using cosine similarity
#Should we consider using L2 metric instead? Does this even make sense?
def retrieval_metrics(Z_caps, Z_imgs, caption_to_image_idx):
    sims = cosine_similarity(Z_caps, Z_imgs)  # (num_caps, num_imgs)
    ranks = []
    for i, true_img_idx in enumerate(caption_to_image_idx):
        sim_scores = sims[i]
        sorted_indices = np.argsort(-sim_scores)  # descending
        rank = np.where(sorted_indices == true_img_idx)[0][0] + 1
        ranks.append(rank)

    ranks = np.array(ranks)
    recall_at_1  = np.mean(ranks <= 1)
    recall_at_5  = np.mean(ranks <= 5)
    recall_at_10 = np.mean(ranks <= 10)
    med_rank = np.median(ranks)

    return {
        "Recall@1": recall_at_1,
        "Recall@5": recall_at_5,
        "Recall@10": recall_at_10,
        "MedianRank": med_rank
    }

metrics_val = retrieval_metrics(Z_caps, Z_imgs, cap2img_val)
for k, v in metrics_val.items():
    print(f"{k}: {v:.4f}")

Recall@1: 0.0465
Recall@5: 0.1588
Recall@10: 0.2631
MedianRank: 30.0000


In [51]:
# evaluate on test set
# Encode into latent space
with torch.no_grad():
    # Encode images
    Z_imgs = []
    for i in range(0, image_test.shape[0], 256):
        batch = torch.from_numpy(image_test[i:i+256]).float().to(device)
        z, _ = img_ae(batch)
        Z_imgs.append(z.cpu().numpy())
    Z_imgs = np.concatenate(Z_imgs, axis=0)   # shape (N_images, latent_dim)

    # Encode captions
    Z_caps = []
    for i in range(0, caption_test.shape[0], 256):
        batch = torch.from_numpy(caption_test[i:i+256]).float().to(device)
        z, _ = txt_ae(batch)
        Z_caps.append(z.cpu().numpy())
    Z_caps = np.concatenate(Z_caps, axis=0)   # shape (N_captions, latent_dim)

print("Encoded latent shapes:", Z_imgs.shape, Z_caps.shape)

def retrieval_metrics(Z_caps, Z_imgs, caption_to_image_idx):
    sims = cosine_similarity(Z_caps, Z_imgs)  # (num_caps, num_imgs)
    ranks = []
    for i, true_img_idx in enumerate(caption_to_image_idx):
        sim_scores = sims[i]
        sorted_indices = np.argsort(-sim_scores)  # descending
        rank = np.where(sorted_indices == true_img_idx)[0][0] + 1
        ranks.append(rank)

    ranks = np.array(ranks)
    recall_at_1  = np.mean(ranks <= 1)
    recall_at_5  = np.mean(ranks <= 5)
    recall_at_10 = np.mean(ranks <= 10)
    med_rank = np.median(ranks)

    return {
        "Recall@1": recall_at_1,
        "Recall@5": recall_at_5,
        "Recall@10": recall_at_10,
        "MedianRank": med_rank
    }

metrics_val = retrieval_metrics(Z_caps, Z_imgs, cap2img_val)
for k, v in metrics_val.items():
    print(f"{k}: {v:.4f}")

Encoded latent shapes: (1214, 512) (6070, 512)
Recall@1: 0.0007
Recall@5: 0.0035
Recall@10: 0.0086
MedianRank: 600.0000


In [None]:
# Prepare data for visualization: get actual caption texts and image filenames for validation set
# Extract the validation caption texts from the original dataframe
val_caption_texts = df["caption"].values[val_mask]

# Extract the validation image filenames
val_image_names = np.array([image_names[i] for i in val_idx])

print(f"Loaded {len(val_caption_texts)} caption texts and {len(val_image_names)} image filenames for validation")
print(f"Example caption: {val_caption_texts[0]}")
print(f"Example image: {val_image_names[0]}")

In [None]:
#Quick visualization of what images are retrieved by what caption:
image_dir = "/Users/sfowler14/Downloads/archive/Images"  # Update this if your images are elsewhere

def show_top_images_for_caption(caption_idx, top_k=5):
    """
    Show top-k retrieved validation images for a given caption index.
    Also displays the true image for comparison.
    """
    # Get the embedding for this caption
    caption_embedding = Z_caps[caption_idx].reshape(1, -1)
    sims = cosine_similarity(caption_embedding, Z_imgs)[0]
    top_img_indices = np.argsort(-sims)[:top_k]

    # Print caption text
    print(f"\nCAPTION: {val_caption_texts[caption_idx]}")
    true_img_idx = cap2img_val[caption_idx]
    print(f"TRUE IMAGE: {val_image_names[true_img_idx]} (index {true_img_idx})")
    
    # Display: true image + top k retrieved images
    plt.figure(figsize=(18, 4))
    
    # Show the true image first
    true_img_name = val_image_names[true_img_idx]
    true_img_path = os.path.join(image_dir, true_img_name)
    try:
        true_img = Image.open(true_img_path)
        plt.subplot(1, top_k + 1, 1)
        plt.imshow(true_img)
        plt.axis('off')
        plt.title("TRUE IMAGE", fontweight='bold', color='green')
    except Exception as e:
        print(f"Could not open true image {true_img_path}: {e}")
    
    # Show retrieved images
    for i, img_idx in enumerate(top_img_indices):
        img_name = val_image_names[img_idx]
        img_path = os.path.join(image_dir, img_name)
        try:
            img = Image.open(img_path)
        except Exception as e:
            print(f"Could not open {img_path}: {e}")
            continue
        
        plt.subplot(1, top_k + 1, i + 2)
        plt.imshow(img)
        plt.axis('off')
        
        # Highlight if this retrieved image matches the true image
        if img_idx == true_img_idx:
            plt.title(f"Rank {i+1} ✓", fontweight='bold', color='green')
        else:
            plt.title(f"Rank {i+1}")
    
    plt.tight_layout()
    plt.show()

In [None]:
import random
for i in random.sample(range(len(caption_val)), 3):
    show_top_images_for_caption(i, top_k=5)