In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import os
import random
from typing import List
import matplotlib.pyplot as plt
from google.colab import drive


In [2]:
# Constants
embedding_dim = 1536
max_patches = 1000  # fixed length for padding
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths
gcs_bucket = "gs://bracs-dataset-bucket/Embeddings/train"
local_embedding_root = Path("/content/embeddings")
local_embedding_root.mkdir(parents=True, exist_ok=True)

In [34]:
# 📦 Step 1: Install dependencies
!pip install --quiet openslide-python
!apt-get install -y -qq openslide-tools
!pip install --upgrade google-cloud-storage

# 📂 Step 2: Set up GCS access
from google.colab import auth
auth.authenticate_user()

from google.cloud import storage
from pathlib import Path
import os



Selecting previously unselected package libopenslide0.
(Reading database ... 126102 files and directories currently installed.)
Preparing to unpack .../libopenslide0_3.4.1+dfsg-5build1_amd64.deb ...
Unpacking libopenslide0 (3.4.1+dfsg-5build1) ...
Selecting previously unselected package openslide-tools.
Preparing to unpack .../openslide-tools_3.4.1+dfsg-5build1_amd64.deb ...
Unpacking openslide-tools (3.4.1+dfsg-5build1) ...
Setting up libopenslide0 (3.4.1+dfsg-5build1) ...
Setting up openslide-tools (3.4.1+dfsg-5build1) ...
Processing triggers for man-db (2.10.2-1) ...
Processing triggers for libc-bin (2.35-0ubuntu3.8) ...
/sbin/ldconfig.real: /usr/local/lib/libtbbbind.so.3 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc_proxy.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtbbmalloc.so.2 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libhwloc.so.15 is not a symbolic link

/sbin/ldconfig.real: /usr/local/lib/libtcm.so.1 is not 

In [77]:
import subprocess
from pathlib import Path

# Set GCS and local paths
bucket_path = "gs://bracs-dataset-bucket/Embeddings/train"
local_path = Path("/content/embeddings/train")
local_path.mkdir(parents=True, exist_ok=True)

# Run the gsutil copy command (recursive, parallel)
print("🔽 Downloading all embeddings from GCS...")
subprocess.run([
    "gsutil", "-m", "cp", "-r", f"{bucket_path}/*", str(local_path)
], check=True)

print("✅ All embeddings downloaded to:", local_path)

🔽 Downloading all embeddings from GCS...
✅ All embeddings downloaded to: /content/embeddings/train


In [83]:
import os
import random
import torch
from pathlib import Path
from torch.utils.data import Dataset

class WSIPairDataset(Dataset):
    def __init__(self, embedding_root: str, label_dict: dict, seed: int = 42):
        """
        Dataset for training retrieval models using WSI tile embeddings.

        Args:
            embedding_root (str): Path to embedding root containing subfolders for each WSI.
            label_dict (dict): Dict mapping slide_id to list or set of label indices.
            seed (int): Random seed for reproducibility.
        """
        self.embedding_root = Path(embedding_root)
        self.label_dict = {k: set(v) for k, v in label_dict.items()}


        self.slide_ids = [
            slide_id for slide_id in self.label_dict.keys()
            if (self.embedding_root / slide_id / f"{slide_id}_embeddings.pt").exists()
        ]

        if not self.slide_ids:
            raise ValueError("No valid slides found in embedding_root.")

        random.seed(seed)

    def __len__(self):
        return 50  # Use high number for infinite-style sampling

    def __getitem__(self, idx):
        # Randomly sample query and document (≠ query)
        query_id = random.choice(self.slide_ids)
        doc_id = random.choice([sid for sid in self.slide_ids if sid != query_id])

        # Load embeddings
        query_data = torch.load(self.embedding_root / query_id / f"{query_id}_embeddings.pt")
        doc_data = torch.load(self.embedding_root / doc_id / f"{doc_id}_embeddings.pt")

        q_emb = query_data["embeddings"]  # [m, 1536]
        q_coord = query_data["coords"]    # [m, 2]

        d_emb = doc_data["embeddings"]    # [n, 1536]
        d_coord = doc_data["coords"]      # [n, 2]

        # Compute Jaccard Index
        labels_q = self.label_dict[query_id]
        labels_d = self.label_dict[doc_id]

        if not labels_q and not labels_d:
            jaccard = 1.0
        else:
            jaccard = len(labels_q & labels_d) / len(labels_q | labels_d)

        return {
            "query_id": query_id,
            "doc_id": doc_id,
            "query_embeds": q_emb,    # shape [m, d]
            "doc_embeds": d_emb,      # shape [n, d]
            "query_coords": q_coord,  # shape [m, 2]
            "doc_coords": d_coord,    # shape [n, 2]
            "jaccard": torch.tensor(jaccard, dtype=torch.float)
        }


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import random


In [72]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimilarityAggregator(nn.Module):
    def __init__(self, patch_count_max: int, hidden_dim: int = 256):
        super().__init__()
        self.row_mlp = nn.Sequential(
            nn.Linear(patch_count_max, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.col_mlp = nn.Sequential(
            nn.Linear(patch_count_max, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, sim_matrix):
        """
        sim_matrix: shape [m, n] where m = # query patches, n = # doc patches
        returns scalar similarity score
        """
        row_scores = self.row_mlp(sim_matrix)      # [m, 1]
        col_scores = self.col_mlp(sim_matrix.T)    # [n, 1]

        score = row_scores.sum() + col_scores.sum()
        return score


In [44]:
def jaccard_contrastive_loss(sim_score, jaccard_sim, margin=0.3):
    """
    sim_score: scalar output of similarity aggregator
    jaccard_sim: float, similarity label in [0, 1]
    """
    sim = torch.sigmoid(sim_score)

    positive_term = jaccard_sim * (1 - sim) ** 2
    negative_term = (1 - jaccard_sim) * F.relu(sim - margin) ** 2
    return positive_term + negative_term


In [79]:
csv_content = """slide_id,subtypes
BRACS_1379,"ADH,FEA"
BRACS_1486,"PB,UDH,FEA,ADH"
BRACS_1494,"N,PB,UDH,FEA,ADH"
BRACS_1499,"PB,UDH,ADH"
BRACS_1616,"ADH,UDH"
BRACS_1622,"N,B,UDH,ADH"
BRACS_1794,"ADH,UDH"
BRACS_1795,"PB,UDH,ADH,FEA"
BRACS_1003728,"ADH"
"""

with open("labelset.csv", "w") as f:
    f.write(csv_content)

print("✅ labelset.csv written")

# Define the subtype vocabulary and index mapping
subtype_vocab = ["N", "PB", "UDH", "FEA", "ADH", "DCIS", "IC"]
subtype_to_idx = {label: idx for idx, label in enumerate(subtype_vocab)}

import pandas as pd
# Function to load the labels.csv file
def load_label_dict(csv_path="labelset.csv"):
    df = pd.read_csv(csv_path)
    label_dict = {}
    for _, row in df.iterrows():
        slide_id = row["slide_id"]
        subtypes = row["subtypes"].split(",")
        indices = [subtype_to_idx[s.strip()] for s in subtypes if s.strip() in subtype_to_idx]
        label_dict[slide_id] = set(indices)
    return label_dict

# Load it once
label_dict = load_label_dict()

✅ labelset.csv written


In [74]:
def preprocess_embeddings(query_embeds, doc_embeds, max_patches=500):
    """
    Truncate or pad query and doc embeddings to (max_patches, dim).
    This ensures similarity matrix is max_patches × max_patches.
    """
    d = query_embeds.size(1)  # embedding dim

    def adjust(x):
        n = x.size(0)
        if n >= max_patches:
            return x[:max_patches]
        else:
            pad = torch.zeros((max_patches - n, d), device=x.device, dtype=x.dtype)
            return torch.cat([x, pad], dim=0)

    return adjust(query_embeds), adjust(doc_embeds)

In [85]:
from torch.utils.data import DataLoader
import tqdm

# Hyperparameters
EPOCHS = 50
BATCH_SIZE = 1                # one WSI pair at a time
MAX_PATCHES = 500             # for padding/truncating
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Dataset
dataset = WSIPairDataset("/content/embeddings/train", label_dict)
loader = DataLoader(dataset, batch_size=BATCH_SIZE)

# Model
model = SimilarityAggregator(patch_count_max=MAX_PATCHES).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)


In [86]:
for epoch in range(EPOCHS):
    print(f"\n- Epoch {epoch + 1}/{EPOCHS}")
    epoch_loss = 0

    for batch in tqdm.tqdm(loader, total=50):  # process 50 pairs per epoch
        q_embeds = batch["query_embeds"][0][:MAX_PATCHES].to(DEVICE)         # [m, d]
        d_embeds = batch["doc_embeds"][0][:MAX_PATCHES].to(DEVICE)           # [n, d]
        jaccard = batch["jaccard"].to(DEVICE)

        q_embeds = batch["query_embeds"][0].to(DEVICE)   # [m, d]
        d_embeds = batch["doc_embeds"][0].to(DEVICE)     # [n, d]

        # Normalize first
        q_norm = F.normalize(q_embeds, p=2, dim=1)
        d_norm = F.normalize(d_embeds, p=2, dim=1)

        # Truncate/pad both sides
        q_norm, d_norm = preprocess_embeddings(q_norm, d_norm, MAX_PATCHES)  # [500, d]

        # Compute similarity and forward pass
        sim_matrix = q_norm @ d_norm.T                                       # [500, 500]
        score = model(sim_matrix)
        similarity = torch.sigmoid(score)

        loss = jaccard_contrastive_loss(similarity, jaccard)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"📉 Epoch Loss: {epoch_loss:.4f}")



- Epoch 1/50


100%|██████████| 50/50 [00:08<00:00,  6.22it/s]


📉 Epoch Loss: 6.2027

- Epoch 2/50


100%|██████████| 50/50 [00:07<00:00,  6.93it/s]


📉 Epoch Loss: 5.9756

- Epoch 3/50


100%|██████████| 50/50 [00:08<00:00,  6.00it/s]


📉 Epoch Loss: 5.9837

- Epoch 4/50


100%|██████████| 50/50 [00:08<00:00,  5.58it/s]


📉 Epoch Loss: 6.2022

- Epoch 5/50


100%|██████████| 50/50 [00:08<00:00,  5.79it/s]


📉 Epoch Loss: 5.8691

- Epoch 6/50


100%|██████████| 50/50 [00:09<00:00,  5.13it/s]


📉 Epoch Loss: 6.1012

- Epoch 7/50


100%|██████████| 50/50 [00:06<00:00,  7.19it/s]


📉 Epoch Loss: 5.7606

- Epoch 8/50


100%|██████████| 50/50 [00:08<00:00,  5.95it/s]


📉 Epoch Loss: 5.3981

- Epoch 9/50


100%|██████████| 50/50 [00:10<00:00,  4.93it/s]


📉 Epoch Loss: 5.8037

- Epoch 10/50


100%|██████████| 50/50 [00:10<00:00,  4.98it/s]


📉 Epoch Loss: 5.9537

- Epoch 11/50


100%|██████████| 50/50 [00:08<00:00,  5.85it/s]


📉 Epoch Loss: 6.2428

- Epoch 12/50


100%|██████████| 50/50 [00:09<00:00,  5.32it/s]


📉 Epoch Loss: 5.8913

- Epoch 13/50


100%|██████████| 50/50 [00:10<00:00,  4.93it/s]


📉 Epoch Loss: 6.0928

- Epoch 14/50


100%|██████████| 50/50 [00:09<00:00,  5.10it/s]


📉 Epoch Loss: 6.1156

- Epoch 15/50


100%|██████████| 50/50 [00:08<00:00,  6.22it/s]


📉 Epoch Loss: 6.2958

- Epoch 16/50


100%|██████████| 50/50 [00:09<00:00,  5.23it/s]


📉 Epoch Loss: 6.1713

- Epoch 17/50


100%|██████████| 50/50 [00:09<00:00,  5.50it/s]


📉 Epoch Loss: 6.2253

- Epoch 18/50


100%|██████████| 50/50 [00:08<00:00,  5.92it/s]


📉 Epoch Loss: 6.0446

- Epoch 19/50


100%|██████████| 50/50 [00:07<00:00,  6.32it/s]


📉 Epoch Loss: 6.4250

- Epoch 20/50


100%|██████████| 50/50 [00:08<00:00,  6.18it/s]


📉 Epoch Loss: 5.9499

- Epoch 21/50


100%|██████████| 50/50 [00:08<00:00,  5.59it/s]


📉 Epoch Loss: 6.3169

- Epoch 22/50


100%|██████████| 50/50 [00:08<00:00,  5.70it/s]


📉 Epoch Loss: 6.4052

- Epoch 23/50


100%|██████████| 50/50 [00:08<00:00,  5.91it/s]


📉 Epoch Loss: 6.3007

- Epoch 24/50


100%|██████████| 50/50 [00:10<00:00,  4.74it/s]


📉 Epoch Loss: 6.0630

- Epoch 25/50


100%|██████████| 50/50 [00:09<00:00,  5.27it/s]


📉 Epoch Loss: 6.1023

- Epoch 26/50


100%|██████████| 50/50 [00:09<00:00,  5.05it/s]


📉 Epoch Loss: 6.3695

- Epoch 27/50


100%|██████████| 50/50 [00:09<00:00,  5.32it/s]


📉 Epoch Loss: 5.9592

- Epoch 28/50


100%|██████████| 50/50 [00:09<00:00,  5.26it/s]


📉 Epoch Loss: 5.8614

- Epoch 29/50


100%|██████████| 50/50 [00:08<00:00,  5.86it/s]


📉 Epoch Loss: 6.3575

- Epoch 30/50


100%|██████████| 50/50 [00:09<00:00,  5.55it/s]


📉 Epoch Loss: 6.3106

- Epoch 31/50


100%|██████████| 50/50 [00:08<00:00,  5.62it/s]


📉 Epoch Loss: 6.3301

- Epoch 32/50


100%|██████████| 50/50 [00:10<00:00,  4.94it/s]


📉 Epoch Loss: 6.0537

- Epoch 33/50


100%|██████████| 50/50 [00:09<00:00,  5.55it/s]


📉 Epoch Loss: 6.1112

- Epoch 34/50


100%|██████████| 50/50 [00:08<00:00,  6.17it/s]


📉 Epoch Loss: 6.3635

- Epoch 35/50


100%|██████████| 50/50 [00:08<00:00,  6.21it/s]


📉 Epoch Loss: 6.0769

- Epoch 36/50


100%|██████████| 50/50 [00:08<00:00,  6.01it/s]


📉 Epoch Loss: 6.2348

- Epoch 37/50


100%|██████████| 50/50 [00:09<00:00,  5.23it/s]


📉 Epoch Loss: 5.9941

- Epoch 38/50


100%|██████████| 50/50 [00:08<00:00,  5.75it/s]


📉 Epoch Loss: 5.8844

- Epoch 39/50


100%|██████████| 50/50 [00:07<00:00,  6.31it/s]


📉 Epoch Loss: 6.4056

- Epoch 40/50


100%|██████████| 50/50 [00:08<00:00,  5.74it/s]


📉 Epoch Loss: 6.4146

- Epoch 41/50


100%|██████████| 50/50 [00:08<00:00,  6.25it/s]


📉 Epoch Loss: 6.2887

- Epoch 42/50


100%|██████████| 50/50 [00:09<00:00,  5.19it/s]


📉 Epoch Loss: 5.9952

- Epoch 43/50


100%|██████████| 50/50 [00:08<00:00,  5.61it/s]


📉 Epoch Loss: 6.2741

- Epoch 44/50


100%|██████████| 50/50 [00:07<00:00,  6.55it/s]


📉 Epoch Loss: 6.1203

- Epoch 45/50


100%|██████████| 50/50 [00:10<00:00,  4.72it/s]


📉 Epoch Loss: 5.9408

- Epoch 46/50


100%|██████████| 50/50 [00:08<00:00,  5.70it/s]


📉 Epoch Loss: 6.0765

- Epoch 47/50


100%|██████████| 50/50 [00:08<00:00,  6.17it/s]


📉 Epoch Loss: 6.0465

- Epoch 48/50


100%|██████████| 50/50 [00:07<00:00,  6.59it/s]


📉 Epoch Loss: 6.3155

- Epoch 49/50


100%|██████████| 50/50 [00:08<00:00,  5.64it/s]


📉 Epoch Loss: 6.1140

- Epoch 50/50


100%|██████████| 50/50 [00:08<00:00,  5.62it/s]

📉 Epoch Loss: 6.3249



