In [81]:
import os
import zipfile
import pandas as pd
import torch
import torch.nn as nn
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from torch.optim import Adam
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt

In [2]:
# Path to your zip file
zip_path = os.path.expanduser('~/Downloads/archive(1).zip')
extract_dir = './flickr8k_data'

# Extract only if not already done
if not os.path.exists(extract_dir):
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_dir)

print("Extracted to:", extract_dir)
print("Contents:", os.listdir(extract_dir))

Extracted to: ./flickr8k_data
Contents: ['captions.txt', 'Images']


In [3]:
captions_path = os.path.join(extract_dir, 'captions.txt')
df = pd.read_csv(captions_path)

print(df.head())
print(f"\nNumber of unique images: {df['image'].nunique()}")

                       image  \
0  1000268201_693b08cb0e.jpg   
1  1000268201_693b08cb0e.jpg   
2  1000268201_693b08cb0e.jpg   
3  1000268201_693b08cb0e.jpg   
4  1000268201_693b08cb0e.jpg   

                                             caption  
0  A child in a pink dress is climbing up a set o...  
1              A girl going into a wooden building .  
2   A little girl climbing into a wooden playhouse .  
3  A little girl climbing the stairs to her playh...  
4  A little girl in a pink dress going into a woo...  

Number of unique images: 8091


In [4]:
# --- Image transform (standard ImageNet normalization) ---
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# --- Dataset class for images ---
class FlickrImageDataset(Dataset):
    def __init__(self, image_dir, image_filenames, transform=None):
        self.image_dir = image_dir
        self.image_filenames = list(image_filenames)
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = self.image_filenames[idx]
        path = os.path.join(self.image_dir, img_name)
        # wrap in try-except to catch corrupt images
        try:
            image = Image.open(path).convert("RGB")
        except Exception as e:
            # if image fails to open, create a black image instead and log
            print(f"Failed to open {path}: {e}")
            image = Image.new('RGB', (224,224))
        if self.transform:
            image = self.transform(image)
        return image, img_name

In [5]:
# --- Model setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

resnet = models.resnet50(pretrained=True)
resnet = torch.nn.Sequential(*(list(resnet.children())[:-1]))  # remove FC layer
resnet.to(device)
resnet.eval()

# --- Create dataset and dataloader ---
image_dir = os.path.join(extract_dir, "Images")
unique_images = df["image"].unique()
image_dataset = FlickrImageDataset(image_dir, unique_images, transform)
image_loader = DataLoader(image_dataset, batch_size=32, shuffle=False, num_workers=2)



In [11]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# test model forward
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
resnet = models.resnet50(pretrained=True)
resnet = torch.nn.Sequential(*(list(resnet.children())[:-1])).to(device).eval()



In [13]:
image_filenames = sorted(os.listdir(image_dir))   # better to use df['image'].unique() if you want same order
dataset = FlickrImageDataset(image_dir, image_filenames, transform)
loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0)  # num_workers=0 is safest

image_features_list = []
image_names = []

In [15]:
with torch.no_grad():
    for imgs, names in tqdm(loader, desc="Extracting image features"):
        imgs = imgs.to(device)
        feats = resnet(imgs)               # (B, 2048, 1, 1)
        feats = feats.view(feats.size(0), -1)  # (B, 2048)
        image_features_list.append(feats.cpu())
        image_names.extend(names)

image_features = torch.cat(image_features_list, dim=0).numpy()
print("Done. image_features shape:", image_features.shape)

Extracting image features: 100%|██████████| 253/253 [03:10<00:00,  1.33it/s]

Done. image_features shape: (8091, 2048)





In [17]:
from transformers import BertTokenizer, BertModel

# --- Device and model setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert = BertModel.from_pretrained('bert-base-uncased')
bert.to(device)
bert.eval()

# --- Dataset for captions ---
class FlickrCaptionDataset(Dataset):
    def __init__(self, captions):
        self.captions = captions
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        text = self.captions[idx]
        # Return plain text; we’ll tokenize in collate_fn for batching
        return text

# --- Custom collate_fn to batch tokenize ---
def collate_fn(batch_texts):
    return tokenizer(batch_texts, return_tensors='pt',
                     truncation=True, padding=True, max_length=64)


In [19]:
# --- Create dataset and dataloader ---
caption_dataset = FlickrCaptionDataset(df["caption"].tolist())
caption_loader = DataLoader(
    caption_dataset,
    batch_size=32,         
    shuffle=False,
    num_workers=0,         
    collate_fn=collate_fn  
)

In [21]:
# --- Extract features ---
caption_features = []

with torch.no_grad():
    for batch in tqdm(caption_loader, desc="Extracting caption features"):
        inputs = {k: v.to(device) for k, v in batch.items()}
        outputs = bert(**inputs)
        cls_embeddings = outputs.last_hidden_state[:, 0, :]  # (B, 768)
        caption_features.append(cls_embeddings.cpu())

caption_features = torch.cat(caption_features, dim=0).numpy()

print("Caption features shape:", caption_features.shape)

Extracting caption features: 100%|██████████| 1265/1265 [05:01<00:00,  4.19it/s]

Caption features shape: (40455, 768)





In [23]:
# Map image filename to its index
image_to_idx = {name: i for i, name in enumerate(image_names)}

# For each caption row, find which image it corresponds to
caption_to_image_idx = df["image"].map(image_to_idx).values

print("caption_to_image_idx shape:", caption_to_image_idx.shape)
print("Example mapping:", list(zip(df['caption'][:3], caption_to_image_idx[:3])))

caption_to_image_idx shape: (40455,)
Example mapping: [('A child in a pink dress is climbing up a set of stairs in an entry way .', 0), ('A girl going into a wooden building .', 0), ('A little girl climbing into a wooden playhouse .', 0)]


In [29]:
np.save("flickr8k_image_features.npy", image_features)
np.save("flickr8k_caption_features.npy", caption_features)
np.save("flickr8k_caption_to_image.npy", caption_to_image_idx)
np.save("flickr8k_image_names.npy", np.array(image_names))

In [33]:
# Split on images
n_images = len(image_names)
indices = np.arange(n_images)

In [39]:
# 70/15/15 split
train_idx, temp_idx = train_test_split(indices, test_size=0.30, random_state=42, shuffle=True)
val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42, shuffle=True)
print(f"Train images: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_idx)}")

Train images: 5663, Val: 1214, Test: 1214


In [41]:
# Create masks for captions 
caption_to_image_idx = caption_to_image_idx.astype(int)

train_mask = np.isin(caption_to_image_idx, train_idx)
val_mask = np.isin(caption_to_image_idx, val_idx)
test_mask = np.isin(caption_to_image_idx, test_idx)

# Split image features 
image_train = image_features[train_idx]
image_val   = image_features[val_idx]
image_test  = image_features[test_idx]

# Split captions (and keep their alignment)
caption_train = caption_features[train_mask]
caption_val   = caption_features[val_mask]
caption_test  = caption_features[test_mask]

In [43]:
# Link captions to local image indices within each split
def remap_caption_indices(global_indices, split_indices):
    """
    Convert global image indices in caption_to_image_idx to 0..len(split_indices)-1 within that split.
    """
    mapping = {g: i for i, g in enumerate(split_indices)}
    return np.array([mapping[i] for i in global_indices if i in mapping])

In [45]:
caption_to_train_img = remap_caption_indices(caption_to_image_idx[train_mask], train_idx)
caption_to_val_img   = remap_caption_indices(caption_to_image_idx[val_mask], val_idx)
caption_to_test_img  = remap_caption_indices(caption_to_image_idx[test_mask], test_idx)

In [47]:
# Sanity check
print("Train split shapes:")
print("  Image features:", image_train.shape)
print("  Caption features:", caption_train.shape)
print("  Caption→Image indices:", caption_to_train_img.shape)

Train split shapes:
  Image features: (5663, 2048)
  Caption features: (28315, 768)
  Caption→Image indices: (28315,)


In [49]:
np.save("train_image_features.npy", image_train)
np.save("val_image_features.npy", image_val)
np.save("test_image_features.npy", image_test)

np.save("train_caption_features.npy", caption_train)
np.save("val_caption_features.npy", caption_val)
np.save("test_caption_features.npy", caption_test)

np.save("train_caption_to_image.npy", caption_to_train_img)
np.save("val_caption_to_image.npy", caption_to_val_img)
np.save("test_caption_to_image.npy", caption_to_test_img)

In [57]:
# Now time to set up and train the actual multi-modal autoencoder
config = {
    "image_feat_path": "train_image_features.npy",   # we'll load proper files below
    "caption_feat_path": "train_caption_features.npy",
    "caption_to_image_path": "train_caption_to_image.npy",
    "val_image_feat_path": "val_image_features.npy",
    "val_caption_feat_path": "val_caption_features.npy",
    "val_caption_to_image_path": "val_caption_to_image.npy",
    "latent_dim": 512,
    "img_input_dim": 2048,
    "txt_input_dim": 768,
    "img_hidden": 1024,
    "txt_hidden": 512,
    "batch_size": 128,    # try 128; lower if memory limited (e.g., 64)
    "lr": 1e-3,
    "weight_decay": 1e-5,
    "epochs": 40,
    "lambda_align": 1.0,  # weight for latent alignment loss; tuneable
    "checkpoint_dir": "./corr_ae_checkpoints",
    "seed": 42,
}

In [59]:
torch.manual_seed(config["seed"])
np.random.seed(config["seed"])

os.makedirs(config["checkpoint_dir"], exist_ok=True)

In [61]:
#I don't need to reload these if we run it all in the same notebook but I'm pasting the load functions here anyway
image_train = np.load(config["image_feat_path"])
caption_train = np.load(config["caption_feat_path"])
cap2img_train = np.load(config["caption_to_image_path"])

image_val = np.load(config["val_image_feat_path"])
caption_val = np.load(config["val_caption_feat_path"])
cap2img_val = np.load(config["val_caption_to_image_path"])

print("Shapes (train):", image_train.shape, caption_train.shape, cap2img_train.shape)
print("Shapes (val):", image_val.shape, caption_val.shape, cap2img_val.shape)

Shapes (train): (5663, 2048) (28315, 768) (28315,)
Shapes (val): (1214, 2048) (6070, 768) (6070,)


In [63]:
# 3) Compute train-set normalization (mean/std) and apply to all splits
# Normalize per-feature (column-wise) using training set statistics
img_mean = image_train.mean(axis=0, keepdims=True)
img_std = image_train.std(axis=0, keepdims=True) + 1e-6

txt_mean = caption_train.mean(axis=0, keepdims=True)
txt_std = caption_train.std(axis=0, keepdims=True) + 1e-6

In [65]:
def normalize_images(x):
    return (x - img_mean) / img_std

def normalize_texts(x):
    return (x - txt_mean) / txt_std

image_train = normalize_images(image_train)
image_val   = normalize_images(image_val)

caption_train = normalize_texts(caption_train)
caption_val   = normalize_texts(caption_val)

In [67]:
# 4) Dataset that returns paired (image_feat, caption_feat) for each caption
class CaptionImagePairedDataset(Dataset):
    """
    Iterates over captions. For index i, returns:
      caption_features[i], image_features[ caption_to_image_idx[i] ]
    """
    def __init__(self, caption_feats, image_feats, caption_to_image_idx):
        assert len(caption_feats) == len(caption_to_image_idx)
        self.caption_feats = caption_feats.astype(np.float32)
        self.image_feats = image_feats.astype(np.float32)
        self.cap2img = caption_to_image_idx.astype(np.int64)

    def __len__(self):
        return len(self.caption_feats)

    def __getitem__(self, idx):
        cap = self.caption_feats[idx]
        img = self.image_feats[self.cap2img[idx]]
        return {"image": torch.from_numpy(img), "caption": torch.from_numpy(cap)}

train_dataset = CaptionImagePairedDataset(caption_train, image_train, cap2img_train)
val_dataset   = CaptionImagePairedDataset(caption_val, image_val, cap2img_val)

train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=0)

In [73]:
# 5) Model: two autoencoders with shared latent dimension
class ImageAE(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=1024, latent_dim=512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)
        return z, recon

class TextAE(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=512, latent_dim=512):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)
        return z, recon

In [75]:
# Instantiate models
img_ae = ImageAE(
    input_dim=config["img_input_dim"],
    hidden_dim=config["img_hidden"],
    latent_dim=config["latent_dim"]
).to(device)

txt_ae = TextAE(
    input_dim=config["txt_input_dim"],
    hidden_dim=config["txt_hidden"],
    latent_dim=config["latent_dim"]
).to(device)

# 6) Losses and optimizer
recon_loss_fn = nn.MSELoss()    # reconstruction for both
align_loss_fn = nn.MSELoss()    # align latents

params = list(img_ae.parameters()) + list(txt_ae.parameters())
optimizer = Adam(params, lr=config["lr"], weight_decay=config["weight_decay"])

In [77]:
# 7) Training / validation loop
def run_epoch(loader, training=True):
    if training:
        img_ae.train(); txt_ae.train()
    else:
        img_ae.eval(); txt_ae.eval()

    total_recon_img = 0.0
    total_recon_txt = 0.0
    total_align = 0.0
    total_loss = 0.0
    n_samples = 0

    pbar = tqdm(loader, desc="train" if training else "val")
    with torch.set_grad_enabled(training):
        for batch in pbar:
            imgs = batch["image"].to(device)    # shape (B, img_dim)
            caps = batch["caption"].to(device)  # shape (B, txt_dim)
            batch_size = imgs.shape[0]

            # forward
            z_img, img_recon = img_ae(imgs)
            z_txt, txt_recon = txt_ae(caps)

            # losses
            L_img = recon_loss_fn(img_recon, imgs)
            L_txt = recon_loss_fn(txt_recon, caps)
            L_align = align_loss_fn(z_img, z_txt)

            loss = L_img + L_txt + config["lambda_align"] * L_align

            if training:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            total_recon_img += L_img.item() * batch_size
            total_recon_txt += L_txt.item() * batch_size
            total_align += L_align.item() * batch_size
            total_loss += loss.item() * batch_size
            n_samples += batch_size

            pbar.set_postfix({
                "loss": f"{total_loss / n_samples:.4f}",
                "Limg": f"{total_recon_img / n_samples:.4f}",
                "Ltxt": f"{total_recon_txt / n_samples:.4f}",
                "Lalign": f"{total_align / n_samples:.4f}"
            })

    return {
        "loss": total_loss / n_samples,
        "Limg": total_recon_img / n_samples,
        "Ltxt": total_recon_txt / n_samples,
        "Lalign": total_align / n_samples
    }

In [79]:
best_val_loss = float("inf")

for epoch in range(1, config["epochs"] + 1):
    print(f"\n=== Epoch {epoch}/{config['epochs']} ===")
    train_metrics = run_epoch(train_loader, training=True)
    val_metrics = run_epoch(val_loader, training=False)

    print(f"Train loss: {train_metrics['loss']:.4f} | Val loss: {val_metrics['loss']:.4f}")

    # Save checkpoint (every epoch)
    ckpt = {
        "epoch": epoch,
        "img_state": img_ae.state_dict(),
        "txt_state": txt_ae.state_dict(),
        "optimizer": optimizer.state_dict(),
        "train_metrics": train_metrics,
        "val_metrics": val_metrics,
        "config": config
    }
    ckpt_path = os.path.join(config["checkpoint_dir"], f"corr_ae_epoch{epoch}.pt")
    torch.save(ckpt, ckpt_path)

    # Keep best
    if val_metrics["loss"] < best_val_loss:
        best_val_loss = val_metrics["loss"]
        torch.save(ckpt, os.path.join(config["checkpoint_dir"], "corr_ae_best.pt"))
        print("Saved best checkpoint.")

print("Training finished.")


=== Epoch 1/40 ===


train: 100%|██████████| 222/222 [00:05<00:00, 42.26it/s, loss=0.9268, Limg=0.4700, Ltxt=0.3571, Lalign=0.0997]
val: 100%|██████████| 48/48 [00:00<00:00, 116.35it/s, loss=0.6464, Limg=0.3226, Ltxt=0.2340, Lalign=0.0897]


Train loss: 0.9268 | Val loss: 0.6464
Saved best checkpoint.

=== Epoch 2/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 47.15it/s, loss=0.5338, Limg=0.2490, Ltxt=0.2050, Lalign=0.0798]
val: 100%|██████████| 48/48 [00:00<00:00, 114.00it/s, loss=0.5140, Limg=0.2508, Ltxt=0.1895, Lalign=0.0738]


Train loss: 0.5338 | Val loss: 0.5140
Saved best checkpoint.

=== Epoch 3/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 46.68it/s, loss=0.4416, Limg=0.2014, Ltxt=0.1732, Lalign=0.0670]
val: 100%|██████████| 48/48 [00:00<00:00, 110.97it/s, loss=0.4616, Limg=0.2276, Ltxt=0.1696, Lalign=0.0643]


Train loss: 0.4416 | Val loss: 0.4616
Saved best checkpoint.

=== Epoch 4/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 47.91it/s, loss=0.3897, Limg=0.1776, Ltxt=0.1530, Lalign=0.0592]
val: 100%|██████████| 48/48 [00:00<00:00, 130.01it/s, loss=0.4230, Limg=0.2118, Ltxt=0.1534, Lalign=0.0578]


Train loss: 0.3897 | Val loss: 0.4230
Saved best checkpoint.

=== Epoch 5/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 48.13it/s, loss=0.3593, Limg=0.1646, Ltxt=0.1404, Lalign=0.0543]
val: 100%|██████████| 48/48 [00:00<00:00, 109.35it/s, loss=0.3924, Limg=0.2004, Ltxt=0.1386, Lalign=0.0534]


Train loss: 0.3593 | Val loss: 0.3924
Saved best checkpoint.

=== Epoch 6/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 47.70it/s, loss=0.3339, Limg=0.1543, Ltxt=0.1290, Lalign=0.0506]
val: 100%|██████████| 48/48 [00:00<00:00, 111.17it/s, loss=0.3802, Limg=0.1980, Ltxt=0.1317, Lalign=0.0504]


Train loss: 0.3339 | Val loss: 0.3802
Saved best checkpoint.

=== Epoch 7/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 48.45it/s, loss=0.3163, Limg=0.1485, Ltxt=0.1197, Lalign=0.0481]
val: 100%|██████████| 48/48 [00:00<00:00, 111.88it/s, loss=0.3626, Limg=0.1934, Ltxt=0.1213, Lalign=0.0479]


Train loss: 0.3163 | Val loss: 0.3626
Saved best checkpoint.

=== Epoch 8/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 47.10it/s, loss=0.3018, Limg=0.1432, Ltxt=0.1124, Lalign=0.0462]
val: 100%|██████████| 48/48 [00:00<00:00, 97.88it/s, loss=0.3533, Limg=0.1909, Ltxt=0.1161, Lalign=0.0464]


Train loss: 0.3018 | Val loss: 0.3533
Saved best checkpoint.

=== Epoch 9/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 46.66it/s, loss=0.2883, Limg=0.1386, Ltxt=0.1050, Lalign=0.0447]
val: 100%|██████████| 48/48 [00:00<00:00, 110.59it/s, loss=0.3379, Limg=0.1872, Ltxt=0.1061, Lalign=0.0446]


Train loss: 0.2883 | Val loss: 0.3379
Saved best checkpoint.

=== Epoch 10/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 48.29it/s, loss=0.2769, Limg=0.1345, Ltxt=0.0989, Lalign=0.0434]
val: 100%|██████████| 48/48 [00:00<00:00, 111.83it/s, loss=0.3334, Limg=0.1871, Ltxt=0.1027, Lalign=0.0435]


Train loss: 0.2769 | Val loss: 0.3334
Saved best checkpoint.

=== Epoch 11/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 48.35it/s, loss=0.2675, Limg=0.1315, Ltxt=0.0935, Lalign=0.0426]
val: 100%|██████████| 48/48 [00:00<00:00, 106.79it/s, loss=0.3213, Limg=0.1851, Ltxt=0.0937, Lalign=0.0425]


Train loss: 0.2675 | Val loss: 0.3213
Saved best checkpoint.

=== Epoch 12/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 48.08it/s, loss=0.2609, Limg=0.1297, Ltxt=0.0893, Lalign=0.0420]
val: 100%|██████████| 48/48 [00:00<00:00, 108.92it/s, loss=0.3142, Limg=0.1839, Ltxt=0.0879, Lalign=0.0423]


Train loss: 0.2609 | Val loss: 0.3142
Saved best checkpoint.

=== Epoch 13/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 48.53it/s, loss=0.2524, Limg=0.1264, Ltxt=0.0847, Lalign=0.0413]
val: 100%|██████████| 48/48 [00:00<00:00, 114.46it/s, loss=0.3065, Limg=0.1790, Ltxt=0.0861, Lalign=0.0414]


Train loss: 0.2524 | Val loss: 0.3065
Saved best checkpoint.

=== Epoch 14/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 46.92it/s, loss=0.2467, Limg=0.1249, Ltxt=0.0810, Lalign=0.0407]
val: 100%|██████████| 48/48 [00:00<00:00, 86.83it/s, loss=0.3114, Limg=0.1811, Ltxt=0.0892, Lalign=0.0411]


Train loss: 0.2467 | Val loss: 0.3114

=== Epoch 15/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 47.60it/s, loss=0.2414, Limg=0.1230, Ltxt=0.0780, Lalign=0.0404]
val: 100%|██████████| 48/48 [00:00<00:00, 114.80it/s, loss=0.3132, Limg=0.1773, Ltxt=0.0954, Lalign=0.0404]


Train loss: 0.2414 | Val loss: 0.3132

=== Epoch 16/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 48.51it/s, loss=0.2366, Limg=0.1204, Ltxt=0.0762, Lalign=0.0400]
val: 100%|██████████| 48/48 [00:00<00:00, 111.87it/s, loss=0.3002, Limg=0.1755, Ltxt=0.0844, Lalign=0.0403]


Train loss: 0.2366 | Val loss: 0.3002
Saved best checkpoint.

=== Epoch 17/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 48.14it/s, loss=0.2324, Limg=0.1199, Ltxt=0.0727, Lalign=0.0399]
val: 100%|██████████| 48/48 [00:00<00:00, 106.69it/s, loss=0.2845, Limg=0.1731, Ltxt=0.0716, Lalign=0.0398]


Train loss: 0.2324 | Val loss: 0.2845
Saved best checkpoint.

=== Epoch 18/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 47.54it/s, loss=0.2267, Limg=0.1163, Ltxt=0.0710, Lalign=0.0393]
val: 100%|██████████| 48/48 [00:00<00:00, 119.36it/s, loss=0.2853, Limg=0.1747, Ltxt=0.0710, Lalign=0.0396]


Train loss: 0.2267 | Val loss: 0.2853

=== Epoch 19/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 48.02it/s, loss=0.2242, Limg=0.1165, Ltxt=0.0685, Lalign=0.0392]
val: 100%|██████████| 48/48 [00:00<00:00, 113.21it/s, loss=0.2837, Limg=0.1734, Ltxt=0.0709, Lalign=0.0394]


Train loss: 0.2242 | Val loss: 0.2837
Saved best checkpoint.

=== Epoch 20/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 46.76it/s, loss=0.2215, Limg=0.1147, Ltxt=0.0676, Lalign=0.0392]
val: 100%|██████████| 48/48 [00:00<00:00, 97.89it/s, loss=0.2884, Limg=0.1777, Ltxt=0.0712, Lalign=0.0395]


Train loss: 0.2215 | Val loss: 0.2884

=== Epoch 21/40 ===


train: 100%|██████████| 222/222 [00:05<00:00, 44.35it/s, loss=0.2202, Limg=0.1149, Ltxt=0.0661, Lalign=0.0393]
val: 100%|██████████| 48/48 [00:00<00:00, 117.63it/s, loss=0.2816, Limg=0.1757, Ltxt=0.0666, Lalign=0.0393]


Train loss: 0.2202 | Val loss: 0.2816
Saved best checkpoint.

=== Epoch 22/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 47.45it/s, loss=0.2154, Limg=0.1120, Ltxt=0.0644, Lalign=0.0389]
val: 100%|██████████| 48/48 [00:00<00:00, 113.98it/s, loss=0.2793, Limg=0.1730, Ltxt=0.0673, Lalign=0.0390]


Train loss: 0.2154 | Val loss: 0.2793
Saved best checkpoint.

=== Epoch 23/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 47.42it/s, loss=0.2136, Limg=0.1108, Ltxt=0.0641, Lalign=0.0388]
val: 100%|██████████| 48/48 [00:00<00:00, 104.69it/s, loss=0.2799, Limg=0.1739, Ltxt=0.0668, Lalign=0.0393]


Train loss: 0.2136 | Val loss: 0.2799

=== Epoch 24/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 46.82it/s, loss=0.2100, Limg=0.1103, Ltxt=0.0609, Lalign=0.0387]
val: 100%|██████████| 48/48 [00:00<00:00, 110.96it/s, loss=0.2750, Limg=0.1729, Ltxt=0.0633, Lalign=0.0388]


Train loss: 0.2100 | Val loss: 0.2750
Saved best checkpoint.

=== Epoch 25/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 47.69it/s, loss=0.2114, Limg=0.1113, Ltxt=0.0612, Lalign=0.0389]
val: 100%|██████████| 48/48 [00:00<00:00, 102.85it/s, loss=0.2793, Limg=0.1690, Ltxt=0.0711, Lalign=0.0393]


Train loss: 0.2114 | Val loss: 0.2793

=== Epoch 26/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 48.80it/s, loss=0.2079, Limg=0.1082, Ltxt=0.0611, Lalign=0.0387]
val: 100%|██████████| 48/48 [00:00<00:00, 122.12it/s, loss=0.2748, Limg=0.1705, Ltxt=0.0651, Lalign=0.0392]


Train loss: 0.2079 | Val loss: 0.2748
Saved best checkpoint.

=== Epoch 27/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 49.03it/s, loss=0.2057, Limg=0.1080, Ltxt=0.0591, Lalign=0.0386]
val: 100%|██████████| 48/48 [00:00<00:00, 130.02it/s, loss=0.2739, Limg=0.1708, Ltxt=0.0641, Lalign=0.0390]


Train loss: 0.2057 | Val loss: 0.2739
Saved best checkpoint.

=== Epoch 28/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 49.24it/s, loss=0.2045, Limg=0.1078, Ltxt=0.0580, Lalign=0.0387]
val: 100%|██████████| 48/48 [00:00<00:00, 138.62it/s, loss=0.2735, Limg=0.1696, Ltxt=0.0650, Lalign=0.0389]


Train loss: 0.2045 | Val loss: 0.2735
Saved best checkpoint.

=== Epoch 29/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 48.66it/s, loss=0.2020, Limg=0.1054, Ltxt=0.0580, Lalign=0.0385]
val: 100%|██████████| 48/48 [00:00<00:00, 109.11it/s, loss=0.2719, Limg=0.1698, Ltxt=0.0633, Lalign=0.0389]


Train loss: 0.2020 | Val loss: 0.2719
Saved best checkpoint.

=== Epoch 30/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 47.33it/s, loss=0.2016, Limg=0.1053, Ltxt=0.0579, Lalign=0.0385]
val: 100%|██████████| 48/48 [00:00<00:00, 115.56it/s, loss=0.2679, Limg=0.1685, Ltxt=0.0606, Lalign=0.0388]


Train loss: 0.2016 | Val loss: 0.2679
Saved best checkpoint.

=== Epoch 31/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 47.84it/s, loss=0.2002, Limg=0.1053, Ltxt=0.0563, Lalign=0.0385]
val: 100%|██████████| 48/48 [00:00<00:00, 125.28it/s, loss=0.2698, Limg=0.1701, Ltxt=0.0609, Lalign=0.0388]


Train loss: 0.2002 | Val loss: 0.2698

=== Epoch 32/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 47.63it/s, loss=0.1970, Limg=0.1033, Ltxt=0.0554, Lalign=0.0384]
val: 100%|██████████| 48/48 [00:00<00:00, 114.55it/s, loss=0.2625, Limg=0.1655, Ltxt=0.0586, Lalign=0.0385]


Train loss: 0.1970 | Val loss: 0.2625
Saved best checkpoint.

=== Epoch 33/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 46.00it/s, loss=0.1982, Limg=0.1043, Ltxt=0.0555, Lalign=0.0385]
val: 100%|██████████| 48/48 [00:00<00:00, 104.15it/s, loss=0.2699, Limg=0.1715, Ltxt=0.0595, Lalign=0.0389]


Train loss: 0.1982 | Val loss: 0.2699

=== Epoch 34/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 46.78it/s, loss=0.1973, Limg=0.1030, Ltxt=0.0558, Lalign=0.0385]
val: 100%|██████████| 48/48 [00:00<00:00, 118.48it/s, loss=0.2631, Limg=0.1671, Ltxt=0.0572, Lalign=0.0387]


Train loss: 0.1973 | Val loss: 0.2631

=== Epoch 35/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 47.34it/s, loss=0.1950, Limg=0.1030, Ltxt=0.0535, Lalign=0.0385]
val: 100%|██████████| 48/48 [00:00<00:00, 110.22it/s, loss=0.2648, Limg=0.1681, Ltxt=0.0580, Lalign=0.0387]


Train loss: 0.1950 | Val loss: 0.2648

=== Epoch 36/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 46.88it/s, loss=0.1928, Limg=0.1009, Ltxt=0.0536, Lalign=0.0382]
val: 100%|██████████| 48/48 [00:00<00:00, 119.23it/s, loss=0.2637, Limg=0.1667, Ltxt=0.0584, Lalign=0.0386]


Train loss: 0.1928 | Val loss: 0.2637

=== Epoch 37/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 47.52it/s, loss=0.1929, Limg=0.1010, Ltxt=0.0536, Lalign=0.0383]
val: 100%|██████████| 48/48 [00:00<00:00, 116.62it/s, loss=0.2615, Limg=0.1643, Ltxt=0.0586, Lalign=0.0385]


Train loss: 0.1929 | Val loss: 0.2615
Saved best checkpoint.

=== Epoch 38/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 47.45it/s, loss=0.1920, Limg=0.1010, Ltxt=0.0527, Lalign=0.0383]
val: 100%|██████████| 48/48 [00:00<00:00, 114.42it/s, loss=0.2672, Limg=0.1679, Ltxt=0.0605, Lalign=0.0388]


Train loss: 0.1920 | Val loss: 0.2672

=== Epoch 39/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 48.06it/s, loss=0.1927, Limg=0.1002, Ltxt=0.0541, Lalign=0.0384]
val: 100%|██████████| 48/48 [00:00<00:00, 123.06it/s, loss=0.2606, Limg=0.1656, Ltxt=0.0562, Lalign=0.0389]


Train loss: 0.1927 | Val loss: 0.2606
Saved best checkpoint.

=== Epoch 40/40 ===


train: 100%|██████████| 222/222 [00:04<00:00, 48.57it/s, loss=0.1897, Limg=0.0996, Ltxt=0.0518, Lalign=0.0383]
val: 100%|██████████| 48/48 [00:00<00:00, 119.64it/s, loss=0.2623, Limg=0.1653, Ltxt=0.0581, Lalign=0.0388]


Train loss: 0.1897 | Val loss: 0.2623
Training finished.


In [83]:
#Now time to evaluate on the validation set
#Probably don't need to reload the model, but I'm going to include the code again in case we break this up into more managable files
# --- Load best checkpoint ---
best_ckpt_path = os.path.join(config["checkpoint_dir"], "corr_ae_best.pt")
ckpt = torch.load(best_ckpt_path, map_location=device)

img_ae.load_state_dict(ckpt["img_state"])
txt_ae.load_state_dict(ckpt["txt_state"])
img_ae.eval(); txt_ae.eval()

print(f"Loaded best checkpoint from {best_ckpt_path} (epoch {ckpt['epoch']})")

Loaded best checkpoint from ./corr_ae_checkpoints\corr_ae_best.pt (epoch 39)


In [85]:
# Encode into latent space
with torch.no_grad():
    # Encode images
    Z_imgs = []
    for i in range(0, image_val.shape[0], 256):
        batch = torch.from_numpy(image_val[i:i+256]).float().to(device)
        z, _ = img_ae(batch)
        Z_imgs.append(z.cpu().numpy())
    Z_imgs = np.concatenate(Z_imgs, axis=0)   # shape (N_images, latent_dim)

    # Encode captions
    Z_caps = []
    for i in range(0, caption_val.shape[0], 256):
        batch = torch.from_numpy(caption_val[i:i+256]).float().to(device)
        z, _ = txt_ae(batch)
        Z_caps.append(z.cpu().numpy())
    Z_caps = np.concatenate(Z_caps, axis=0)   # shape (N_captions, latent_dim)

print("Encoded latent shapes:", Z_imgs.shape, Z_caps.shape)

Encoded latent shapes: (1214, 512) (6070, 512)


In [87]:
#Use Recall@1/5/10 to evaluate hyperparameter performance
#Note that we are using cosine similarity
#Should we consider using L2 metric instead? Does this even make sense?
def retrieval_metrics(Z_caps, Z_imgs, caption_to_image_idx):
    sims = cosine_similarity(Z_caps, Z_imgs)  # (num_caps, num_imgs)
    ranks = []
    for i, true_img_idx in enumerate(caption_to_image_idx):
        sim_scores = sims[i]
        sorted_indices = np.argsort(-sim_scores)  # descending
        rank = np.where(sorted_indices == true_img_idx)[0][0] + 1
        ranks.append(rank)

    ranks = np.array(ranks)
    recall_at_1  = np.mean(ranks <= 1)
    recall_at_5  = np.mean(ranks <= 5)
    recall_at_10 = np.mean(ranks <= 10)
    med_rank = np.median(ranks)

    return {
        "Recall@1": recall_at_1,
        "Recall@5": recall_at_5,
        "Recall@10": recall_at_10,
        "MedianRank": med_rank
    }

metrics_val = retrieval_metrics(Z_caps, Z_imgs, cap2img_val)
for k, v in metrics_val.items():
    print(f"{k}: {v:.4f}")

Recall@1: 0.0524
Recall@5: 0.1633
Recall@10: 0.2540
MedianRank: 38.0000


In [93]:
#Quick visualization of what images are retrieved by what caption:
image_dir = "path/to/Flickr8k/Images"  # e.g. "/content/Flickr8k/Images"

def show_top_images_for_caption(caption_idx, top_k=5):
    """
    Show top-k retrieved validation images for a given caption index.
    """
    # Get the embedding for this caption
    caption_embedding = Z_caps[caption_idx].reshape(1, -1)
    sims = cosine_similarity(caption_embedding, Z_imgs)[0]
    top_img_indices = np.argsort(-sims)[:top_k]

    # Print caption text
    print(f"\nCAPTION: {caption_val[caption_idx]}")
    true_img_idx = cap2img_val[caption_idx]
    print(f"TRUE IMAGE: {image_val[true_img_idx]} (index {true_img_idx})")
    print(f"Top {top_k} retrieved images:")

    # Display images
    plt.figure(figsize=(15, 4))
    for i, img_idx in enumerate(top_img_indices):
        img_name = image_val[img_idx]  # this assumes image_val contains filenames
        img_path = os.path.join(image_dir, img_name)
        try:
            img = Image.open(img_path)
        except Exception as e:
            print(f"Could not open {img_path}: {e}")
            continue
        plt.subplot(1, top_k, i+1)
        plt.imshow(img)
        plt.axis('off')
        plt.title(f"Rank {i+1}")
    plt.show()

In [95]:
for i in random.sample(range(len(caption_val)), 3):
    show_top_images_for_caption(i, top_k=5)


CAPTION: [-4.40105468e-01  2.62257129e-01 -1.82232273e+00  1.71197605e+00
  1.61076617e+00  1.87989962e+00 -1.62736058e+00  1.18145227e+00
 -2.05603933e+00  1.02855766e+00  1.01329672e+00 -7.16462910e-01
 -9.71074820e-01  8.97090316e-01 -3.40079486e-01  6.10017776e-01
 -9.27111983e-01 -1.01340437e+00 -4.90179360e-01 -2.48719901e-01
 -4.23995793e-01  1.04763293e+00 -1.71742356e+00  6.69775307e-01
 -4.73519862e-01  1.22269928e+00  6.58022404e-01 -8.49948704e-01
  1.39369681e-01  1.06526160e+00  3.10615599e-01 -1.74752653e+00
  7.66426086e-01  6.06213868e-01 -3.13661933e-01 -1.16403329e+00
 -1.45711675e-01  2.63895333e-01 -1.61355817e+00 -9.26913738e-01
  1.97226977e+00  3.24242890e-01 -5.30107796e-01 -1.21021128e+00
  1.24355778e-01 -5.41140616e-01  5.43014467e-01  9.64143932e-01
  1.34140587e+00  8.16520154e-01 -7.91564763e-01 -5.19880056e-01
  5.68725288e-01  3.08898352e-02  6.28872156e-01  1.61347345e-01
  1.35990477e+00 -6.41580582e-01  2.30055881e+00  4.66104656e-01
 -3.53251338e-0

TypeError: join() argument must be str, bytes, or os.PathLike object, not 'ndarray'

<Figure size 1500x400 with 0 Axes>