In [1]:
"""
Feature extraction for CLIP-GO (UNI2 + CoCa)
===========================================

• Reuses the training-time wiring (CFG, UNI2Wrapper, CLIPGO).
• Loads the same AnnData + sentences CSV.
• Builds a non-shuffled DataLoader.
• Loads best/last checkpoint.
• Saves NPZ with:
    - cell_id
    - vision_raw  : UNI2 pooled (4 center tokens), dim=1536
    - vision_proj : projected + L2-normalized, dim=256
    - text_raw    : CoCa encode_text output, dim≈768/1024 (depends on model)
    - text_proj   : projected + L2-normalized, dim=256
"""

from __future__ import annotations

import os, glob, math
from typing import List
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

import timm
import open_clip
import scanpy as sc
import openslide

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [12]:
# -----------------------------------------------------------------------------
# CFG (match your training style)
# -----------------------------------------------------------------------------
class CFG:
    # data
    cancer = "lung"
    ground_truth = "refined"
    level = 0                 # UNI2 spatial token level (0 → 4-center tokens)
    batch_size = 256
    num_workers = 8

    # embeddings / model dims
    morph_emb_dims = 1536
    projection_dim = 256
    patch_size = 224

    # Text / CoCa
    coca_model = "coca_ViT-L-14"
    coca_pretrain = "laion2B-s13b-b90k"
    context_len = 76
    freeze_text = True  # only matters if you change the model; extraction is no-grad

    # paths (mirror your training)
    root = "/rsrch9/home/plm/idso_fa1_pathology/TIER1/paul-xenium/public_data/10x_genomics"
    xenium_sample_dict = {
        "lung":"Xenium_Prime_Human_Lung_Cancer_FFPE_outs",
        "breast": "Xenium_Prime_Breast_Cancer_FFPE_outs",
        "lymph_node": "Xenium_Prime_Human_Lymph_Node_Reactive_FFPE_outs",
        "prostate": "Xenium_Prime_Human_Prostate_FFPE_outs",
        "skin": "Xenium_Prime_Human_Skin_FFPE_outs",
        "ovarian": "Xenium_Prime_Ovarian_Cancer_FFPE_outs",
        "cervical": "Xenium_Prime_Cervical_Cancer_FFPE_outs",
    }
    go_dir = "/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/gene_ontology"
    model_dir = "/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/models/public/UNI2-h"
    ckpt_dir  = "/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/models/fine_tuned/GoCLIP"

    target_mpp = 0.5  # target µm/px (≈20×)

    # output
    out_npz = "/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/models/fine_tuned/GoCLIP/features_lung.npz"

In [13]:
# -----------------------------------------------------------------------------
# Projection head (same as training)
# -----------------------------------------------------------------------------
class ProjectionHead(nn.Module):
    def __init__(self, in_dim: int, proj_dim: int = CFG.projection_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, proj_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(proj_dim, proj_dim)
        self.ln  = nn.LayerNorm(proj_dim)

    def forward(self, x):
        h = self.fc1(x)
        x = self.act(h)
        x = self.fc2(x)
        return self.ln(x + h)


# -----------------------------------------------------------------------------
# UNI2 wrapper (4-centre token pooling; same as training)
# -----------------------------------------------------------------------------
class UNI2Wrapper(nn.Module):
    def __init__(self, uni2: nn.Module, centre_idx: List[int]):
        super().__init__()
        self.uni2 = uni2
        self.centre_idx = torch.tensor(centre_idx, dtype=torch.long)
        self.prefix_tokens = 9

    def forward(self, x: torch.Tensor):
        tok = self.uni2.forward_features(x)               # (B, 265, 1536)
        spatial = tok[:, self.prefix_tokens :, :]
        centre  = spatial.index_select(1, self.centre_idx.to(x.device)).mean(1)
        return centre                                     # (B,1536)


# -----------------------------------------------------------------------------
# Dataset: cell-centred patch + GO sentence (returns cell_id too)
# -----------------------------------------------------------------------------
class CellPatchTextDataset(Dataset):
    def __init__(self, slide, cell_df: pd.DataFrame, sentences: pd.Series,
                 transform, scale: float, patch_size: int = CFG.patch_size):
        self.slide = slide
        self.cells = cell_df.reset_index(drop=False)  # keep cell_id in column "index"
        self.sentences = sentences
        self.tfm = transform
        self.scale = scale
        self.patch_size = patch_size

    def __len__(self):
        return len(self.cells)

    def _read_patch(self, x, y):
        big = int(self.patch_size * self.scale)
        tlx, tly = int(x - big/2), int(y - big/2)
        patch = self.slide.read_region((tlx, tly), 0, (big, big)).convert("RGB")
        return patch.resize((self.patch_size, self.patch_size), Image.LANCZOS)

    def __getitem__(self, idx):
        row   = self.cells.iloc[idx]
        patch = self._read_patch(row.x_centroid, row.y_centroid)
        img_t = self.tfm(patch)
        cell_id = row["index"]
        sent = self.sentences.loc[cell_id]
        sent = "" if pd.isna(sent) else str(sent)
        return {"image": img_t, "text": sent, "cell_id": cell_id}


# -----------------------------------------------------------------------------
# CLIP-GO (UNI2 + CoCa) – same as training
# -----------------------------------------------------------------------------
class CLIPGO(nn.Module):
    def __init__(self, vision_backbone: nn.Module,
                 coca_model: str = CFG.coca_model,
                 coca_pretrain: str = CFG.coca_pretrain,
                 proj_dim: int = CFG.projection_dim,
                 context_len: int = CFG.context_len,
                 freeze_text: bool = CFG.freeze_text):
        super().__init__()
        self.context_len = context_len
        self.freeze_text = freeze_text

        # Vision branch
        self.vision_encoder = vision_backbone
        vision_dim = vision_backbone.uni2.embed_dim  # 1536 for ViT-Giant
        self.vision_proj  = ProjectionHead(vision_dim, proj_dim)

        # Text branch (CoCa causal)
        self.text_encoder, _, _ = open_clip.create_model_and_transforms(
            coca_model, pretrained=coca_pretrain,
            cache_dir=os.path.expanduser("~/.cache/open_clip")
        )
        self.tokenizer = open_clip.get_tokenizer(coca_model)
        if freeze_text:
            for p in self.text_encoder.parameters():
                p.requires_grad_(False)

        # Determine text width dynamically
        with torch.no_grad():
            dummy = self.tokenizer(["dummy"], context_length=context_len)
            txt_feat = self.text_encoder.encode_text(dummy)
        text_dim = txt_feat.shape[-1]
        self.text_proj = ProjectionHead(text_dim, proj_dim)

        # CLIP temperature (not used directly for extraction but kept for completeness)
        self.logit_scale = nn.Parameter(torch.log(torch.tensor(1/0.07)))


# -----------------------------------------------------------------------------
# Helpers: slide/MPP + transforms + checkpoint
# -----------------------------------------------------------------------------
def get_slide_and_mpp(slide_dir: str):
    tifs = sorted(glob.glob(os.path.join(slide_dir, "**", "*he_image_registered*.ome.tif"), recursive=True))
    if not tifs:
        tifs = sorted(glob.glob(os.path.join(slide_dir, "**", "*.tif"), recursive=True))
        assert tifs, f"No slide tif found under {slide_dir}"
    slide_path = tifs[0]
    slide = openslide.open_slide(slide_path)

    # mpp search (robust)
    props = slide.properties
    mpp = None
    for key in ("openslide.mpp-x", "aperio.MPP", "tiff.XResolution"):
        if key in props:
            try:
                mpp = float(props[key]); break
            except Exception:
                pass
    if mpp is None:
        mpp = CFG.target_mpp
    return slide, mpp, slide_path


def build_transforms():
    return T.Compose([
        T.ToTensor(),
        T.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
    ])


def resolve_checkpoint():
    cands = [
        os.path.join(CFG.ckpt_dir, "best.pth"),
        os.path.join(CFG.ckpt_dir, CFG.cancer, "best.pth"),
        os.path.join(CFG.ckpt_dir, "last.pth"),
    ]
    for p in cands:
        if os.path.isfile(p):
            return p
    raise FileNotFoundError(f"No checkpoint found in {CFG.ckpt_dir} (tried best.pth / last.pth).")

In [14]:
# -----------------------------------------------------------------------------
# Feature extraction
# -----------------------------------------------------------------------------
@torch.no_grad()
def extract_features(model: CLIPGO, loader: DataLoader, device: torch.device):
    model.eval().to(device)
    ids = []
    vis_raw, vis_proj = [], []
    txt_raw, txt_proj = [], []

    pbar = tqdm(loader, desc="Extracting features", dynamic_ncols=True)
    for batch in pbar:
        imgs  = batch["image"].to(device, non_blocking=True)
        texts = batch["text"]
        batch_ids = batch["cell_id"]
        ids.extend(batch_ids)

        # Vision
        v_raw  = model.vision_encoder(imgs)                    # (B, 1536)
        v_proj = F.normalize(model.vision_proj(v_raw), dim=-1) # (B, 256)

        # Text
        tokens = model.tokenizer(texts, context_length=model.context_len).to(device)
        t_raw  = model.text_encoder.encode_text(tokens)        # (B, text_dim)
        t_proj = F.normalize(model.text_proj(t_raw), dim=-1)   # (B, 256)

        vis_raw.append(v_raw.cpu());   vis_proj.append(v_proj.cpu())
        txt_raw.append(t_raw.cpu());   txt_proj.append(t_proj.cpu())

    vis_raw  = torch.cat(vis_raw, 0).numpy()
    vis_proj = torch.cat(vis_proj, 0).numpy()
    txt_raw  = torch.cat(txt_raw, 0).numpy()
    txt_proj = torch.cat(txt_proj, 0).numpy()
    return ids, vis_raw, vis_proj, txt_raw, txt_proj


def save_npz(path: str, cell_ids, vis_raw, vis_proj, txt_raw, txt_proj):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    np.savez_compressed(
        path,
        cell_id=np.array(cell_ids),
        vision_raw=vis_raw,
        vision_proj=vis_proj,
        text_raw=txt_raw,
        text_proj=txt_proj,
    )
    print(f"✓ Saved → {path}")

In [15]:
# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Resolve dataset paths
sample  = CFG.xenium_sample_dict[CFG.cancer]
sample_dir = os.path.join(CFG.root, sample)

# Load AnnData (expects x_centroid / y_centroid in .obs)
adata_path = os.path.join(
    sample_dir,
    "preprocessed",
    f"fine_tune_{CFG.ground_truth}_v2",
    f"processed_xenium_data_fine_tune_{CFG.ground_truth}_v2_annotated.h5ad",
)
adata = sc.read_h5ad(adata_path)
cell_df = adata.obs.copy()  # index = cell_id

# Sentences: same CSV logic as your training
sent_path = f"{CFG.go_dir}/{sample.replace('outs', 'GO.csv')}"
assert os.path.isfile(sent_path), f"Missing sentences file: {sent_path}"
sentences = pd.read_csv(sent_path, index_col="cell_id")["go_sentences"].astype(str)
sentences = sentences.reindex(cell_df.index).fillna("")

# Slide + scale
slide, mpp, slide_path = get_slide_and_mpp(sample_dir)
scale_factor = max(CFG.target_mpp / float(mpp), 1e-6)

# Transforms and DataLoader (no shuffle, keep all)
transform = build_transforms()
dataset = CellPatchTextDataset(slide, cell_df, sentences, transform,
                               scale=scale_factor, patch_size=CFG.patch_size)
loader  = DataLoader(dataset, batch_size=CFG.batch_size, shuffle=False,
                     num_workers=CFG.num_workers, pin_memory=True,
                     persistent_workers=True)

# Build vision backbone (UNI2) and wrap
uni2_cfg = {
    'model_name':'vit_giant_patch14_224','img_size':CFG.patch_size,'patch_size':14,'depth':24,
    'num_heads':24,'init_values':1e-5,'embed_dim':CFG.morph_emb_dims,'mlp_ratio':2.66667*2,
    'num_classes':0,'no_embed_class':True,'mlp_layer':timm.layers.SwiGLUPacked,
    'act_layer':torch.nn.SiLU,'reg_tokens':8,'dynamic_img_size':True,
}

uni2 = timm.create_model(pretrained=False, **uni2_cfg)
# (Optional) load UNI2 weights (same as training init)
uni2_weights = os.path.join(CFG.model_dir, "pytorch_model.bin")
if os.path.isfile(uni2_weights):
    uni2.load_state_dict(torch.load(uni2_weights, map_location="cpu"), strict=False)

centre_idx = [119, 120, 135, 136] if CFG.level == 0 else [
    102,103,104,105,118,119,120,121,134,135,136,137,150,151,152,153
]
vision_backbone = UNI2Wrapper(uni2, centre_idx)

# Build model and load trained checkpoint
model = CLIPGO(vision_backbone)
ckpt_path = resolve_checkpoint()
print(f"→ Loading checkpoint: {ckpt_path}")
state = torch.load(ckpt_path, map_location="cpu")
missing, unexpected = model.load_state_dict(state["model"], strict=False)
if missing:   print("[load] missing keys:", missing)
if unexpected: print("[load] unexpected keys:", unexpected)

# Extract features
cell_ids, v_raw, v_proj, t_raw, t_proj = extract_features(model, loader, device)

# Quick sanity: alignment + cosine diag stats
assert len(cell_ids) == len(dataset), "Mismatch in number of extracted embeddings"
diag_cos = (v_proj @ t_proj.T).diagonal()
print(f"Diag cosine (proj): mean={diag_cos.mean():.3f}, std={diag_cos.std():.3f}")

# Save


→ Loading checkpoint: /rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/models/fine_tuned/GoCLIP/best.pth


Extracting features: 100%|██████████| 956/956 [1:28:51<00:00,  5.58s/it]


Diag cosine (proj): mean=0.486, std=0.054


In [17]:
save_npz(CFG.out_npz, cell_ids, v_raw, v_proj, t_raw, t_proj)

✓ Saved → /rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/models/fine_tuned/GoCLIP/features_lung.npz
