In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import os
import random
from typing import List
import matplotlib.pyplot as plt
from google.colab import drive
!pip install gcsfs




In [5]:
!pip install -q gcsfs

# Step 2: Authenticate to access Google Cloud Storage
from google.colab import auth
auth.authenticate_user()

# Step 3: Use gcsfs to interact with your bucket
import gcsfs

# Replace with your actual project ID if needed
fs = gcsfs.GCSFileSystem()

# Path to the training embeddings folder
embedding_base_path = 'bracs-dataset-bucket/patch-embeddings/test'

# Step 4: List WSI folders
wsi_folders = fs.ls(embedding_base_path)

# Only keep directories (some may include .pt files directly)
wsi_dirs = [path for path in wsi_folders]

print(f"✅ Number of WSI folders: {len(wsi_dirs)}\n")

# Print first few folder names
print("📂 Sample WSI Folders:")
for folder in wsi_dirs[:10]:
    print("-", folder)


✅ Number of WSI folders: 63

📂 Sample WSI Folders:
- bracs-dataset-bucket/patch-embeddings/test/BRACS_1003691_embeddings.pt
- bracs-dataset-bucket/patch-embeddings/test/BRACS_1003694_embeddings.pt
- bracs-dataset-bucket/patch-embeddings/test/BRACS_1228_embeddings.pt
- bracs-dataset-bucket/patch-embeddings/test/BRACS_1283_embeddings.pt
- bracs-dataset-bucket/patch-embeddings/test/BRACS_1330_embeddings.pt
- bracs-dataset-bucket/patch-embeddings/test/BRACS_1334_embeddings.pt
- bracs-dataset-bucket/patch-embeddings/test/BRACS_1412_embeddings.pt
- bracs-dataset-bucket/patch-embeddings/test/BRACS_1416_embeddings.pt
- bracs-dataset-bucket/patch-embeddings/test/BRACS_1473_embeddings.pt
- bracs-dataset-bucket/patch-embeddings/test/BRACS_1474_embeddings.pt


In [7]:
class SeparableLITEScorer(nn.Module):
    def __init__(self, max_query_patches=256, max_doc_patches=256, hidden_dim=128):
        super().__init__()
        self.max_query_patches = max_query_patches
        self.max_doc_patches = max_doc_patches

        # Row-wise MLP over doc dimension
        self.row_mlp = nn.Sequential(
            nn.LayerNorm(max_doc_patches),
            nn.Linear(max_doc_patches, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, max_doc_patches)
        )

        # Column-wise MLP over query dimension
        self.col_mlp = nn.Sequential(
            nn.LayerNorm(max_query_patches),
            nn.Linear(max_query_patches, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, max_query_patches)
        )

        # Final projection to scalar
        self.final_proj = nn.Linear(max_query_patches * max_doc_patches, 1)

    def forward(self, S, q_mask, d_mask):
        S_prime = self.row_mlp(S)                              # shape: [m, n]
        S_double_prime = self.col_mlp(S_prime.T).T             # shape: [m, n]
        flat = S_double_prime.reshape(1, -1)                   # shape: [1, m*n]
        score = self.final_proj(flat)                          # shape: [1, 1]
        return score.squeeze()
import torch

def prepare_patches(embeds, max_len=256):
    """
    Truncates or pads the patch embeddings to size (max_len, D).
    """
    L, D = embeds.shape
    if L > max_len:
        return embeds[:max_len]
    else:
        pad_len = max_len - L
        pad_tensor = torch.zeros((pad_len, D), device=embeds.device)
        return torch.cat([embeds, pad_tensor], dim=0)

def prepare_mask(length, max_len=256):
    """
    Returns a binary mask of shape (max_len,) with 1s for valid tokens.
    """
    mask = torch.zeros((max_len,), dtype=torch.bool)
    mask[:min(length, max_len)] = 1
    return mask


In [55]:
# Download model from GCS manually first (if not already)
!gsutil cp gs://bracs-dataset-bucket/checkpoints/scorer_epoch50.pt lite_scorer.pt
import torch

# Re-declare model class (SeparableLITEScorer) here as before
model = SeparableLITEScorer()
model.load_state_dict(torch.load("lite_scorer.pt", map_location=torch.device("cpu")))
model.eval()
print("✅ Model loaded successfully.")

Copying gs://bracs-dataset-bucket/checkpoints/scorer_epoch50.pt...
- [1 files][779.5 KiB/779.5 KiB]                                                
Operation completed over 1 objects/779.5 KiB.                                    
✅ Model loaded successfully.


In [11]:
import pandas as pd

xls = pd.ExcelFile("BRACS_BRACS.xlsx")
df_info = pd.read_excel(xls, "WSI_Information")

# Build a lookup: WSI_ID → global label index (0 to 6)
lesions = ['N', 'PB', 'UDH', 'FEA', 'ADH', 'DCIS', 'IC']
label2idx = {label: i for i, label in enumerate(lesions)}
global_labels = {
    row['WSI Filename']: label2idx[row['WSI label']]
    for _, row in df_info.iterrows()
    if row['WSI label'] in label2idx
}


In [12]:
import gcsfs
import torch

fs = gcsfs.GCSFileSystem()

test_embedding_path = "bracs-dataset-bucket/patch-embeddings/test"
test_slide_paths = []

for f in fs.find(test_embedding_path):
    if f.endswith("_embeddings.pt"):
        test_slide_paths.append(f)

print(f"✅ Found {len(test_slide_paths)} test WSIs.")

✅ Found 63 test WSIs.


In [13]:
def load_embedding_from_gcs(path):
    with fs.open(path, 'rb') as f:
        data = torch.load(f, map_location='cpu')
    return data

def get_global_label_vector(slide_id):
    if slide_id in global_labels:
        label_index = global_labels[slide_id]
        vec = torch.zeros(7)
        vec[label_index] = 1.0
        return vec
    else:
        print(f"Warning: No global label found for slide {slide_id}.")
        return None

In [37]:
test_data = {}

for path in test_slide_paths:
    slide_id = path.split("/")[-1].split("_embeddings.pt")[0]
    label_vec = get_global_label_vector(slide_id)

    if label_vec is None:
        continue

    emb = load_embedding_from_gcs(path)
    mask = torch.ones(len(emb))  # all valid
    test_data[slide_id] = {
        "emb": emb,
        "mask": mask,
        "label": label_vec
    }



print(f"✅ Valid test WSIs with labels: {len(test_data)}")

✅ Valid test WSIs with labels: 63


In [15]:
import torch.nn.functional as F
from collections import defaultdict

import torch
import torch.nn.functional as F

def prepare_patches(embeds, max_len=256):
    """
    Truncate or pad the patch embeddings to size (max_len, D)
    """
    L, D = embeds.shape
    if L > max_len:
        return embeds[:max_len]
    else:
        pad_len = max_len - L
        pad_tensor = torch.zeros((pad_len, D), device=embeds.device)
        return torch.cat([embeds, pad_tensor], dim=0)

def prepare_mask(length, max_len=256):
    """
    Binary mask of shape (max_len,) indicating valid patches.
    """
    mask = torch.zeros((max_len,), dtype=torch.bool)
    mask[:min(length, max_len)] = 1
    return mask

def get_similarity(q_emb, d_emb, scorer):
    device = next(scorer.parameters()).device

    q = F.normalize(q_emb, dim=-1).to(device)
    d = F.normalize(d_emb, dim=-1).to(device)

    q = prepare_patches(q, max_len=256)
    d = prepare_patches(d, max_len=256)

    q_mask = prepare_mask(q_emb.shape[0], max_len=256).to(device)
    d_mask = prepare_mask(d_emb.shape[0], max_len=256).to(device)

    S = torch.matmul(q, d.T)  # [256, 256]
    score = scorer(S, q_mask, d_mask)

    return torch.sigmoid(score).item()


def evaluate_top_k(test_data, model, K=5):
    correct_at_1 = 0
    correct_at_k = 0
    total = 0

    slide_ids = list(test_data.keys())

    for i, q_id in enumerate(slide_ids):
        q_data = test_data[q_id]
        sims = []

        for j, d_id in enumerate(slide_ids):
            if q_id == d_id:
                continue

            sim = get_similarity(q_data["emb"], test_data[d_id]["emb"], model)
            sims.append((d_id, sim))

        # sort by similarity
        sims.sort(key=lambda x: x[1], reverse=True)
        top_k = [sid for sid, _ in sims[:K]]

        true_label = torch.argmax(q_data["label"]).item()

        top_labels = [torch.argmax(test_data[sid]["label"]).item() for sid in top_k]

        correct_at_1 += (top_labels[0] == true_label)
        correct_at_k += (true_label in top_labels)
        total += 1

    top1_acc = correct_at_1 / total
    topk_acc = correct_at_k / total

    print(f"✅ Top-1 Accuracy: {top1_acc:.4f}")
    print(f"✅ Top-{K} Accuracy: {topk_acc:.4f}")


In [58]:


test_data = {sid: info for sid, info in test_data.items() if info["emb"].shape[0] >= 230}
len(test_data)


41

In [59]:
evaluate_top_k(test_data, model, K=5)

✅ Top-1 Accuracy: 0.1220
✅ Top-5 Accuracy: 0.4878


In [54]:
evaluate_top_k(test_data, model, K=3)

✅ Top-1 Accuracy: 0.2222
✅ Top-3 Accuracy: 0.5556
