In [21]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import pandas as pd
import numpy as np

# === Dataset ===
class TileCaptionDatasetFromNpy(Dataset):
    def __init__(self, embedding_path, filename_path, caption_csv):
        # Load data
        self.embeddings = np.load(embedding_path, allow_pickle=True)  # shape: [N, embed_dim]
        self.filenames = np.load(filename_path, allow_pickle=True)     # shape: [N], dtype='<U...'

        # Read caption CSV and create filename → caption mapping
        caption_df = pd.read_csv(caption_csv)
        caption_df['filepath'] = caption_df['filepath'].str.replace("/train_250k/", "/train/", regex=False)
        caption_df = caption_df.set_index("filepath")
        caption_map = caption_df["caption"].to_dict()

        # Build samples: only those with both embedding and caption
        self.samples = []
        for i, fname in enumerate(self.filenames):
            if fname in caption_map:
                caption = caption_map[fname]
                self.samples.append((self.embeddings[i], caption))

        print(f"[INFO] Matched {len(self.samples)} image-caption pairs out of {len(self.embeddings)}")

        self.tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        tile_embed, caption = self.samples[idx]
        tokens = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=64,
            return_tensors="pt"
        )
        return {
            "tile": torch.tensor(tile_embed, dtype=torch.float32),
            "input_ids": tokens["input_ids"].squeeze(0),
            "attention_mask": tokens["attention_mask"].squeeze(0)
        }

# === ClinicalBERT Text Encoder ===
class ClinicalBERTEmbedder(nn.Module):
    def __init__(self, out_dim=128):
        super().__init__()
        self.model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
        self.proj = nn.Linear(768, out_dim)  # Project CLS token to same dim as tile embeddings

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        cls_token = output.last_hidden_state[:, 0]
        return self.proj(cls_token)

# === CLIP-style Contrastive Loss ===
class CLIPLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.logit_scale = nn.Parameter(torch.tensor(np.log(1 / temperature), dtype=torch.float32))

    def forward(self, img_embeds, text_embeds):
        img_embeds = F.normalize(img_embeds, dim=-1)
        text_embeds = F.normalize(text_embeds, dim=-1)
        logit_scale = self.logit_scale.exp()

        logits_per_image = img_embeds @ text_embeds.t() * logit_scale
        logits_per_text = text_embeds @ img_embeds.t() * logit_scale

        targets = torch.arange(img_embeds.size(0), device=img_embeds.device)

        loss_i2t = F.cross_entropy(logits_per_image, targets)
        loss_t2i = F.cross_entropy(logits_per_text, targets)
        return (loss_i2t + loss_t2i) / 2

# === Training Loop ===
def train_one_epoch(dataloader, text_encoder, loss_fn, optimizer, device):
    text_encoder.train()
    total_loss = 0.0

    for batch in dataloader:
        tile = batch["tile"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        text_embeds = text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        loss = loss_fn(tile, text_embeds)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

In [None]:
# === Main ===
def main():
    # Paths
    embedding_path = "/gpfs/home/yb2612/dl4med_25/dl_project/results/hpl/train/image_embeddings.npy"
    filename_path = "/gpfs/home/yb2612/dl4med_25/dl_project/results/hpl/train/image_filenames.npy"
    caption_csv = "/gpfs/home/yb2612/dl4med_25/dl_project/data/scratch_data/hpl-clip/long_consistent_captions/lung_250k_filepath_caption.csv"

    # Hyperparams
    batch_size = 64
    num_epochs = 5
    lr = 2e-5

    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset + Dataloader
    dataset = TileCaptionDatasetFromNpy(embedding_path, filename_path, caption_csv)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Model & Loss
    text_encoder = ClinicalBERTEmbedder().to(device)
    loss_fn = CLIPLoss().to(device)
    optimizer = torch.optim.AdamW(list(text_encoder.parameters()) + list(loss_fn.parameters()), lr=lr)

    # Train
    for epoch in range(num_epochs):
        avg_loss = train_one_epoch(dataloader, text_encoder, loss_fn, optimizer, device)
        print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}")

if __name__ == "__main__":
    main()

[INFO] Matched 236635 image-caption pairs out of 551613
[Epoch 1/5] Loss: 3.2077
