In [1]:
#%%
import os
import numpy as np 
import pandas as pd 
import scanpy as sc 
import matplotlib.pyplot as plt
from tqdm import tqdm  # Import tqdm
import pickle 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchmetrics import Accuracy

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay

import torch.backends.cuda
torch.backends.cuda.enable_flash_sdp(False)

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(


In [2]:
print(torch.cuda.is_available())
print(torch.version.cuda)
print(torch.cuda.device_count())


True
12.1
1


In [3]:
#%%
# Select platform
platform = "xenium" # xenium or visium 
ground_truth = "refined"  # refined or cellvit
level = 0
filtered_genes = False
label_source = "singleR"  # singleR, celltypist, aistil, combined
morph_version = "v2"
use_qc = False
limit_classes = True  # Set to False to use all classes


if platform == "xenium":
    cancer = "lung"
    xenium_folder_dict = {"lung": "Xenium_Prime_Human_Lung_Cancer_FFPE_outs",
                          "breast":"Xenium_Prime_Breast_Cancer_FFPE_outs",
                          "lymph_node": "Xenium_Prime_Human_Lymph_Node_Reactive_FFPE_outs",
                          "prostate": "Xenium_Prime_Human_Prostate_FFPE_outs",
                          "skin": "Xenium_Prime_Human_Skin_FFPE_outs",
                          "ovarian": "Xenium_Prime_Ovarian_Cancer_FFPE_outs",
                          "cervical": "Xenium_Prime_Cervical_Cancer_FFPE_outs"
                          }

    xenium_folder = xenium_folder_dict[cancer]
    celltypist_data_path = f"/rsrch9/home/plm/idso_fa1_pathology/TIER1/paul-xenium/public_data/10x_genomics/{xenium_folder}/preprocessed/fine_tune_{ground_truth}_v2/processed_xenium_data_fine_tune_{ground_truth}_ImmuneHigh_v2.h5ad"
    singleR_data_path = f"/rsrch9/home/plm/idso_fa1_pathology/TIER1/paul-xenium/public_data/10x_genomics/{xenium_folder}/preprocessed/fine_tune_{ground_truth}_v2/processed_xenium_data_fine_tune_{ground_truth}_v2_annotated.h5ad"
    
    embedding_dir = f"/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/{xenium_folder}"

    gene_emb_path = f"{embedding_dir}/scGPT/scGPT_CP.h5ad"

    # if filtered_genes:
    #     gene_embedding_file = f"/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/{xenium_folder}/processed_xenium_refined_clustering_filtered_v2.csv"
    # else:
    #     gene_embedding_file = f"/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/{xenium_folder}/processed_xenium_{ground_truth}_v2.csv"
    
    # Load Morphological Embeddings

        
elif platform == "visium":
    data_path = "/rsrch9/home/plm/idso_fa1_pathology/TIER1/paul-xenium/public_data/10x_genomics/Visium_HD_Human_Lung_Cancer_post_Xenium_Prime_5k_Experiment2/binned_outputs/square_002um/preprocessed/bin2cell/to_tokenize/corrected_cells_matched_preprocessed_refined_v2.h5ad"
    singleR_data_path = "/rsrch9/home/plm/idso_fa1_pathology/TIER1/paul-xenium/public_data/10x_genomics/Visium_HD_Human_Lung_Cancer_post_Xenium_Prime_5k_Experiment2/binned_outputs/square_002um/preprocessed/bin2cell/corrected_cells_matched_preprocessed_refined_v2_annotated.h5ad"

    # gene_embedding_file = "/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/Visium_HD_Human_Lung_Cancer_post_Xenium_Prime_5k_Experiment2/bin2cell/embeddings_output/processed_visium_hd_bin2cell.csv"
    embedding_dir = "/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/Visium_HD_Human_Lung_Cancer_post_Xenium_Prime_5k_Experiment2/"
    gene_emb_path = "/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/Visium_HD_Human_Lung_Cancer_post_Xenium_Prime_5k_Experiment2/b2c_scGPT_WH.h5ad"


# Load AnnData
if label_source == "singleR":
    data_path = singleR_data_path
    adata = sc.read_h5ad(data_path)
    
    
elif label_source == "celltypist":
    data_path = celltypist_data_path
    adata = sc.read_h5ad(data_path)
    
elif label_source == "combined":
    adata = sc.read_h5ad(singleR_data_path)
    bdata = sc.read_h5ad(celltypist_data_path)
    adata.obs["majority_voting"] = bdata.obs["majority_voting"]
    adata.obs["qc_celltypist"] = bdata.obs["qc_celltypist"]
    cell_data = adata.obs

cell_data = adata.obs
print("Cell data shape:", cell_data.shape)
# Load Morphology Embeddings 
if morph_version == "v1":
    morph_embedding_csv = os.path.join(embedding_dir, "UNI2_cell_representation",f"level_{level}","morphology_embeddings_v2.csv")  
else:
    morph_embedding_csv = os.path.join(embedding_dir, "UNI2_cell_representation",f"level_{level}","uni2_pretrained_embeddings.csv") 

morph_embeddings = pd.read_csv(morph_embedding_csv, index_col="Unnamed: 0")

# Load gene Embeddings 
# gene_embeddings = pd.read_csv(gene_embedding_file, index_col="Unnamed: 0")
gdata = sc.read_h5ad(gene_emb_path)


if platform == "visium":
    # Ensure index alignment
    cell_data.index = cell_data.index.astype(str)
    gdata.obs_names = gdata.obs_names.astype(str)
    morph_embeddings.index = morph_embeddings.index.astype(str)
    
    # Filter cell_data to match gene embeddings
    cell_data = cell_data.loc[gdata.obs_names]

    # Align morphology embeddings to gene_embeddings (which should already be filtered)
    morph_embeddings = morph_embeddings.loc[cell_data.index]


# Now create gene_embeddings with matching index
gene_embeddings = pd.DataFrame(gdata.obsm["X_scGPT"], index=gdata.obs_names)

assert (morph_embeddings.index == gene_embeddings.index).all(), "Indices are not aligned!"

# Spatial Information 
spatial_coords = cell_data[['x_centroid', 'y_centroid']].rename(columns={'x_centroid': 'x', 'y_centroid': 'y'})


if label_source=="singleR":
    print("Using labels from SingleR.")
    singleR_to_class_map = {
        "Smooth muscle": "fibroblast",
        "Fibroblasts": "fibroblast",
        "Endothelial cells": "endothelial",
        "CD4+ T-cells": "t_cell",
        "CD8+ T-cells": "t_cell",
        "B-cells": "b_cell",
        "Macrophages": "macrophage",
        "Epithelial cells": "epithelial",
    }
    
    target_classes = ["fibroblast", "endothelial",
                      "t_cell", "b_cell", "macrophage",
                      "epithelial"]
    
    # Map SingleR labels to 7-class system
    cell_data[label_source] = cell_data["singleR_class"].map(singleR_to_class_map)
    
    # Drop cells that are nan (if any)
    # cell_data = cell_data.dropna(subset=[label_source])
    
    # Keep only those 7 classes
    # cell_data = cell_data[cell_data[label_source].isin(target_classes)]
    print("Cell data shape:", cell_data.shape)

    
    if use_qc:
        cell_data = cell_data[cell_data["qc_singleR"]==1]

    
    # Reindex embeddings/coords
    gene_embeddings = gene_embeddings.reindex(cell_data.index)
    morph_embeddings = morph_embeddings.reindex(cell_data.index)
    spatial_coords = spatial_coords.reindex(cell_data.index)
    
    
elif label_source=="aistil":
    print("Using AISTIL labels")
    label_source = "class"
    target_classes = ["f", "l", "t"]  # Modify this list to restrict classification to specific classes
    if limit_classes:
        num_classes = len(target_classes)
        cell_data = cell_data[cell_data[label_source].isin(target_classes)]

        # Change index type for Visium data to match embeddings Idxs 
        if platform == "visium":
            morph_embeddings.index = morph_embeddings.index.astype(str)

        # Update corresponding embeddings and spatial coordinates
        gene_embeddings = gene_embeddings.reindex(cell_data.index)
        morph_embeddings = morph_embeddings.reindex(cell_data.index)
        spatial_coords = spatial_coords.reindex(cell_data.index)
    else:
        target_classes = ["f","l","o","t"]
        
elif label_source=="celltypist":
    print("Using CellTypist Labels")
    celltypist_to_class_map = {
        "Fibroblasts": "fibroblast",
        "Endothelial cells": "endothelial",
        "T cells": "t_cell",
        "B cells": "b_cell",
        "Macrophages": "macrophage",
        "Epithelial cells": "epithelial",
    }
    target_classes = ["fibroblast", "endothelial",
                      "t_cell", "b_cell", "macrophage",
                      "epithelial"]

    # Map SingleR labels to 7-class system
    cell_data[label_source] = cell_data["majority_voting"].map(celltypist_to_class_map)
    
    # Drop cells that are nan (if any)
    # cell_data = cell_data.dropna(subset=[])
    
    # Keep only those 7 classes
    # cell_data = cell_data[cell_data[label_source].isin(target_classes)]

    if use_qc:
        cell_data = cell_data[cell_data["qc_celltypist"]==1]

    # Reindex embeddings/coords
    gene_embeddings = gene_embeddings.reindex(cell_data.index)
    morph_embeddings = morph_embeddings.reindex(cell_data.index)
    spatial_coords = spatial_coords.reindex(cell_data.index)
    
elif label_source == "combined":
    print("Using combined SingleR and CellTypist labels (agreement only)")

    # Define the shared label map and target classes
    shared_class_map = {
        "Fibroblasts": "fibroblast",
        "Smooth muscle": "fibroblast",
        "Endothelial cells": "endothelial",
        "CD4+ T-cells": "t_cell",
        "CD8+ T-cells": "t_cell",
        "T cells": "t_cell",
        "B cells": "b_cell",
        "B-cells": "b_cell",
        "Macrophages": "macrophage",
        "Epithelial cells": "epithelial",
    }
    
    target_classes = ["fibroblast", "endothelial", "t_cell", "b_cell", "macrophage", "epithelial"]

    # First map the labels (these are safe)
    cell_data["singleR_mapped"] = cell_data["singleR_class"].map(shared_class_map)
    cell_data["celltypist_mapped"] = cell_data["majority_voting"].map(shared_class_map)
    
    # Then immediately filter with a properly aligned mask
    cell_data = cell_data[
        cell_data["singleR_mapped"].notnull() &
        cell_data["celltypist_mapped"].notnull() &
        (cell_data["singleR_mapped"] == cell_data["celltypist_mapped"])
    ].copy()
    

    # Rename the final label column
    cell_data["combined"] = cell_data["singleR_mapped"]
    
    if use_qc:
        qc_mask = cell_data["qc_singleR"] == 1
        if "qc_celltypist" in cell_data.columns:
            qc_mask &= cell_data["qc_celltypist"] == 1
        cell_data = cell_data[qc_mask]

    # Reindex everything to the filtered cells
    gene_embeddings = gene_embeddings.reindex(cell_data.index)
    morph_embeddings = morph_embeddings.reindex(cell_data.index)
    spatial_coords = spatial_coords.reindex(cell_data.index)



num_classes = len(target_classes)
label_mapping = {cls_name: i for i, cls_name in enumerate(target_classes)}
labels = pd.Series(cell_data[label_source].map(label_mapping))

Cell data shape: (244659, 40)
Using labels from SingleR.
Cell data shape: (244659, 41)


In [4]:
# GPU setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [5]:
# Convert dataframes to tensors
gene_tensor = torch.tensor(gene_embeddings.values, dtype=torch.float32).to(device)
morph_tensor = torch.tensor(morph_embeddings.values, dtype=torch.float32).to(device)
print(gene_tensor.shape)
# Confirm shapes match
assert gene_tensor.shape[0] == morph_tensor.shape[0], "Mismatch in cell counts."


torch.Size([244659, 512])


In [168]:
# ───────────────────────── configuration ─────────────────────────
proj_dim       = 256
lr_peak        = 1e-4
epochs         = 50
batch_size     = 2048
tau            = 0.15
theta          = 0.20          # keep a low, fixed mask threshold
warmup_epochs  = 5             # pure InfoNCE before blending
alpha_end      = 0.20          # final weight of InfoNCE in the blend
grad_clip      = 5.0
device         = "cuda" if torch.cuda.is_available() else "cpu"


In [169]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, proj_dim=256, drop=0.3):
        super().__init__()
        self.proj  = nn.Linear(in_dim, proj_dim)
        self.act   = nn.GELU()
        self.fc    = nn.Linear(proj_dim, proj_dim)
        self.drop  = nn.Dropout(drop)
        self.norm  = nn.LayerNorm(proj_dim)

    def forward(self, x):
        y = self.proj(x)
        z = self.act(y)
        z = self.fc(z)
        z = self.drop(z)
        z = z + y            # residual
        return self.norm(z)


# Initialize projection networks

gene_proj_net = ProjectionHead(in_dim=gene_tensor.shape[1]).to(device)
morph_proj_net = ProjectionHead(in_dim=morph_tensor.shape[1]).to(device)

if weight_init:
    def init_proj(net):
        for m in net.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=weight_gain)  # bigger variance
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    gene_proj_net.apply(init_proj)
    morph_proj_net.apply(init_proj)



In [170]:
# ─────────────────────── loss functions ─────────────────────────
def infoNCE(z_a, z_b, tau):
    z_a, z_b = F.normalize(z_a, dim=1), F.normalize(z_b, dim=1)
    logits   = z_a @ z_b.T / tau
    labels   = torch.arange(z_a.size(0), device=z_a.device)
    return 0.5 * (
        F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)
    )

def masked_bleep(z_a, z_b, tau, theta):
    z_a, z_b = F.normalize(z_a, dim=1), F.normalize(z_b, dim=1)
    logits   = (z_a @ z_b.T) / tau
    with torch.no_grad():
        mask = ((z_a @ z_a.T) > theta) & ((z_b @ z_b.T) > theta)
        mask.fill_diagonal_(True)
        targets = mask.float()
        targets /= targets.sum(1, keepdim=True)
    log_q = F.log_softmax(logits, dim=-1)
    loss  = -(targets * log_q ).sum(-1).mean()
    loss += -(targets * log_q.T).sum(-1).mean()
    return 0.5 * loss



In [171]:

# ───────────────────── diagnostics helper (C) ───────────────────
def quick_stats(z_g, z_m):
    with torch.no_grad():
        z_g, z_m = F.normalize(z_g, dim=1), F.normalize(z_m, dim=1)
        pos = (z_g * z_m).sum(dim=1).mean().item()
        neg = (z_g @ z_m.T).mean().item()
    return pos, neg
    
def row_alignment_score(raw_a, raw_b):
    """
    Quick test *before training*: 1.0 = perfectly aligned rows,
    0.0 = completely random. If this is ~0 your tensors are scrambled.
    """
    raw_a = F.normalize(raw_a, dim=1)
    raw_b = F.normalize(raw_b, dim=1)
    r1 = ((raw_a @ raw_b.T).argmax(dim=1) ==
          torch.arange(raw_a.size(0), device=raw_a.device)).float().mean()
    return r1.item()

# ──────────────────────────────────────────────────────────
#  Training loop with *inline* diagnostics
# ──────────────────────────────────────────────────────────
def train_bleep(gene, morph):
    g_proj, m_proj = (ProjectionHead(gene.shape[1], proj_dim).to(device),
                      ProjectionHead(morph.shape[1], proj_dim).to(device))
    params  = list(g_proj.parameters()) + list(m_proj.parameters())
    opt     = torch.optim.AdamW(params, lr=lr_peak, weight_decay=0.0)

    N, idx = gene.size(0), np.arange(gene.size(0))
    for ep in range(1, epochs + 1):
        np.random.shuffle(idx); running = 0.0
        # weight of InfoNCE in the blended loss (linear decay)
        if ep <= warmup_epochs:
            alpha = 1.0
        else:
            frac  = (ep - warmup_epochs) / (epochs - warmup_epochs)
            alpha = max(1 - (1 - alpha_end) * frac, 0.60)   # ← keep ≥0.60
        for s in range(0, N, batch_size):
            b   = idx[s : s + batch_size]
            z_g = g_proj(gene [b].to(device))
            z_m = m_proj(morph[b].to(device))

            l_infonce = infoNCE(z_g, z_m, tau)
            l_bleep   = masked_bleep(z_g, z_m, tau, theta)
            loss      = alpha * l_infonce + (1 - alpha) * l_bleep

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(params, grad_clip)
            opt.step()
            running += loss.item()

        # diagnostics
        zg_n, zm_n = F.normalize(z_g, dim=1), F.normalize(z_m, dim=1)
        pos = (zg_n * zm_n).sum(1).mean().item()
        neg = (zg_n @ zm_n.T).mean().item()
        print(
            f"E{ep:03d} | loss {running/(N//batch_size):.4f} | "
            f"⟨cos⁺⟩={pos:.3f} ⟨cos⁻⟩={neg:.3f} | α={alpha:.2f}"
        )


In [172]:

train_bleep(gene_tensor, morph_tensor)

E001 | loss 6.5798 | ⟨cos⁺⟩=0.497 ⟨cos⁻⟩=0.032 | α=1.00
E002 | loss 5.9673 | ⟨cos⁺⟩=0.529 ⟨cos⁻⟩=0.022 | α=1.00
E003 | loss 5.8158 | ⟨cos⁺⟩=0.552 ⟨cos⁻⟩=0.021 | α=1.00
E004 | loss 5.7244 | ⟨cos⁺⟩=0.564 ⟨cos⁻⟩=0.012 | α=1.00
E005 | loss 5.6606 | ⟨cos⁺⟩=0.565 ⟨cos⁻⟩=0.015 | α=1.00
E006 | loss 5.6369 | ⟨cos⁺⟩=0.569 ⟨cos⁻⟩=0.017 | α=0.98
E007 | loss 5.6254 | ⟨cos⁺⟩=0.566 ⟨cos⁻⟩=0.011 | α=0.96
E008 | loss 5.6209 | ⟨cos⁺⟩=0.576 ⟨cos⁻⟩=0.009 | α=0.95
E009 | loss 5.6216 | ⟨cos⁺⟩=0.575 ⟨cos⁻⟩=0.014 | α=0.93
E010 | loss 5.6255 | ⟨cos⁺⟩=0.583 ⟨cos⁻⟩=0.011 | α=0.91
E011 | loss 5.6331 | ⟨cos⁺⟩=0.592 ⟨cos⁻⟩=0.015 | α=0.89
E012 | loss 5.6424 | ⟨cos⁺⟩=0.586 ⟨cos⁻⟩=0.011 | α=0.88
E013 | loss 5.6539 | ⟨cos⁺⟩=0.598 ⟨cos⁻⟩=0.009 | α=0.86
E014 | loss 5.6680 | ⟨cos⁺⟩=0.576 ⟨cos⁻⟩=0.015 | α=0.84
E015 | loss 5.6822 | ⟨cos⁺⟩=0.593 ⟨cos⁻⟩=0.012 | α=0.82
E016 | loss 5.6975 | ⟨cos⁺⟩=0.593 ⟨cos⁻⟩=0.009 | α=0.80
E017 | loss 5.7155 | ⟨cos⁺⟩=0.588 ⟨cos⁻⟩=0.012 | α=0.79
E018 | loss 5.7328 | ⟨cos⁺⟩=0.596 ⟨cos⁻⟩=0.007 |

In [None]:
import torch, torch.nn.functional as F

def neighbourhood_alignment(z_gene, z_morph, k=20):
    # z_* : (N,D) tensors, L2 normalised
    z_gene  = F.normalize(z_gene, dim=1)
    z_morph = F.normalize(z_morph, dim=1)

    nn_gene  = torch.topk(z_gene  @ z_gene .T, k+1, dim=-1).indices[:,1:]
    nn_morph = torch.topk(z_morph @ z_morph.T, k+1, dim=-1).indices[:,1:]

    overlap = (nn_gene == nn_morph).float().sum(dim=1).mean() / k
    return overlap.item()          # fraction of shared neighbours

print("k-NN overlap =", neighbourhood_alignment(gene_tensor, morph_tensor))


In [None]:
# After training, compute final embeddings

with torch.no_grad():
    gene_emb_final = gene_proj_net(gene_tensor).cpu().numpy()
    morph_emb_final = morph_proj_net(morph_tensor).cpu().numpy()

# Combine embeddings (e.g., average)
joint_embeddings = (gene_emb_final + morph_emb_final) / 2

# Save embeddings
gene_emb_df = pd.DataFrame(gene_emb_final, index=gene_embeddings.index)
morph_emb_df = pd.DataFrame(morph_emb_final, index=gene_embeddings.index)
joint_emb_df = pd.DataFrame(joint_embeddings, index=gene_embeddings.index)

# os.makedirs(os.path.join(embedding_dir,"contrastive_learning"), exist_ok=True)
# gene_emb_df.to_csv(os.path.join(embedding_dir,"contrastive_learning", f"gene_projection_embeddings_{morph_version}.csv"))
# morph_emb_df.to_csv(os.path.join(embedding_dir,"contrastive_learning", f"morph_projection_embeddings_{morph_version}.csv"))
# joint_emb_df.to_csv(os.path.join(embedding_dir,"contrastive_learning", f"joint_embeddings_{morph_version}.csv"))

In [None]:
edata = sc.AnnData(joint_emb_df.values, obs=pd.DataFrame(index=joint_emb_df.index))

# Compute neighborhood graph (for clustering)
sc.pp.neighbors(edata, n_neighbors=15, use_rep='X')

# UMAP dimensionality reduction
sc.tl.umap(edata)

# Leiden clustering
sc.tl.leiden(edata, resolution=0.5)

# Plot UMAP with clusters
sc.pl.umap(edata, color='leiden', size=20, legend_loc='right margin', frameon=False)

# Extract Scanpy's cluster colors from UMAP plot
leiden_palette = edata.uns["leiden_colors"]
leiden_ids = sorted(edata.obs["leiden"].astype(int).unique())
cluster_color_map = {str(i): c for i, c in zip(leiden_ids, leiden_palette)}

In [None]:
# Assuming spatial_coords dataframe matches joint embeddings index
edata.obs['x'] = spatial_coords.loc[edata.obs_names, 'x']
edata.obs['y'] = spatial_coords.loc[edata.obs_names, 'y']

cluster_key="leiden"
cluster_colors = edata.obs[cluster_key].astype(str).map(cluster_color_map)

# Spatial cluster visualization
plt.figure(figsize=(8, 6))
plt.scatter(edata.obs['x'], edata.obs['y'], c=cluster_colors, s=1)
plt.gca().invert_yaxis()  # Adjust as needed
plt.axis('equal')
plt.title('Spatial Visualization of Clusters')
plt.xlabel('X Coordinate')
plt.ylabel('Y Coordinate')
plt.show()


In [None]:
# Cell Classificatin network

class ContrastiveDataset(Dataset):
    def __init__(self, gene_128, morph_128, labels):
        self.gene   = torch.tensor(gene_128,  dtype=torch.float32)
        self.morph  = torch.tensor(morph_128, dtype=torch.float32)
        self.labels = torch.tensor(labels.values, dtype=torch.long)
    def __len__(self): 
        return len(self.labels)
    def __getitem__(self, idx):
        return self.gene[idx], self.morph[idx], self.labels[idx]



class DualFeatureTransformer(nn.Module):
    def __init__(self, d_in=128, d_model=128, heads=4, num_layers=4, n_cls=3):
        super().__init__()
        self.gene_proj  = nn.Linear(d_in, d_model)
        self.morph_proj = nn.Linear(d_in, d_model)
        self.gene_type  = nn.Parameter(torch.randn(1, d_model))
        self.morph_type = nn.Parameter(torch.randn(1, d_model))
        self.layers = nn.ModuleList(
            nn.TransformerEncoderLayer(
                d_model=d_model, nhead=heads, batch_first=True
            ) for _ in range(num_layers)
        )
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, n_cls)
        )
    def forward(self, g, m):
        g = self.gene_proj(g)  + self.gene_type   # [B,d_model]
        m = self.morph_proj(m) + self.morph_type
        x = torch.stack([g, m], dim=1)            # [B,2,d_model]
        for layer in self.layers: 
            x = layer(x)
        return self.classifier(x.mean(1))


In [None]:
# ──────────────────────────────────────────────────────────
# 0  Choose features depending on fusion‑decoder flag
# ──────────────────────────────────────────────────────────
if fusion_decoder:
    # gene_list and morph_list were never built — let's make them explicit
    gene_list, morph_list = [], []

    fusion_dec.eval()
    with torch.no_grad():
        for s in range(0, N, eval_batch):
            e = min(s + eval_batch, N)
            g_proj = gene_proj_net(gene_tensor[s:e])       # (b,128)
            m_proj = morph_proj_net(morph_tensor[s:e])     # (b,128)

            # fused representations (one step for each direction)
            g_dec = fusion_dec(g_proj.unsqueeze(1), m_proj.unsqueeze(1)).squeeze(1)
            m_dec = fusion_dec(m_proj.unsqueeze(1), g_proj.unsqueeze(1)).squeeze(1)

            gene_list.append(g_dec.cpu())
            morph_list.append(m_dec.cpu())

    gene_feat  = torch.cat(gene_list).numpy()   # (N,128)
    morph_feat = torch.cat(morph_list).numpy()  # (N,128)

else:
    gene_feat  = gene_proj_net(gene_tensor).detach().cpu().numpy()
    morph_feat = morph_proj_net(morph_tensor).detach().cpu().numpy()


# ──────────────────────────────────────────────────────────
# 1  Create train / val split (80 / 20)
# ──────────────────────────────────────────────────────────
dataset_conch = ContrastiveDataset(gene_feat, morph_feat, labels)
split_save_dir = "/rsrch5/home/plm/phacosta/xenium_project/Code/data_files"

with open(f"{split_save_dir}/train_test_indices.pkl", "rb") as f:
    idx_dict = pickle.load(f)
train_idx = np.asarray(idx_dict["train_idx"])
test_idx  = np.asarray(idx_dict["test_idx"])


train_dataset = Subset(dataset_conch, train_idx)   # ← same split
test_dataset  = Subset(dataset_conch, test_idx)

train_loader  = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader   = DataLoader(test_dataset,  batch_size=64, shuffle=False)

# ──────────────────────────────────────────────────────────
# 2  Instantiate classifier
# ──────────────────────────────────────────────────────────
num_classes = labels.nunique()
clf = DualFeatureTransformer(
    d_in=128, d_model=128, heads=4, num_layers=4, n_cls=num_classes
).to(device)

opt = torch.optim.AdamW(clf.parameters(), lr=3e-4)
best_acc = 0.0

# ──────────────────────────────────────────────────────────
# 3  Training loop
# ──────────────────────────────────────────────────────────
for epoch in range(1, 21):
    clf.train()
    epoch_loss = 0
    for g,m,y in train_loader:
        g,m,y = g.to(device), m.to(device), y.to(device)
        pred  = clf(g,m)
        loss  = F.cross_entropy(pred, y)
        opt.zero_grad(); loss.backward(); opt.step()
        epoch_loss += loss.item()

    # ---- validation ----
    clf.eval()
    correct = total = 0
    with torch.no_grad():
        for g,m,y in test_loader:
            g,m,y = g.to(device), m.to(device), y.to(device)
            logits = clf(g,m)
            correct += (logits.argmax(1) == y).sum().item()
            total   += y.size(0)
    acc = correct / total
    print(f"Epoch {epoch:02d} | loss={epoch_loss/len(train_loader):.4f} | val acc={acc:.3f}")

    if acc > best_acc:
        best_acc = acc
        # torch.save(clf.state_dict(), "best_celltype_transformer.pth")
        # print("✓ new best model saved")

print("Training finished.  Best validation accuracy:", best_acc)

In [None]:
clf.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for g, m, y in test_loader:
        g, m = g.to(device), m.to(device)
        logits = clf(g, m)
        all_preds.append(logits.argmax(1).cpu())
        all_labels.append(y)                 # already on CPU
all_preds  = torch.cat(all_preds).numpy()
all_labels = torch.cat(all_labels).numpy()


# target_classes = (
#     labels.cat.categories.tolist()
#     if hasattr(labels, "cat")
#     else sorted(np.unique(labels))
# )
# target_classes = [str(c) for c in target_classes]   # ← make them strings

target_classes = ["fibroblast", "endothelial",
                      "t_cell", "b_cell", "macrophage",
                      "epithelial"]
cm  = confusion_matrix(all_labels, all_preds, normalize="true")
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=target_classes)
disp.plot(cmap="viridis", xticks_rotation="vertical")
plt.title("Confusion Matrix – Contrastive features")
plt.tight_layout()
plt.show()

print("Classification Report")
print(classification_report(all_labels, all_preds, target_names=target_classes))