## 1. Setup and Configuration

In [None]:
import os
import tempfile
import scanpy as sc
import scvi
import seaborn as sns
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rcParams
%matplotlib inline
import warnings
warnings.filterwarnings("ignore")

In [None]:
base_dir = "/blue/clive/smith6jt/KINTSUGI/notebooks"

In [None]:
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)

In [None]:
# torch.set_float32_matmul_precision("high")


save_dir = os.path.join(base_dir, "CODEX_scvi_model_v2")
# os.makedirs(save_dir, exist_ok=True)

# print(f"Save directory: {save_dir}")

## 2. Load and Prepare CODEX Data

Load the CODEX data and prepare it for scVI training. Key considerations:
- CODEX uses protein markers (not genes)
- Raw MFI values (not transcript counts)
- Batch variable is donor_group (ND/Aab+/T1D)

In [None]:
# Load CODEX data
adata = sc.read_h5ad(os.path.join(base_dir, '../panc_CODEX.h5ad'))

print("Loaded CODEX data:")
print(adata)
print(f"\nCells: {adata.n_obs:,}")
print(f"Proteins: {adata.n_vars}")

In [None]:
# Examine the donor status distribution (imageid column)
print("Image ID (donor) distribution:")
print(adata.obs['imageid'].value_counts())

print("\nDonor Status distribution:")
print(adata.obs['Donor Status'].value_counts())

print("\nAvailable metadata columns:")
print(adata.obs.columns.tolist())

print("\nProtein markers:")
print(adata.var_names.tolist())

In [None]:
# IMPORTANT: For raw CODEX MFI data, calculate QC metrics BEFORE normalization
# Then normalize and filter based on those metrics

print("Step 1: Store raw MFI data and calculate QC metrics")
print(f"  Raw data range: {adata.X.min():.2f} to {adata.X.max():.2f}")

# Store the absolutely raw data
from scipy.sparse import issparse
if issparse(adata.X):
    adata.X = adata.X.toarray()

adata.layers["raw_mfi"] = adata.X.copy()

# Calculate QC metrics on RAW data (before any transformation)
sc.pp.calculate_qc_metrics(adata, percent_top=None, log1p=False, inplace=True)

print(f"\nQC metrics (from raw MFI):")
print(f"  Mean proteins detected per cell: {adata.obs['n_genes_by_counts'].mean():.1f}")
print(f"  Mean total MFI per cell: {adata.obs['total_counts'].mean():.1f}")
print(f"  Median total MFI per cell: {adata.obs['total_counts'].median():.1f}")

In [None]:
# Step 2: Quality control filtering (using QC metrics from raw data)
print("\nStep 2: Quality Control Filtering")
print(f"Before filtering:")
print(f"  Cells: {adata.n_obs:,}")
print(f"  Proteins: {adata.n_vars}")

# Check the QC distribution to set appropriate thresholds
print(f"\nQC Distribution:")
print(f"  Proteins detected per cell (5th percentile): {adata.obs['n_genes_by_counts'].quantile(0.05):.1f}")
print(f"  Proteins detected per cell (median): {adata.obs['n_genes_by_counts'].median():.1f}")
print(f"  Total MFI per cell (5th percentile): {adata.obs['total_counts'].quantile(0.05):.1f}")
print(f"  Total MFI per cell (median): {adata.obs['total_counts'].median():.1f}")

# Filter cells with very few detected proteins or low total MFI (likely debris)
# More stringent filtering for quality
min_proteins = 2  # At least 15 of 31 proteins detected (>50%)
min_total_mfi = adata.obs['total_counts'].quantile(0.01)  # Remove bottom 1% by total MFI

print(f"\nApplying filters:")
print(f"  Minimum proteins: {min_proteins}")
print(f"  Minimum total MFI: {min_total_mfi:.1f}")

sc.pp.filter_cells(adata, min_genes=min_proteins)
adata = adata[adata.obs['total_counts'] >= min_total_mfi].copy()

# Filter proteins detected in very few cells (optional - usually not needed for CODEX)
# Since CODEX panels are pre-selected, all proteins should be informative
min_cells = int(0.001 * adata.n_obs)  # Protein must be detected in at least 0.1% of cells
sc.pp.filter_genes(adata, min_cells=min_cells)

print(f"\nAfter filtering:")
print(f"  Cells: {adata.n_obs:,}")
print(f"  Proteins: {adata.n_vars}")
n_removed = 2676419 - adata.n_obs
print(f"  Removed {n_removed:,} low-quality cells ({n_removed/2676419*100:.2f}%)")

In [None]:
# Step 3: Apply CLR normalization (AFTER filtering)
# CLR (Centered Log-Ratio) transformation - standard for CODEX/CITE-seq protein data
print("\nStep 3: Apply CLR normalization (protein-specific)")

# Make sure we're working with the filtered raw data
adata.X = adata.layers["raw_mfi"].copy()

# Add small pseudocount to avoid log(0)
eps = 1e-5
adata.X = adata.X + eps

# CLR transformation: log(x / geometric_mean(x))
# For each cell, normalize protein values by their geometric mean
geom_mean = np.exp(np.mean(np.log(adata.X), axis=1, keepdims=True))
adata.X = np.log(adata.X / geom_mean)

print(f"  After CLR: range {adata.X.min():.2f} to {adata.X.max():.2f}")
print(f"  Mean: {adata.X.mean():.2f}, Std: {adata.X.std():.2f}")

# Store CLR-normalized data for later use
adata.layers["clr_normalized"] = adata.X.copy()

In [None]:
# Step 4: Prepare data for scVI
# scVI expects raw counts, so we'll use the raw MFI data

print("\nStep 4: Prepare counts layer for scVI")

# Use raw MFI as "counts" for scVI (before CLR transformation)
adata.layers["counts"] = adata.layers["raw_mfi"].copy()
print("Created 'counts' layer from raw MFI values")

# Apply standard scanpy normalization to X for visualization/DE testing
# (scVI will use the counts layer, not X)
adata.X = adata.layers["clr_normalized"].copy()  # Use CLR for X
sc.pp.scale(adata, max_value=10)  # Scale for PCA/visualization if needed

print(f"\nData prepared:")
print(f"  'counts' layer (for scVI): raw MFI, shape {adata.layers['counts'].shape}")
print(f"  'clr_normalized' layer: CLR-transformed, shape {adata.layers['clr_normalized'].shape}")
print(f"  X (for visualization): scaled CLR, shape {adata.X.shape}")

### Important Note: Why Use Raw MFI for scVI?

**The Question:** Should we use raw MFI counts or CLR-normalized data for scVI?

**Answer: Use RAW MFI counts** ✓

**Reasoning:**
1. **scVI was designed for raw counts** (UMI-based RNA-seq data)
2. **scVI internally models the data distribution** using a negative binomial likelihood
3. **scVI performs its own normalization** as part of the generative model
4. **CLR normalization** transforms data to log-ratios, which violates scVI's count-based assumptions

**What we're doing:**
- `counts` layer: Raw MFI → given to scVI for modeling
- `clr_normalized` layer: CLR-transformed → used for visualization/DE testing (after scVI)
- `X`: Scaled CLR → used for PCA visualization if needed

**Alternative Approach (if skeptical of using raw MFI with scVI):**
- Skip scVI entirely
- Use: CLR → PCA → Harmony/BBKNN for batch correction → neighbors → UMAP → clustering
- This is more standard for protein data but loses scVI's advantages (better batch correction, latent space)

**Bottom Line:**
Raw MFI as "counts" for scVI is the correct approach, even though CODEX MFI is not exactly like RNA-seq UMI counts. The model is robust enough to handle it, and this is the standard practice in the field (see CITE-seq analysis tutorials).

In [None]:
# OPTIONAL: Check the distribution of raw MFI vs CLR-normalized data
# This helps understand what scVI is seeing vs what we use for visualization

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot raw MFI distribution (what scVI sees)
axes[0].hist(adata.layers['counts'].flatten(), bins=100, alpha=0.7, color='blue', edgecolor='black')
axes[0].set_xlabel('Raw MFI Value')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Raw MFI Distribution (Input to scVI)')
axes[0].set_yscale('log')

# Plot CLR-normalized distribution (what we use for visualization)
axes[1].hist(adata.layers['clr_normalized'].flatten(), bins=100, alpha=0.7, color='green', edgecolor='black')
axes[1].set_xlabel('CLR-Normalized Value')
axes[1].set_ylabel('Frequency')
axes[1].set_title('CLR-Normalized Distribution (For Visualization)')

plt.tight_layout()
plt.show()

print("Left: Raw MFI (input to scVI) - count-like, right-skewed distribution")
print("Right: CLR-normalized (for visualization) - centered at 0, more normal distribution")

In [None]:
# For protein data, we typically don't do HVG selection like with RNA
# CODEX panels are already curated (~30-60 protein markers)
# All proteins are informative, so we keep all of them

print(f"Using all {adata.n_vars} protein markers for scVI")
print(f"Total cells: {adata.n_obs:,}")
print(f"\nImage IDs (donors) for batch correction:")
print(adata.obs['imageid'].value_counts())
print(f"\nDonor Status breakdown:")
print(adata.obs['Donor Status'].value_counts())

## 3. Setup scVI Model

Configure scVI for protein expression data with batch correction by imageid (donor)

In [None]:
# Setup scVI with imageid as the batch key
# This corrects for donor-specific technical variation
# 
# IMPORTANT: We're using the "counts" layer which contains RAW MFI values
# scVI expects raw count-like data and performs its own normalization internally

scvi.model.SCVI.setup_anndata(
    adata,
    layer="counts",                           # RAW MFI counts (not CLR-normalized!)
    categorical_covariate_keys=["Age", "Gender"],   # Batch correction for each donor/image
    continuous_covariate_keys=None            # No continuous covariates (see Section 10 for why)
)

print("scVI setup complete:")
print(f"  Using counts layer (RAW MFI): {adata.layers['counts'].shape}")
print(f"  Batch key: imageid ({adata.obs['imageid'].nunique()} unique donors/images)")
print(f"  Images: {adata.obs['imageid'].unique().tolist()}")
print(f"  Total cells: {adata.n_obs:,}")
print(f"  Total proteins: {adata.n_vars}")
print(f"\nNote: scVI will internally normalize and model the raw MFI distribution")

## 4. Configure Hardware for Training

In [None]:
# Configure PyTorch for optimal B200 GPU performance
print("Hardware Configuration:")
print(f"  CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"  CUDA Version: {torch.version.cuda}")
    
    # Optimize for B200
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    
# CPU configuration
print(f"\nCPU Configuration:")
print(f"  CPU cores available: {os.cpu_count()}")
torch.set_num_threads(64)
print(f"  Set PyTorch to use 48 threads")

In [None]:
# Create scVI model optimized for protein data
# For protein data (smaller feature space than RNA), we can use smaller latent dimensions

model = scvi.model.SCVI(
    adata,
    n_latent=10,                    # Smaller latent space for protein data (vs 30 for RNA)
    n_layers=2,                     # Neural network depth
    n_hidden=128,                   # Hidden layer size
    dropout_rate=0.1,               # Regularization
    dispersion='gene-batch',        # Gene-batch dispersion (accounts for batch effects in protein variance)
    gene_likelihood='nb',           # Negative binomial for MFI data
)

print("scVI model created for CODEX protein data:")
print(f"  Latent dimensions: 15 (optimized for ~30-60 proteins)")
print(f"  Hidden layers: 128 nodes")
print(f"  Dispersion: gene-batch (accounts for donor-specific variance)")
print(f"  Batch correction: multiple donors via imageid")
print(model)

## 5. Train the Model

In [None]:
# Train the model with optimized settings for your hardware
model.train(
    max_epochs=400,                  # Sufficient for convergence
    accelerator="gpu",               # Use B200 GPU
    devices=1,                       # Single GPU
    batch_size=2048,                 # Large batch for B200's memory
    train_size=0.9,                  # 90% training, 10% validation
    early_stopping=True,             # Stop if validation loss plateaus
    early_stopping_patience=45,      # Wait 45 epochs before stopping
    check_val_every_n_epoch=1,       # Validate every epoch
    plan_kwargs={
        "lr": 0.001,                 # Learning rate
    },
    datasplitter_kwargs={
        "num_workers": 64,           # Use 32 CPU cores for data loading
        "pin_memory": True           # Pin memory for faster GPU transfer
    }
)

print("\n✓ Training complete on B200 GPU")
print("Model trained with batch correction for donor-specific effects (imageid)")

In [None]:
# Save the trained model
model_dir = os.path.join(save_dir, "CODEX_scvi_model_v2")
model.save(model_dir, overwrite=True)

print(f"✓ Model saved to: {model_dir}")
print(f"Model can be reloaded with: scvi.model.SCVI.load('{model_dir}', adata)")

## 6. Extract Latent Representation and Compute Neighbors

In [None]:
model_dir = os.path.join(save_dir, "CODEX_scvi_model_v2")
model = scvi.model.SCVI.load(model_dir, adata)

In [None]:
# Extract latent representation from the trained model
SCVI_LATENT_KEY = "X_scVI"

latent = model.get_latent_representation()
adata.obsm["X_scVI"] = latent

print(f"Latent representation extracted:")
print(f"  Shape: {latent.shape}")
print(f"  Stored in adata.obsm['{SCVI_LATENT_KEY}']")

In [None]:
# Compute neighbors using scVI latent representation
sc.pp.neighbors(
    adata, 
    use_rep=SCVI_LATENT_KEY,
    n_neighbors=10,
    n_pcs=None,
    metric='cosine'
)

print(f"Computed neighbors on scVI latent space")
print(f"  Latent dimensions: {latent.shape[1]}")
print(f"  Number of neighbors: 10")
print(f"  Metric: cosine")

In [None]:
# Optional: Check PCA on gene expression for comparison
# (This is just for visualization - don't use for downstream analysis)
# sc.tl.pca(adata, n_comps=30)

# Visualize how well tissues integrate in scVI vs PCA space
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# PCA space (before batch correction)
from sklearn.decomposition import PCA
pca_viz = PCA(n_components=2)
pca_coords = pca_viz.fit_transform(adata.X)

axes[0].scatter(pca_coords[:, 0], pca_coords[:, 1], 
               c=adata.obs['imageid'].astype('category').cat.codes, 
               s=1, alpha=0.5, cmap='tab10')
axes[0].set_title('PCA Space (Before Correction)\nDonors may be separated')
axes[0].set_xlabel('PC1')
axes[0].set_ylabel('PC2')

# scVI space (after batch correction)
from sklearn.decomposition import PCA as PCA2
pca_scvi = PCA2(n_components=2)
scvi_coords = pca_scvi.fit_transform(latent)

scatter = axes[1].scatter(scvi_coords[:, 0], scvi_coords[:, 1],
                         c=adata.obs['imageid'].astype('category').cat.codes,
                         s=1, alpha=0.5, cmap='tab10')
axes[1].set_title('scVI Latent Space (After Correction)\nBiological variation preserved')
axes[1].set_xlabel('Latent Dim 1')
axes[1].set_ylabel('Latent Dim 2')

# Add legend
# donors = adata.obs['imageid'].unique()
# handles = [plt.Line2D([0], [0], marker='o', color='w', 
#                      markerfacecolor=plt.cm.tab10(i/len(donors)), 
#                      markersize=8, label=donors) 
#           for i, donor in enumerate(donors)]


plt.tight_layout()
plt.show()

print("scVI integrates tissues while preserving biological differences")

In [None]:
# Compute UMAP for visualization
sc.tl.umap(
    adata,
    min_dist=0.05,     
    spread=2.0,     
    init_pos='spectral' 
)

print("UMAP computed")

In [None]:
# Compute Leiden clustering at multiple resolutions
for res in [0.50, 1.00, 1.50, 2.00]:
    key = f"leiden_res_{res:.2f}"
    sc.tl.leiden(
        adata, 
        key_added=key, 
        resolution=res, 
        flavor="igraph"
    )
    n_clusters = adata.obs[key].nunique()
    print(f"  Resolution {res:.2f}: {n_clusters} clusters")
    adata.obs[key] = adata.obs[key].astype('category')

print("\n✓ Clustering complete")

In [None]:
adata.obs['donor_id'] = adata.obs['imageid'].astype(str) + '_' + adata.obs['Donor Status'].astype(str)
adata.obs['donor_id'].unique()

In [None]:
cell_type_markers = {

    'Beta cells': ['INS'],
    'Alpha cells': ['GCG'],
    'Delta cells': ['SST'],
    
    'Acinar cells': ['BActin'],
    'Ductal cells': ['CK19', 'panCK'],

    'T cells CD8+': ['CD8a', 'CD3e'],
    'CD4+': ['CD4'],
    'B cells': ['CD20'],
    'Macrophages': ['CD68', 'CD163', 'LGALS3'],
    'Stroma':['VIM'],
    'APCs': ['HLADR'],

    'Vasculature': ['CD31', 'CD34'],
    'Lymphatic': ['PDPN'],
    'Fibroblasts': ['ColIV'],
    'Smooth Muscle': ['SMA'],
    
    'Neural': ['PGP9.5','B3TUBB', 'GAP43'],
}

In [None]:
cluster_keys = [key for key in adata.obs.columns if key.startswith('leiden_res_')]
if not cluster_keys:
    raise ValueError("No Leiden clustering columns found in adata.obs")

cluster_summaries = []
for key in sorted(cluster_keys):
    counts = adata.obs[key].value_counts().sort_index()
    percents = counts / counts.sum() * 100
    cluster_summaries.append(
        pd.DataFrame(
            {
                "resolution": key.replace("leiden_res_", ""),
                "cluster": counts.index.astype(str),
                "n_cells": counts.values,
                "pct_cells": percents.values,
            }
        )
    )

summary_df = pd.concat(cluster_summaries, ignore_index=True)
summarized = summary_df.groupby("resolution").agg(
    clusters=("cluster", "nunique"),
    min_cells=("n_cells", "min"),
    median_cells=("n_cells", "median"),
    max_cells=("n_cells", "max"),
)

summarized

In [None]:
annotation_resolution = "leiden_res_1.00"

if annotation_resolution not in adata.obs:
    raise KeyError(f"Resolution {annotation_resolution} not present. Available: {sorted(cluster_keys)}")

cluster_means = None
if "rank_genes_groups" in adata.uns:
    try:
        # Capture differential expression results if they exist for downstream review.
        cluster_means = sc.get.rank_genes_groups_df(adata, key="rank_genes_groups")
    except (KeyError, ValueError, TypeError):
        cluster_means = None

scaled = adata.layers.get("scaled", None)
if scaled is None:
    scaled = sc.pp.scale(adata, zero_center=True, copy=True).X
    adata.layers["scaled"] = scaled

scaled_df = pd.DataFrame(
    scaled,
    index=adata.obs.index,
    columns=adata.var_names,
 )

cluster_profiles = (
    scaled_df.join(adata.obs[[annotation_resolution]])
    .groupby(annotation_resolution)
    .mean()
    .sort_index()
)

annotation_table = []
for cluster, profile in cluster_profiles.iterrows():
    marker_scores = {}
    for label, markers in cell_type_markers.items():
        present = [m for m in markers if m in profile.index]
        if not present:
            continue
        marker_scores[label] = profile[present].mean()
    top_marker = max(marker_scores, key=marker_scores.get) if marker_scores else "Unknown"
    annotation_table.append(
        {
            "cluster": cluster,
            "top_marker": top_marker,
            "score": marker_scores.get(top_marker, np.nan),
            "marker_scores": marker_scores,
        }
    )

annotation_df = pd.DataFrame(annotation_table).sort_values("cluster").reset_index(drop=True)
annotation_df

In [None]:
import plotly.express as px

fig = px.box(
    summary_df,
    x="resolution",
    y="pct_cells",
    points="all",
    color="resolution",
    labels={"pct_cells": "Cluster size (%)"},
    title="Cluster size distribution per resolution",
)
fig.show()

In [None]:
adata.obs["draft_cell_type"] = adata.obs[annotation_resolution].map(
    annotation_df.set_index("cluster")["top_marker"]
)
adata.obs["draft_cell_type"].value_counts().sort_values(ascending=False).head(10)

In [None]:
adata.write_h5ad(os.path.join(base_dir, 'CODEX_panc_scvi_BioCov.h5ad'))

## 8. Visualize Results

In [None]:
sc.set_figure_params(dpi=100, dpi_save=300, fontsize=26)
plt.rcParams['legend.markerscale'] = 4.0 

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(25,18))

sc.pl.umap(
    adata,
    color='leiden_res_1.00',
    ax=axes[0,0],
    show=False,
    frameon=False,     size=1,
    alpha=0.4, wspace=0.4,
    legend_loc='right margin'
)
sc.pl.umap(
    adata,
    color='donor_id',
    ax=axes[0,1],
    show=False,
    frameon=False,     size=1,
    alpha=0.4, wspace=0.4,
    legend_loc='right margin'
)
sc.pl.umap(
    adata,
    color='draft_cell_type',
    ax=axes[1,0],
    show=False,
    frameon=False,     size=1,
    alpha=0.4, wspace=0.4,
    legend_loc='right margin'
)
sc.pl.umap(
    adata,
    color='Donor Status',
    ax=axes[1,1],
    show=False,
    frameon=False,     size=1,
    alpha=0.4, wspace=0.4,
    legend_loc='right margin'
)
plt.tight_layout()
plt.show()

In [None]:
# Visualize biological covariates on UMAP
# These were preserved by scVI (not corrected) so we can study their effects

# Prepare figure with multiple panels
fig, axes = plt.subplots(2, 3, figsize=(24, 16))

# Plot Donor Status
sc.pl.umap(adata, color='Donor Status', ax=axes[0,0], show=False, 
           title='Donor Status (Biological)', frameon=False, size=3)

# Plot imageid (donor identification numbers - should be well-mixed after batch correction)
# Make sure imageid is categorical for discrete colors
if not isinstance(adata.obs['imageid'].dtype, pd.CategoricalDtype):
    adata.obs['imageid'] = adata.obs['imageid'].astype('category')
sc.pl.umap(adata, color='imageid', ax=axes[0,1], show=False,
           title='Donor ID (individual donors)', frameon=False, size=3, palette='tab10')

# Plot age if available
if 'Age' in adata.obs.columns:
    sc.pl.umap(adata, color='Age', ax=axes[0,2], show=False,
               title='Age (Biological)', frameon=False, size=3, cmap='viridis')
elif 'age' in adata.obs.columns:
    sc.pl.umap(adata, color='age', ax=axes[0,2], show=False,
               title='Age (Biological)', frameon=False, size=3, cmap='viridis')
else:
    axes[0,2].text(0.5, 0.5, 'Age data not available', ha='center', va='center')
    axes[0,2].axis('off')

# Plot gender if available
gender_cols = [col for col in adata.obs.columns if 'gender' in col.lower() or 'sex' in col.lower()]
if gender_cols:
    sc.pl.umap(adata, color=gender_cols[0], ax=axes[1,0], show=False,
               title=f'{gender_cols[0]} (Biological)', frameon=False, size=3)
else:
    axes[1,0].text(0.5, 0.5, 'Gender data not available', ha='center', va='center')
    axes[1,0].axis('off')

# Plot autoantibody status if available
# Note: Each Aab column is True/False
if 'GADA' in adata.obs.columns:
    sc.pl.umap(adata, color='GADA', ax=axes[1,1], show=False,
               title='GADA+ (True=has GADA autoantibody)', frameon=False, size=3)
else:
    axes[1,1].text(0.5, 0.5, 'GADA data not available', ha='center', va='center')
    axes[1,1].axis('off')

if 'Aab_status' in adata.obs.columns:
    # If we created the combined status, show that
    sc.pl.umap(adata, color='Aab_status', ax=axes[1,2], show=False,
               title='Combined Aab Status', frameon=False, size=3)
elif 'ZnT8A' in adata.obs.columns:
    sc.pl.umap(adata, color='ZnT8A', ax=axes[1,2], show=False,
               title='ZnT8A+ (True=has ZnT8A autoantibody)', frameon=False, size=3)
else:
    axes[1,2].text(0.5, 0.5, 'Aab status not available', ha='center', va='center')
    axes[1,2].axis('off')

plt.tight_layout()
plt.show()

print("\n✓ Biological covariates visualized on batch-corrected UMAP")
print("These signals were PRESERVED by scVI (not removed)")

In [None]:
# Visualize biological covariates on UMAP
# These were preserved by scVI (not corrected) so we can study their effects

# Prepare figure with multiple panels
fig, axes = plt.subplots(2, 3, figsize=(24, 16))

# Plot Donor Status
sc.pl.umap(adata, color='Donor Status', ax=axes[0,0], show=False, 
           title='Donor Status (Biological)', frameon=False, size=3)

# Plot imageid (donor identification numbers - should be well-mixed after batch correction)
# Make sure imageid is categorical for discrete colors
if not isinstance(adata.obs['imageid'].dtype, pd.CategoricalDtype):
    adata.obs['imageid'] = adata.obs['imageid'].astype('category')
sc.pl.umap(adata, color='imageid', ax=axes[0,1], show=False,
           title='Donor ID (individual donors)', frameon=False, size=3, palette='tab10')

# Plot age if available
if 'Age' in adata.obs.columns:
    sc.pl.umap(adata, color='Age', ax=axes[0,2], show=False,
               title='Age (Biological)', frameon=False, size=3, cmap='viridis')
elif 'age' in adata.obs.columns:
    sc.pl.umap(adata, color='age', ax=axes[0,2], show=False,
               title='Age (Biological)', frameon=False, size=3, cmap='viridis')
else:
    axes[0,2].text(0.5, 0.5, 'Age data not available', ha='center', va='center')
    axes[0,2].axis('off')

# Plot gender if available
gender_cols = [col for col in adata.obs.columns if 'gender' in col.lower() or 'sex' in col.lower()]
if gender_cols:
    sc.pl.umap(adata, color=gender_cols[0], ax=axes[1,0], show=False,
               title=f'{gender_cols[0]} (Biological)', frameon=False, size=3)
else:
    axes[1,0].text(0.5, 0.5, 'Gender data not available', ha='center', va='center')
    axes[1,0].axis('off')

# Plot autoantibody status if available
# Note: Each Aab column is True/False
if 'GADA' in adata.obs.columns:
    sc.pl.umap(adata, color='GADA', ax=axes[1,1], show=False,
               title='GADA+ (True=has GADA autoantibody)', frameon=False, size=3)
else:
    axes[1,1].text(0.5, 0.5, 'GADA data not available', ha='center', va='center')
    axes[1,1].axis('off')

if 'Aab_status' in adata.obs.columns:
    # If we created the combined status, show that
    sc.pl.umap(adata, color='Aab_status', ax=axes[1,2], show=False,
               title='Combined Aab Status', frameon=False, size=3)
elif 'ZnT8A' in adata.obs.columns:
    sc.pl.umap(adata, color='ZnT8A', ax=axes[1,2], show=False,
               title='ZnT8A+ (True=has ZnT8A autoantibody)', frameon=False, size=3)
else:
    axes[1,2].text(0.5, 0.5, 'Aab status not available', ha='center', va='center')
    axes[1,2].axis('off')

plt.tight_layout()
plt.show()

print("\n✓ Biological covariates visualized on batch-corrected UMAP")
print("These signals were PRESERVED by scVI (not removed)")

In [None]:
# Visualize multiple clustering resolutions
fig, axes = plt.subplots(2, 2, figsize=(20, 20))

sc.pl.umap(adata, color='Donor Status', ax=axes[0,0], show=False, 
           title='Donor Status (ND/Aab+/T1D)', frameon=False, size=3, alpha=0.3)

sc.pl.umap(adata, color='leiden_res_0.50', ax=axes[0,1], show=False,
           title='leiden_res_0.50', frameon=False, size=3, alpha=0.3,
           legend_loc='right margin')

sc.pl.umap(adata, color='leiden_res_1.00', ax=axes[1,0], show=False,
           title='leiden_res_1.00', frameon=False, size=3, alpha=0.3,
           legend_loc='right margin')

sc.pl.umap(adata, color='leiden_res_1.50', ax=axes[1,1], show=False,
           title='leiden_res_1.50', frameon=False, size=3, alpha=0.3,
           legend_loc='right margin')

plt.tight_layout()
plt.show()

## 9. Differential Expression and Marker Analysis

In [None]:
# Compute differential protein expression for clusters
# Choose a resolution for detailed analysis
sc.tl.rank_genes_groups(adata, groupby="leiden_res_1.00", method="wilcoxon")

print("Differential protein expression computed")
print("Use sc.pl.rank_genes_groups(adata) to visualize top markers per cluster")

In [None]:
# Visualize top marker proteins per cluster
sc.pl.rank_genes_groups(adata, sharey=False)

In [None]:
# Create a dotplot of key protein markers (customize based on your CODEX panel)
# Example markers - replace with your actual CODEX protein markers
protein_markers = adata.var_names.tolist()  # Use all proteins or select specific ones

sc.pl.dotplot(
    adata, 
    var_names=protein_markers, 
    groupby='leiden_res_1.00', 
    standard_scale='var',
    figsize=(20, 8)
)

## 10. Biological Covariate Analysis (Age, Gender, Autoantibodies)

**Important**: These biological covariates were NOT included in the scVI model. The model only corrected for technical batch effects (`imageid`). This means the biological signals from age, gender, and autoantibodies are preserved and can now be analyzed.

In [None]:
# Visualize biological covariates on UMAP
# These were preserved by scVI (not corrected) so we can study their effects

# Prepare figure with multiple panels
fig, axes = plt.subplots(2, 3, figsize=(24, 16))

# Plot Donor Status
sc.pl.umap(adata, color='Donor Status', ax=axes[0,0], show=False, 
           title='Donor Status (Biological)', frameon=False, size=3)

# Plot imageid (donor identification numbers - should be well-mixed after batch correction)
# Make sure imageid is categorical for discrete colors
if not isinstance(adata.obs['imageid'].dtype, pd.CategoricalDtype):
    adata.obs['imageid'] = adata.obs['imageid'].astype('category')
sc.pl.umap(adata, color='imageid', ax=axes[0,1], show=False,
           title='Donor ID (individual donors)', frameon=False, size=3, palette='tab10')

# Plot age if available
if 'Age' in adata.obs.columns:
    sc.pl.umap(adata, color='Age', ax=axes[0,2], show=False,
               title='Age (Biological)', frameon=False, size=3, cmap='viridis')
elif 'age' in adata.obs.columns:
    sc.pl.umap(adata, color='age', ax=axes[0,2], show=False,
               title='Age (Biological)', frameon=False, size=3, cmap='viridis')
else:
    axes[0,2].text(0.5, 0.5, 'Age data not available', ha='center', va='center')
    axes[0,2].axis('off')

# Plot gender if available
gender_cols = [col for col in adata.obs.columns if 'gender' in col.lower() or 'sex' in col.lower()]
if gender_cols:
    sc.pl.umap(adata, color=gender_cols[0], ax=axes[1,0], show=False,
               title=f'{gender_cols[0]} (Biological)', frameon=False, size=3)
else:
    axes[1,0].text(0.5, 0.5, 'Gender data not available', ha='center', va='center')
    axes[1,0].axis('off')

# Plot autoantibody status if available
# Note: Each Aab column is True/False
if 'GADA' in adata.obs.columns:
    sc.pl.umap(adata, color='GADA', ax=axes[1,1], show=False,
               title='GADA+ (True=has GADA autoantibody)', frameon=False, size=3)
else:
    axes[1,1].text(0.5, 0.5, 'GADA data not available', ha='center', va='center')
    axes[1,1].axis('off')

if 'Aab_status' in adata.obs.columns:
    # If we created the combined status, show that
    sc.pl.umap(adata, color='Aab_status', ax=axes[1,2], show=False,
               title='Combined Aab Status', frameon=False, size=3)
elif 'ZnT8A' in adata.obs.columns:
    sc.pl.umap(adata, color='ZnT8A', ax=axes[1,2], show=False,
               title='ZnT8A+ (True=has ZnT8A autoantibody)', frameon=False, size=3)
else:
    axes[1,2].text(0.5, 0.5, 'Aab status not available', ha='center', va='center')
    axes[1,2].axis('off')

plt.tight_layout()
plt.show()

print("\n✓ Biological covariates visualized on batch-corrected UMAP")
print("These signals were PRESERVED by scVI (not removed)")

In [None]:
# Differential expression analysis between Donor Status groups
# Compare protein expression in ND vs Aab+ vs T1D

print("="*60)
print("Differential Protein Expression: ND vs Aab+ vs T1D")
print("="*60)

sc.tl.rank_genes_groups(
    adata, 
    groupby='Donor Status', 
    method='wilcoxon',
    key_added='donor_status_de'
)

print("\n✓ Differential expression computed")
print("Use sc.pl.rank_genes_groups(adata, key='donor_status_de') to visualize")

In [None]:
# Visualize top differentially expressed proteins between disease groups
sc.pl.rank_genes_groups(adata, key='donor_status_de', sharey=False)

In [None]:
# Analyze autoantibody-positive vs negative cells
# NOTE: Each Aab column is True/False, and donors can have multiple Aabs
# 'None' column indicates donors with NO autoantibodies

autoantibody_cols = ['GADA', 'ZnT8A', 'IA2A', 'mIAA', 'None']
found_ab = [col for col in autoantibody_cols if col in adata.obs.columns]

if found_ab:
    print("="*60)
    print("Autoantibody Analysis")
    print("="*60)
    
    # Overall summary
    print("\nAutoantibody Status Summary:")
    for ab in found_ab:
        n_positive = adata.obs[ab].sum()
        pct_positive = n_positive / adata.n_obs * 100
        print(f"  {ab}: {n_positive:,} cells ({pct_positive:.1f}%)")
    
    # Check for multiple autoantibodies
    if 'None' in found_ab:
        ab_types = [col for col in found_ab if col != 'None']
        print(f"\nCells with NO autoantibodies (None=True): {adata.obs['None'].sum():,}")
        
        if ab_types:
            # Count cells with multiple Aabs
            n_multiple = (adata.obs[ab_types].sum(axis=1) > 1).sum()
            n_single = (adata.obs[ab_types].sum(axis=1) == 1).sum()
            print(f"Cells with exactly 1 autoantibody: {n_single:,}")
            print(f"Cells with multiple autoantibodies: {n_multiple:,}")
            
            # Show combinations
            print("\nAutoantibody combinations (cells with multiple):")
            for idx, row in adata.obs[adata.obs[ab_types].sum(axis=1) > 1].head(10).iterrows():
                positive_abs = [ab for ab in ab_types if row[ab]]
                print(f"  Cell {idx}: {', '.join(positive_abs)}")
    
    # Distribution by Donor Status
    print("\n" + "="*60)
    print("Distribution by Donor Status (ND/Aab+/T1D):")
    print("="*60)
    for ab in found_ab:
        print(f"\n{ab}:")
        status_counts = adata.obs.groupby('Donor Status')[ab].value_counts()
        for (status, ab_val), count in status_counts.items():
            if ab_val:  # Only show True values
                total_in_status = (adata.obs['Donor Status'] == status).sum()
                pct = count / total_in_status * 100
                print(f"  {status}: {count:,} cells ({pct:.1f}% of {status} cells)")
    
    # Differential expression for each autoantibody type
    print("\n" + "="*60)
    print("Differential Expression Analysis")
    print("="*60)
    
    for ab in found_ab:
        if ab != 'None' and adata.obs[ab].sum() > 0:  # Skip 'None' and empty columns
            n_positive = adata.obs[ab].sum()
            n_negative = (~adata.obs[ab]).sum()
            
            # Only run DE if we have enough cells in both groups
            if n_positive >= 50 and n_negative >= 50:
                print(f"\nComputing DE: {ab}+ ({n_positive:,} cells) vs {ab}- ({n_negative:,} cells)")
                
                # Convert boolean to categorical for scanpy
                ab_cat_col = f'{ab}_categorical'
                adata.obs[ab_cat_col] = adata.obs[ab].astype(str).astype('category')
                
                sc.tl.rank_genes_groups(
                    adata, 
                    groupby=ab_cat_col, 
                    method='wilcoxon',
                    key_added=f'{ab.lower()}_de'
                )
                print(f"  ✓ Saved as '{ab.lower()}_de'")
                print(f"  Use: sc.pl.rank_genes_groups(adata, key='{ab.lower()}_de')")
            else:
                print(f"\n{ab}: Skipping DE (need ≥50 cells per group, have {n_positive} positive, {n_negative} negative)")
    
    # Create a combined Aab status column for easier visualization
    print("\n" + "="*60)
    print("Creating combined autoantibody status column")
    print("="*60)
    
    if 'None' in found_ab:
        ab_types = [col for col in found_ab if col != 'None']
        
        def get_aab_status(row):
            if row['None']:
                return 'None'
            positive = [ab for ab in ab_types if row[ab]]
            if len(positive) == 0:
                return 'Unknown'
            elif len(positive) == 1:
                return positive[0]
            else:
                return 'Multiple'
        
        adata.obs['Aab_status'] = adata.obs.apply(get_aab_status, axis=1).astype('category')
        
        print("Created 'Aab_status' column with categories:")
        print(adata.obs['Aab_status'].value_counts())
        print("\nYou can now visualize with: sc.pl.umap(adata, color='Aab_status')")
else:
    print("No autoantibody columns found in data")

In [None]:
# Visualize autoantibody status on UMAP
if 'Aab_status' in adata.obs.columns:
    fig, axes = plt.subplots(1, 2, figsize=(20, 8))
    
    # Plot combined Aab status
    sc.pl.umap(adata, color='Aab_status', ax=axes[0], show=False,
               title='Autoantibody Status', frameon=False, size=3, alpha=0.5)
    
    # Plot Donor Status for comparison
    sc.pl.umap(adata, color='Donor Status', ax=axes[1], show=False,
               title='Donor Status (ND/Aab+/T1D)', frameon=False, size=3, alpha=0.5)
    
    plt.tight_layout()
    plt.show()
    
    print("\nAutoantibody status overlaid on batch-corrected UMAP")
    print("This shows how Aab+ status relates to disease progression")
    
    # Cross-tabulation
    print("\n" + "="*60)
    print("Autoantibody Status × Donor Status Cross-tabulation")
    print("="*60)
    crosstab = pd.crosstab(adata.obs['Aab_status'], adata.obs['Donor Status'], normalize='columns')
    print(crosstab.to_string())
    
    print("\nInterpretation:")
    print("  - 'None' autoantibodies should be enriched in ND (non-diabetic)")
    print("  - Single or multiple Aabs should be enriched in Aab+ and T1D")
    print("  - This validates that the biological signal is preserved after batch correction")

In [None]:
# Visualize each autoantibody type individually on UMAP
autoantibody_cols = ['GADA', 'ZnT8A', 'IA2A', 'mIAA', 'None']
found_ab = [col for col in autoantibody_cols if col in adata.obs.columns]

if len(found_ab) > 0:
    n_cols = min(3, len(found_ab))
    n_rows = (len(found_ab) + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(8*n_cols, 6*n_rows))
    if n_rows == 1:
        axes = axes.reshape(1, -1) if len(found_ab) > 1 else np.array([[axes]])
    
    for idx, ab in enumerate(found_ab):
        row = idx // n_cols
        col = idx % n_cols
        ax = axes[row, col] if n_rows > 1 else axes[0, col] if len(found_ab) > 1 else axes[0, 0]
        
        n_positive = adata.obs[ab].sum()
        pct = n_positive / adata.n_obs * 100
        
        sc.pl.umap(adata, color=ab, ax=ax, show=False,
                   title=f'{ab} (n={n_positive:,}, {pct:.1f}%)',
                   frameon=False, size=2, alpha=0.6)
    
    # Hide empty subplots
    for idx in range(len(found_ab), n_rows * n_cols):
        row = idx // n_cols
        col = idx % n_cols
        ax = axes[row, col] if n_rows > 1 else axes[0, col]
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print("Individual autoantibody types on UMAP")
    print("True (red/orange) = cells/donors with this autoantibody")
    print("False (blue) = cells/donors without this autoantibody")
else:
    print("No autoantibody columns found")

### Summary: Why Biological Covariates Were Not Included in scVI

**The Key Principle:**
- scVI **removes/corrects** the variation from any covariate you give it
- We gave it `imageid` (donor batch) → removed technical donor effects ✓
- We did NOT give it age, gender, or autoantibodies → preserved biological signal ✓

**What we did:**
1. ✅ Batch correction: `imageid` → removes donor-specific technical noise
2. ✅ Biological preservation: age, gender, autoantibodies → kept for downstream analysis

**What would have happened if we included them in scVI:**
- ❌ scVI would try to "correct out" age differences → lose age signal
- ❌ scVI would try to "correct out" gender differences → lose sex-specific biology
- ❌ scVI would try to "correct out" autoantibody effects → lose disease-related signals

**The correct workflow:**
1. Train scVI with ONLY technical batch variables (`imageid`)
2. Extract batch-corrected latent space
3. Perform clustering and visualization on corrected data
4. **Then** analyze biological covariates (age, gender, autoantibodies) on the corrected data

This way, you get the best of both worlds:
- Clean data (no technical batch effects)
- Preserved biology (can still study disease differences)

In [None]:
import anndata as ad
adata=ad.read_h5ad(os.path.join(base_dir, 'CODEX_panc_scvi.h5ad'))

In [None]:
adata

## 11. Save Processed Data

In [None]:
# Save the processed AnnData object with scVI latent representation
output_file = os.path.join(base_dir, 'Pancreas_CODEX_processed_with_scVI.h5ad')
adata.write_h5ad(output_file)

print(f"✓ Processed data saved to: {output_file}")
print(f"\nSummary:")
print(f"  Cells: {adata.n_obs:,}")
print(f"  Proteins: {adata.n_vars}")
print(f"  Images (donors): {adata.obs['imageid'].nunique()}")
print(f"  Donor Status groups: {adata.obs['Donor Status'].unique().tolist()}")
print(f"  scVI latent dimensions: {adata.obsm[SCVI_LATENT_KEY].shape[1]}")
print(f"  Clustering resolutions: {[col for col in adata.obs.columns if 'leiden' in col]}")

## Next Steps

1. **Cell Type Annotation**: Use the differential protein expression results to annotate clusters based on known cell type markers
2. **Spatial Analysis**: Integrate spatial coordinates if available to analyze tissue organization
3. **Donor Group Comparison**: Compare cell type distributions and protein expression across ND/Aab+/T1D groups
4. **Downstream Analysis**: Perform statistical tests to identify disease-associated changes

In [None]:
primary_markers = {
    'Beta cells': 'INS',
    'Alpha cells': 'GCG',
    'Delta cells': 'SST',
    'Ductal cells': 'CK19',
    'T cells': 'CD3e',
    'B cells': 'CD20',
    'Macrophages': 'CD68',
    'Endothelial': 'CD31',
    'Stromal': 'VIM',
}

In [None]:
adata.layers["scaled"] = sc.pp.scale(adata, zero_center=True, copy=True).X

In [None]:
%matplotlib inline

In [None]:
sc.pl.dotplot(adata, primary_markers, groupby='leiden_res_1.00',
              figsize=(12, 6),  layer="scaled", cmap='RdBu_r', vmin=-2, vmax=2)

In [None]:
cell_type_markers = {
    # Endocrine cells (Islet cells)
    'Beta cells': ['INS', 'CHGA', 'NaKATPase'],
    'Alpha cells': ['GCG', 'CHGA'],
    'Delta cells': ['SST', 'CHGA'],
    
    # Exocrine cells
    'Acinar cells': ['BActin', 'ECAD'],
    'Ductal cells': ['CK19', 'panCK', 'KRT14', 'ECAD'],
    
    # Immune cells - T cells
    'T cells CD8+': ['CD8a', 'CD3e', 'CD45'],
    'T cells CD4+': ['CD4', 'CD3e', 'CD45'],
    
    # Immune cells - B cells and myeloid
    'B cells': ['CD20', 'CD45'],
    'Macrophages': ['CD68', 'CD163', 'CD45'],
    'Antigen Presenting Cells': ['HLADR', 'CD45'],
    
    # Stromal cells
    'Endothelial cells': ['CD31', 'CD34', 'ECAD'],
    'Fibroblasts': ['VIM', 'ColIV'],
    'Pericytes': ['SMA', 'VIM', 'CD44'],
    
    # Neural cells
    'Neurons': ['PGP9.5', 'GAP43', 'B3TUBB'],
}

In [None]:
sc.pl.dotplot(adata, cell_type_markers, groupby='leiden_res_2.00',
              figsize=(12, 6),  layer="scaled", cmap='RdBu_r', vmin=-2, vmax=2)