In [3]:
#%%
import os
import numpy as np 
import pandas as pd 
import scanpy as sc 
import matplotlib.pyplot as plt
from tqdm import tqdm  # Import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchmetrics import Accuracy

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(


In [4]:
#%%
# Data params
platform = "xenium" # xenium or visium 
ground_truth = "refined"  # refined or cellvit
level = 0
filtered_genes = False
granular_labels = True
use_singleR_qc = True
limit_classes = True  # Set to False to use all classes

# Training params
use_focal_loss = False


if platform == "xenium":
    cancer = "lung"
    xenium_folder_dict = {"lung": "Xenium_Prime_Human_Lung_Cancer_FFPE_outs",
                          "breast":"Xenium_Prime_Breast_Cancer_FFPE_outs",
                          "lymph_node": "Xenium_Prime_Human_Lymph_Node_Reactive_FFPE_outs",
                          "prostate": "Xenium_Prime_Human_Prostate_FFPE_outs",
                          "skin": "Xenium_Prime_Human_Skin_FFPE_outs",
                          "ovarian": "Xenium_Prime_Ovarian_Cancer_FFPE_outs",
                          "cervical": "Xenium_Prime_Cervical_Cancer_FFPE_outs"
                          }

    xenium_folder = xenium_folder_dict[cancer]
    
    data_path = f"/rsrch9/home/plm/idso_fa1_pathology/TIER1/paul-xenium/public_data/10x_genomics/{xenium_folder}/preprocessed/fine_tune_{ground_truth}_v2/processed_xenium_data_fine_tune_{ground_truth}_v2_annotated.h5ad"
    # data_path = "/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/Xenium_Prime_Human_Lung_Cancer_FFPE_outs/scGPT_CP.h5ad"

    if filtered_genes:
        gene_embedding_file = f"/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/{xenium_folder}/processed_xenium_refined_clustering_filtered_v2.csv"
    else:
        gene_embedding_file = f"/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/{xenium_folder}/processed_xenium_{ground_truth}_v2.csv"
    
    # Load Morphological Embeddings
    morph_embedding_dir = "/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/Xenium_Prime_Human_Lung_Cancer_FFPE_outs"

    # Load Morphological Embeddings
    morph_embedding_dir = f"/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/{xenium_folder}"

        
elif platform == "visium":
    data_path = "/rsrch9/home/plm/idso_fa1_pathology/TIER1/paul-xenium/public_data/10x_genomics/Visium_HD_Human_Lung_Cancer_post_Xenium_Prime_5k_Experiment2/binned_outputs/square_002um/preprocessed/bin2cell/to_tokenize/corrected_cells_matched_preprocessed_refined_v2_annotated.h5ad"

    gene_embedding_file = "/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/Visium_HD_Human_Lung_Cancer_post_Xenium_Prime_5k_Experiment2/bin2cell/embeddings_output/processed_visium_hd_bin2cell.csv"
    morph_embedding_dir = f"/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/Visium_HD_Human_Lung_Cancer_post_Xenium_Prime_5k_Experiment2"

# Load AnnData
adata = sc.read_h5ad(data_path)
cell_data = adata.obs

# Spatial Information 
spatial_coords = cell_data[['x_centroid', 'y_centroid']].rename(columns={'x_centroid': 'x', 'y_centroid': 'y'})

# Load gene Embeddings 
# gene_embeddings = pd.read_csv(gene_embedding_file, index_col="Unnamed: 0")
# gene_embeddings.index = cell_data.index
scGPT_path = "/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/Xenium_Prime_Human_Lung_Cancer_FFPE_outs/scGPT_CP.h5ad"
scGPT_adata = sc.read_h5ad(scGPT_path)
gene_embeddings = pd.DataFrame(scGPT_adata.obsm["X_scGPT"])
gene_embeddings.index = cell_data.index

# Load Morphology Embeddings 
morph_embedding_csv = os.path.join(morph_embedding_dir, "UNI2_cell_representation",f"level_{level}","morphology_embeddings_v2.csv")
morph_embeddings = pd.read_csv(morph_embedding_csv, index_col="Unnamed: 0")


if granular_labels:
    print("Using granular labels from SingleR.")
    label_key = "granular_class"
    singleR_to_class_map = {
        "Smooth muscle": "fibroblast",
        "Fibroblasts": "fibroblast",
        "Endothelial cells": "fibroblast",
        "CD4+ T-cells": "t_cell",
        "CD8+ T-cells": "t_cell",
        "B-cells": "b_cell",
        "Macrophages": "macrophage",
        "Epithelial cells": "tumor",
    }
    
    target_classes = ["fibroblast", #"endothelial",
                      "t_cell", "b_cell", "macrophage",
                      "tumor"]
    
    # Map SingleR labels to 7-class system
    cell_data[label_key] = cell_data["singleR_class"].map(singleR_to_class_map)
    
    # Drop cells that are nan (if any)
    cell_data = cell_data.dropna(subset=[label_key])
    
    # Keep only those 7 classes
    cell_data = cell_data[cell_data[label_key].isin(target_classes)]
    
    if use_singleR_qc:
        cell_data = cell_data[cell_data["qc_singleR"]==1]

    
    # Reindex embeddings/coords
    gene_embeddings = gene_embeddings.reindex(cell_data.index)
    morph_embeddings = morph_embeddings.reindex(cell_data.index)
    spatial_coords = spatial_coords.reindex(cell_data.index)
    
    # Numeric mapping
    label_mapping = {cls_name: i for i, cls_name in enumerate(target_classes)}
    labels = pd.Series(cell_data["granular_class"].map(label_mapping))
    
    
    
else:
    print("Using AISTIL labels")
    label_key = "class"
    target_classes = ["f", "l", "t"]  # Modify this list to restrict classification to specific classes
    if limit_classes:
        num_classes = len(target_classes)
        cell_data = cell_data[cell_data[label_key].isin(target_classes)]

        # Change index type for Visium data to match embeddings Idxs 
        if platform == "visium":
            morph_embeddings.index = morph_embeddings.index.astype(str)

        # Update corresponding embeddings and spatial coordinates
        gene_embeddings = gene_embeddings.reindex(cell_data.index)
        morph_embeddings = morph_embeddings.reindex(cell_data.index)
        spatial_coords = spatial_coords.reindex(cell_data.index)
    else:
        target_classes = ["f","l","o","t"]
    
# Convert labels to numerical indices
# label_mapping = {l:c for c,l in enumerate(target_classes)}
# labels = pd.Series([label_mapping[lbl] for lbl in cell_data["class"]])

num_classes = len(target_classes)
label_mapping = {cls_name: i for i, cls_name in enumerate(target_classes)}
labels = pd.Series(cell_data[label_key].map(label_mapping))


marker_genes = [
    "CD3E",   # T-cells
    "CD3G",    # T-cells
    "CD4",    # T-cells
    "CD8A",    # T-cells
    "CD8B",    # T-cells
    "CD19",   # B-cells
    "CD27",  # B-cells
    "CD68",   # Macrophages
    "PDGFRA", # stromal
    "EPCAM",  # Epithelial / tumor
    "EGFR",    # epithelial tumor
    ]

# Ensure adata.var_names are the gene names
# Make sure marker_genes are present
valid_markers = [g for g in marker_genes if g in adata.var_names]
print(f"Using {len(valid_markers)} valid markers: {valid_markers}")

# Subset the X matrix or a relevant layer
# e.g., adata.raw.X or adata.X if everything is properly stored
marker_expr = adata[:, valid_markers].X  # shape = (n_cells, num_markers)

# Convert to a DataFrame for convenience, aligned with adata.obs
marker_expr_df = pd.DataFrame(marker_expr.toarray() if hasattr(marker_expr, "toarray") else marker_expr,
                              columns=valid_markers,
                              index=adata.obs.index)
# Suppose cell_data is filtered to your final set of cells
marker_expr_df = marker_expr_df.reindex(cell_data.index)


Using granular labels from SingleR.
Using 11 valid markers: ['CD3E', 'CD3G', 'CD4', 'CD8A', 'CD8B', 'CD19', 'CD27', 'CD68', 'PDGFRA', 'EPCAM', 'EGFR']


In [5]:
print("Cell data index:", cell_data.index.tolist()[:5])
print("Morph embeddings index:", morph_embeddings.index.tolist()[:5])
print("Gene embeddings index:", gene_embeddings.index.tolist()[:5])
print("Spatial coords index:", spatial_coords.index.tolist()[:5])
print("adata.X dtype:", adata.X.dtype)
print("Min:", adata.X.min(), "Max:", adata.X.max())
print("Mean:", adata.X.mean())
marker_expr_df

Cell data index: ['aaaaadnb-1', 'aaaabalp-1', 'aaaadjia-1', 'aaaafglb-1', 'aaaagbdd-1']
Morph embeddings index: ['aaaaadnb-1', 'aaaabalp-1', 'aaaadjia-1', 'aaaafglb-1', 'aaaagbdd-1']
Gene embeddings index: ['aaaaadnb-1', 'aaaabalp-1', 'aaaadjia-1', 'aaaafglb-1', 'aaaagbdd-1']
Spatial coords index: ['aaaaadnb-1', 'aaaabalp-1', 'aaaadjia-1', 'aaaafglb-1', 'aaaagbdd-1']
adata.X dtype: float32
Min: 0.0 Max: 8.229777
Mean: 0.18159501


Unnamed: 0,CD3E,CD3G,CD4,CD8A,CD8B,CD19,CD27,CD68,PDGFRA,EPCAM,EGFR
aaaaadnb-1,4.895923,0.0,3.812154,0.000000,0.0,0.0,0.0,3.812154,0.000000,0.0,0.0
aaaabalp-1,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0
aaaadjia-1,4.811788,0.0,3.449600,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0
aaaafglb-1,4.500041,0.0,0.000000,3.423402,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0
aaaagbdd-1,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...
oijmaenh-1,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0
oijpapcb-1,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,4.261694,0.0,0.0
oijpeago-1,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0
oikbajbf-1,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0


In [None]:

# focal Loss

class FocalLoss(nn.Module):
    """
    Multi-class Focal Loss implementation
    gamma: focusing parameter
    alpha: can be a single float (scalar) or a Tensor of shape [num_classes]
    reduction: 'mean' or 'sum'
    """
    def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        # alpha can be scalar or per-class weights
        # If alpha is a list/tuple/np array, convert it to a FloatTensor
        if alpha is not None:
            if isinstance(alpha, (float, int)):
                self.alpha = torch.tensor([alpha], dtype=torch.float)
            else:
                self.alpha = torch.as_tensor(alpha, dtype=torch.float)
        else:
            self.alpha = None
        self.reduction = reduction

    def forward(self, inputs, targets):
        """
        inputs: [N, C] logits (no softmax)
        targets: [N] class indices in [0, C-1]
        """
        # Standard cross-entropy (per-sample)
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=None)
        
        # Convert to probabilities of correct class
        pt = torch.exp(-ce_loss)  # = exp(-CE) = p_t

        # If alpha is per-class, pick alpha for each target
        if self.alpha is not None:
            # If alpha is scalar, multiply entire loss by alpha
            # If alpha is a tensor of shape [C], pick per-sample alpha
            if len(self.alpha) == 1:
                focal_loss = self.alpha[0] * ((1 - pt) ** self.gamma) * ce_loss
            else:
                # alpha for each class
                alpha_t = self.alpha[targets]  # shape [N]
                focal_loss = alpha_t * ((1 - pt) ** self.gamma) * ce_loss
        else:
            # No class weighting
            focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        
        # Reduction
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss  # 'none' for no reduction




# Positional Encoding
class PositionalEncoding2D(nn.Module):
    """Sinusoidal positional encoding for spatial coordinates"""
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.proj = nn.Linear(4, d_model)  # (sin(x), cos(x), sin(y), cos(y))
        
    def forward(self, coords):
        # Ensure coords is [B, 2]
        assert coords.ndim == 2, f"Expected coords shape [B, 2], but got {coords.shape}"
        x = coords[:, 0] * 2 * torch.pi
        y = coords[:, 1] * 2 * torch.pi
        
        pe = torch.stack([
            torch.sin(x), torch.cos(x),
            torch.sin(y), torch.cos(y)
        ], dim=-1)  # [B, 4]
        pe = self.proj(pe)  # [B, d_model]
        return pe

# Transformer Layer with Relative Position Attention
class RelativePositionTransformerLayer(nn.TransformerEncoderLayer):
    """Enhanced with relative position attention"""
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        # Initialize parent class with batch_first=True.
        super().__init__(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.pos_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.pos_norm = nn.LayerNorm(d_model)
        # Add a dropout for the feedforward branch
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, src, pos_emb):
        # Uncomment to debug
        # print(f"src shape before attn: {src.shape}")   # Expected [B, 2, d_model]
        # print(f"pos_emb shape before attn: {pos_emb.shape}")  # Expected [B, 1, d_model]

        # Ensure pos_emb has the same sequence length as src.
        pos_emb = pos_emb.expand(-1, src.shape[1], -1)  # Now [B, 2, d_model]
        
        # Uncomment to debug
        # print(f"pos_emb shape after expansion: {pos_emb.shape}")  # Should be [B, 2, d_model]

        # Standard self-attention
        src2 = self.self_attn(src, src, src, need_weights=False)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # Position-aware attention
        src2 = self.pos_attn(src, pos_emb, pos_emb)[0]
        src = src + self.dropout2(src2)
        src = self.pos_norm(src)

        # Feedforward
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout3(src2)
        src = self.norm2(src)
        return src

class CellTransformer(nn.Module):
    def __init__(self, d_model=512, num_heads=8, num_classes=3, num_markers=0):
        super().__init__()

        # Let's assume your Geneformer embedding is 512-D
        # and you have num_markers columns for raw expression
        self.gene_input_dim = 512 + num_markers

        # MLP for combined gene+markers
        self.gene_encoder = nn.Sequential(
            nn.Linear(self.gene_input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, d_model)
        )

        # MLP for morphology
        self.morph_encoder = nn.Sequential(
            nn.Linear(1536, 1024),
            nn.ReLU(),
            nn.Linear(1024, d_model)
        )

        self.spatial_pe = PositionalEncoding2D(d_model)
        self.gene_type = nn.Parameter(torch.randn(1, d_model))
        self.morph_type = nn.Parameter(torch.randn(1, d_model))

        # 6 Transformer layers
        self.layers = nn.ModuleList([
            RelativePositionTransformerLayer(d_model, nhead=num_heads)
            for _ in range(6)
        ])

        # Classifier
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, num_classes)
        )

    def forward(self, gene, morph, spatial, marker):
        # gene: [B, 512]
        # marker: [B, num_markers]
        combined_gene = torch.cat([gene, marker], dim=1)  # [B, 512 + num_markers]

        gene_emb = self.gene_encoder(combined_gene) + self.gene_type
        morph_emb = self.morph_encoder(morph) + self.morph_type

        spatial_emb = self.spatial_pe(spatial)  # [B, d_model]

        tokens = torch.stack([gene_emb, morph_emb], dim=1)  # [B, 2, d_model]
        spatial_emb = spatial_emb.unsqueeze(1)

        for layer in self.layers:
            tokens = layer(tokens, spatial_emb)

        pooled = tokens.mean(dim=1)
        return self.classifier(pooled)

    
# Data Handling

class CellDataset(Dataset):
    def __init__(self, gene_df, morph_df, spatial_df, marker_df, labels):
        # Convert everything to torch.Tensor
        self.gene = torch.tensor(gene_df.values, dtype=torch.float32)
        self.morph = torch.tensor(morph_df.values, dtype=torch.float32)
        self.spatial = torch.tensor(spatial_df.values, dtype=torch.float32)
        self.marker = torch.tensor(marker_df.values, dtype=torch.float32)
        self.labels = torch.tensor(labels.values, dtype=torch.long)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        gene_sample = self.gene[idx]       # [512]
        morph_sample = self.morph[idx]     # [1536]
        spatial_sample = self.spatial[idx] # [2]
        marker_sample = self.marker[idx]   # [num_markers]
        label = self.labels[idx]
        return gene_sample, morph_sample, spatial_sample, marker_sample, label

def compute_class_weights(labels):
    counts = np.bincount(labels.values)
    weights = 1. / (counts + 1e-8)  # Prevent division by zero
    return torch.tensor(weights, dtype=torch.float32)



# Training and testing loop

num_epochs = 20 
batch_size = 64

# Initialize device and model
device = torch.device("cuda")
num_markers = len(valid_markers)
model = CellTransformer(
    d_model=512,       # same as before
    num_heads=8,
    num_classes=num_classes, 
    num_markers=num_markers
).to(device)
    
    
# model = CellTransformer(d_model=512, num_heads=8, num_classes=num_classes).to(device)

# Optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
# dataset = CellDataset(gene_embeddings, morph_embeddings, spatial_coords, labels)

dataset = CellDataset(
    gene_embeddings,       # shape [N, 512]
    morph_embeddings,      # shape [N, 1536]
    spatial_coords,        # shape [N, 2]
    marker_expr_df,        # shape [N, num_markers]
    labels                 # shape [N]
)


# Split dataset into 80% train and 20% test (stratified by labels)
all_indices = np.arange(len(dataset))
train_idx, test_idx = train_test_split(
    all_indices, test_size=0.2, random_state=42, stratify=dataset.labels.numpy()
)
train_dataset = Subset(dataset, train_idx)
test_dataset = Subset(dataset, test_idx)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=3e-5, steps_per_epoch=len(train_loader), epochs=num_epochs
)

if use_focal_loss:
        criterion = FocalLoss(
        gamma=2.0,  # typical default
        alpha=compute_class_weights(labels).to(device),  # or a scalar, or None
        reduction='mean'
    ).to(device)

else:
    # Loss function with class weights moved to the proper device
    criterion = nn.CrossEntropyLoss(weight=compute_class_weights(labels).to(device))

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    correct = 0
    total = 0

    # Wrap the train_loader with tqdm for a progress bar
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False)
    for batch_idx, (gene, morph, spatial, marker, lbls) in enumerate(progress_bar):
        gene, morph, spatial, marker, lbls = (
            gene.to(device),
            morph.to(device),
            spatial.to(device),
            marker.to(device),
            lbls.to(device)
        )
        
        optimizer.zero_grad()
        outputs = model(gene, morph, spatial, marker)
        loss = criterion(outputs, lbls)
        loss.backward()
        optimizer.step()
        scheduler.step()

        epoch_loss += loss.item() * gene.size(0)
        _, predicted = torch.max(outputs, 1)
        total += lbls.size(0)
        correct += (predicted == lbls).sum().item()

        # Update progress bar with current loss
        progress_bar.set_postfix(loss=loss.item())

    avg_loss = epoch_loss / total
    accuracy = correct / total * 100 
    print(f"Epoch {epoch+1}: Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}%")

    # Evaluation on the test set
    model.eval()
    test_loss = 0.0
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for gene, morph, spatial, marker, lbls in test_loader:
            gene, morph, spatial, marker, lbls = (
                gene.to(device),
                morph.to(device),
                spatial.to(device),
                marker.to(device),
                lbls.to(device)
            )
        
            outputs = model(gene, morph, spatial, marker)
            loss = criterion(outputs, lbls)
            test_loss += loss.item() * gene.size(0)
            _, predicted = torch.max(outputs, 1)
            test_total += lbls.size(0)
            test_correct += (predicted == lbls).sum().item()
    avg_test_loss = test_loss / test_total
    test_accuracy = test_correct / test_total * 100
    print(f"Test: Loss: {avg_test_loss:.4f} | Accuracy: {test_accuracy:.2f}%")



                                                                        

Epoch 1: Loss: 1.2414 | Accuracy: 51.25%
Test: Loss: 0.5802 | Accuracy: 83.21%


Epoch 2:  69%|██████▉   | 1056/1528 [00:32<00:13, 34.69it/s, loss=0.334]

In [None]:
gene_embeddings

In [None]:
# Eval on same data 
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for gene, morph, spatial, marker, lbls in test_loader:
        gene, morph, spatial, marker, lbls = (
            gene.to(device),
            morph.to(device),
            spatial.to(device),
            marker.to(device),
            lbls.to(device)
        )

        outputs = model(gene, morph, spatial, marker)
        preds = torch.argmax(outputs, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(lbls.cpu().numpy())

cm = confusion_matrix(all_labels, all_preds)

disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                              display_labels=target_classes)


disp.plot(cmap='viridis', xticks_rotation='vertical')
plt.title(f"Confusion Matrix: {platform}")
plt.show()
print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=target_classes))

In [None]:
cm

In [None]:
#%%Eval trained model on different data 
# Select platform
platform = "xenium" # xenium or visium 
ground_truth = "refined"  # refined or cellvit
level = 0

#%% Define Class Limiting Parameters
limit_classes = True  # Set to False to use all classes
target_classes = ["f", "l", "t"]  # Modify this list to restrict classification to specific classes

if platform == "xenium":
    cancer = "breast"
    xenium_folder_dict = {"lung": "Xenium_Prime_Human_Lung_Cancer_FFPE_outs",
                          "breast":"Xenium_Prime_Breast_Cancer_FFPE_outs",
                          "lymph_node": "Xenium_Prime_Human_Lymph_Node_Reactive_FFPE_outs",
                          "prostate": "Xenium_Prime_Human_Prostate_FFPE_outs",
                          "skin": "Xenium_Prime_Human_Skin_FFPE_outs",
                          "ovarian": "Xenium_Prime_Ovarian_Cancer_FFPE_outs",
                          "cervical": "Xenium_Prime_Cervical_Cancer_FFPE_outs"
                          }

    xenium_folder = xenium_folder_dict[cancer]
    
    data_path = f"/rsrch9/home/plm/idso_fa1_pathology/TIER1/paul-xenium/public_data/10x_genomics/{xenium_folder}/preprocessed/fine_tune_{ground_truth}_v2/processed_xenium_data_fine_tune_{ground_truth}_v2.h5ad"
    gene_embedding_file = f"/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/{xenium_folder}/processed_xenium_{ground_truth}_v2.csv"
    
    # Load Morphological Embeddings
    morph_embedding_dir = f"/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/{xenium_folder}"

        
elif platform == "visium":
    data_path = "/rsrch9/home/plm/idso_fa1_pathology/TIER1/paul-xenium/public_data/10x_genomics/Visium_HD_Human_Lung_Cancer_post_Xenium_Prime_5k_Experiment2/binned_outputs/square_002um/preprocessed/bin2cell/to_tokenize/corrected_cells_matched_preprocessed_refined_v2.h5ad"

    gene_embedding_file = "/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/Visium_HD_Human_Lung_Cancer_post_Xenium_Prime_5k_Experiment2/bin2cell/embeddings_output/processed_visium_hd_bin2cell.csv"
    morph_embedding_dir = f"/rsrch9/home/plm/idso_fa1_pathology/TIER2/paul-xenium/embeddings/public_data/Visium_HD_Human_Lung_Cancer_post_Xenium_Prime_5k_Experiment2"

# Load AnnData
adata = sc.read_h5ad(data_path)
cell_data = adata.obs

# Spatial Information 
spatial_coords = cell_data[['x_centroid', 'y_centroid']].rename(columns={'x_centroid': 'x', 'y_centroid': 'y'})

# Load gene Embeddings 
gene_embeddings = pd.read_csv(gene_embedding_file, index_col="Unnamed: 0")
gene_embeddings.index = cell_data.index

# Load Morphology Embeddings 
morph_embedding_csv = os.path.join(morph_embedding_dir, "UNI2_cell_representation",f"level_{level}","morphology_embeddings_v2.csv")
morph_embeddings = pd.read_csv(morph_embedding_csv, index_col="Unnamed: 0")


if limit_classes:
    num_classes = len(target_classes)
    cell_data = cell_data[cell_data["class"].isin(target_classes)]
    
    # Change index type for Visium data to match embeddings Idxs 
    if platform == "visium":
        morph_embeddings.index = morph_embeddings.index.astype(str)

    # Update corresponding embeddings and spatial coordinates
    gene_embeddings = gene_embeddings.reindex(cell_data.index)
    morph_embeddings = morph_embeddings.reindex(cell_data.index)
    spatial_coords = spatial_coords.reindex(cell_data.index)
else:
    target_classes = ["f","l","o","t"]
    
# Convert labels to numerical indices
label_mapping = {l:c for c,l in enumerate(target_classes)}
labels = pd.Series([label_mapping[lbl] for lbl in cell_data["class"]])

dataset = CellDataset(gene_embeddings, morph_embeddings, spatial_coords, labels)
test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for gene, morph, spatial, lbls in test_loader:
        gene = gene.to(device)
        morph = morph.to(device)
        spatial = spatial.to(device)
        lbls = lbls.to(device)

        outputs = model(gene, morph, spatial)
        preds = torch.argmax(outputs, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(lbls.cpu().numpy())

cm = confusion_matrix(all_labels, all_preds)

disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                              display_labels=target_classes)


disp.plot(cmap='viridis', xticks_rotation='vertical')
plt.title(f"Confusion Matrix: {platform}")
plt.show()
print("Classification Report:")
print(classification_report(all_labels, all_preds))