
# Subcluster analysis for MULTIVISPLICE models

This notebook:
1. Loads a trained MULTIVISPLICE model and a MuData file.
2. Creates an output folder stamped with date/time and model name.
3. Computes Leiden clusters on the **joint** latent space at a configurable resolution.
4. Plots **UMAP** and **t-SNE** of the joint space colored by clusters, and side-by-side with `medium_cell_type` (or `broad_cell_type` if missing).
5. Within a chosen cell type (e.g., *Excitatory neuron*), finds **subclusters** and runs **pairwise** Differential Expression (DE) and Differential Splicing (DS) between all subcluster pairs.
6. Summarizes significant DE/DS counts, overlaps (mapping junctions → genes), shows a few representative features, and
7. Plots UMAP of the selected cell type colored by **denoised PSI** (from the model) and **normalized expression**.

> Notes:
> - This assumes your model class exposes `get_normalized_expression`, `get_normalized_splicing`, `differential_expression`, and the new `differential_splicing` you just added.
> - It expects the splicing modality to be registered as `junc_ratio_layer` and accessible via the registry key used by your class (e.g., `REGISTRY_KEYS.JUNC_RATIO_X_KEY` inside scvi-tools setup).


In [None]:

# === Parameters ===
# Fill these in before running the notebook.

# Paths
TRAINED_MODEL_DIR = "/gpfs/commons/home/svaidyanathan/splice_vi_partial_vae_sweep/batch_20251020_102953/mouse_trainandtest_REAL_cd=32_mn=50000_ld=25_lr=1e-5_0_scatter_PartialEncoderEDDI_pool=sum/models"
MUDATA_PATH       = "/gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/MOUSE_SPLICING_FOUNDATION/MODEL_INPUT/072025/train_70_30_ge_splice_combined_20250730_164104.h5mu"

# Plot/output base dir (a timestamped subfolder is created under this)
BASE_OUTDIR = "/gpfs/commons/home/svaidyanathan/repos/multivi_tools_splicing/multivi_splice_utils/jupyter_notebooks/figures"

# Clustering
LEIDEN_RESOLUTION = 1.0

# Cell type column preference
PREFERRED_CELLTYPE_COL = "medium_cell_type"   # falls back to "broad_cell_type" if missing

# DE/DS settings
DE_DELTA = 0.25  # genes (on appropriate scale; keep your model's default meaning)
DS_DELTA = 0.10  # PSI threshold in [0,1]
FDR      = 0.05
BATCH_SIZE_POST = 512  # for decoding normalized values

# Target cell type to analyze subclusters in (exact match in obs)
TARGET_CELLTYPE = "Cortical excitatory neuron"

# How many top features to visualize from DE/DS for quick look
N_TOP_SHOW = 12


In [2]:

import os
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
import scanpy as sc
import scvi
import mudata as mu
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from sklearn.metrics import adjusted_mutual_info_score, silhouette_score
from sklearn.decomposition import PCA

# plotting defaults
sns.set_context("notebook")
sns.set_style("whitegrid")

print("scvi version:", scvi.__version__)
print("scanpy version:", sc.__version__)
print("torch version:", torch.__version__)


scvi version: 1.3.1
scanpy version: 1.10.4
torch version: 2.6.0+cu124


In [None]:

# === Load MuData and trained model ===
assert os.path.exists(MUDATA_PATH), f"MuData path not found: {MUDATA_PATH}"
mdata = mu.read_h5mu(MUDATA_PATH, backed="r")

# Required layer names
x_layer = "junc_ratio"
junction_counts_layer = "cell_by_junction_matrix"
cluster_counts_layer  = "cell_by_cluster_matrix"
mask_layer            = "psi_mask"

print("Splicing layers:", list(mdata["splicing"].layers.keys()))

# --- Subset BEFORE setup_mudata ---
# Choose columns
ct_key = "medium_cell_type" if "medium_cell_type" in mdata["rna"].obs else (
    "broad_cell_type" if "broad_cell_type" in mdata["rna"].obs else None
)
assert ct_key is not None, "No cell-type column (medium_cell_type / broad_cell_type) in RNA obs."

TARGET_CELLTYPE = "Excitatory neuron"  # exact label in your obs
TARGET_TISSUES = ["cortex"]            # case-insensitive match; set [] to ignore tissue

obs_rna = mdata["rna"].obs

mask = (obs_rna[ct_key] == TARGET_CELLTYPE)
if len(TARGET_TISSUES) > 0 and "tissue" in obs_rna:
    mask &= obs_rna["tissue"].astype(str).str.lower().isin([t.lower() for t in TARGET_TISSUES])

n_keep = int(mask.sum())
assert n_keep > 0, f"No cells matched {TARGET_CELLTYPE!r} with tissues={TARGET_TISSUES}."

# Apply to the whole MuData (cells only)
mdata = mdata[mask].copy()
print(f"Subset to {n_keep} cells for {TARGET_CELLTYPE} (tissues={TARGET_TISSUES})")


# Setup model registry for this MuData
scvi.model.MULTIVISPLICE.setup_mudata(
    mdata,
    batch_key=None,
    size_factor_key="X_library_size",
    rna_layer="length_norm",
    junc_ratio_layer=x_layer,
    atse_counts_layer=cluster_counts_layer,
    junc_counts_layer=junction_counts_layer,
    psi_mask_layer=mask_layer,
    modalities={"rna_layer": "rna", "junc_ratio_layer": "splicing"},
)

# Load the trained model from disk
assert os.path.isdir(TRAINED_MODEL_DIR), f"Model dir not found: {TRAINED_MODEL_DIR}"
model = scvi.model.MULTIVISPLICE.load(TRAINED_MODEL_DIR, mdata=mdata)

# Make timestamped output dir
model_name = Path(TRAINED_MODEL_DIR).name
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
OUTDIR = Path(BASE_OUTDIR) / f"{stamp}__{model_name}"
OUTDIR.mkdir(parents=True, exist_ok=True)
OUTDIR


In [None]:

# === Joint latent + neighbors ===
Z_joint = model.get_latent_representation(adata=mdata, modality="joint")
ad_rna = mdata["rna"].copy()  # use RNA AnnData as the canvas for neighbors/UMAP/TSNE
ad_rna.obsm["X_latent_joint"] = Z_joint

sc.pp.neighbors(ad_rna, use_rep="X_latent_joint", key_added="neighbors_joint")

# UMAP
sc.tl.umap(ad_rna, neighbors_key="neighbors_joint")
ad_rna.obsm["X_umap_joint"] = ad_rna.obsm["X_umap"].copy()

# t-SNE
sc.tl.tsne(ad_rna, use_rep="X_latent_joint", n_pcs=0, learning_rate=200.0, perplexity=30.0, random_state=0)
ad_rna.obsm["X_tsne_joint"] = ad_rna.obsm["X_tsne"].copy()

# pick cell type columns
ct_key = PREFERRED_CELLTYPE_COL if PREFERRED_CELLTYPE_COL in ad_rna.obs else ("broad_cell_type" if "broad_cell_type" in ad_rna.obs else None)
print("Cell type key:", ct_key)


In [None]:

# === Leiden clustering on joint space ===
sc.tl.leiden(ad_rna, neighbors_key="neighbors_joint", key_added="leiden_joint", resolution=LEIDEN_RESOLUTION)
print("n_clusters:", ad_rna.obs["leiden_joint"].nunique())
ad_rna.obs["leiden_joint"].value_counts().head()


In [None]:

# === Plots ===
# 1) UMAP colored by Leiden
fig, ax = plt.subplots(figsize=(8,6))
sc.pl.embedding(ad_rna, basis="X_umap_joint", color="leiden_joint", legend_loc="right margin", show=False, ax=ax, frameon=True)
ax.set_title(f"Joint UMAP — Leiden (res={LEIDEN_RESOLUTION})")
plt.tight_layout()
fig.savefig(OUTDIR / "umap_joint_leiden.png", dpi=300, bbox_inches="tight")
plt.show()

# 2) tSNE colored by Leiden
fig, ax = plt.subplots(figsize=(8,6))
sc.pl.embedding(ad_rna, basis="X_tsne_joint", color="leiden_joint", legend_loc="right margin", show=False, ax=ax, frameon=True)
ax.set_title(f"Joint tSNE — Leiden (res={LEIDEN_RESOLUTION})")
plt.tight_layout()
fig.savefig(OUTDIR / "tsne_joint_leiden.png", dpi=300, bbox_inches="tight")
plt.show()

# 3) Side-by-side UMAP vs tSNE colored by cell type
if ct_key is not None:
    fig, axs = plt.subplots(1, 2, figsize=(14,6))
    sc.pl.embedding(ad_rna, basis="X_umap_joint", color=ct_key, legend_loc="right margin", show=False, ax=axs[0], frameon=True)
    axs[0].set_title(f"UMAP — {ct_key}")
    sc.pl.embedding(ad_rna, basis="X_tsne_joint", color=ct_key, legend_loc="right margin", show=False, ax=axs[1], frameon=True)
    axs[1].set_title(f"tSNE — {ct_key}")
    plt.tight_layout()
    fig.savefig(OUTDIR / f"umap_tsne_joint_{ct_key}.png", dpi=300, bbox_inches="tight")
    plt.show()
else:
    print("No cell-type column found; skipping side-by-side plot.")


In [None]:

# === Focus on a single cell type and extract subclusters ===
assert ct_key is not None, "Target cell type column not available in obs."
mask_ct = ad_rna.obs[ct_key] == TARGET_CELLTYPE
assert mask_ct.sum() > 0, f"No cells found for TARGET_CELLTYPE={TARGET_CELLTYPE!r}"

ad_ct = ad_rna[mask_ct].copy()
subclusters = ad_ct.obs["leiden_joint"].astype(str).unique().tolist()
subclusters.sort(key=lambda x: (len(x), x))
print("Subclusters in", TARGET_CELLTYPE, ":", subclusters)


In [None]:

# === Helper: run pairwise DE and DS between subclusters within the selected cell type ===

def run_pairwise_de_ds(model, mdata, ad_rna, ct_key, target_ct, de_delta, ds_delta, fdr, batch_size_post):
    # We'll create a view of the full MuData restricted to the cell type
    mask = ad_rna.obs[ct_key] == target_ct
    m_ct = mdata[mask].copy()
    ad_rna_ct = ad_rna[mask].copy()  # for plotting/labels
    
    # Prepare outputs
    pairs = []
    de_tables = {}
    ds_tables = {}

    clabs = ad_rna_ct.obs["leiden_joint"].astype(str).values
    unique_subs = sorted(pd.unique(clabs), key=lambda x: (len(x), x))

    for i in range(len(unique_subs)):
        for j in range(i+1, len(unique_subs)):
            a = unique_subs[i]; b = unique_subs[j]
            idx1 = (clabs == a)
            idx2 = (clabs == b)

            # Differential Expression (genes)
            de_df = model.differential_expression(
                adata=m_ct["rna"],  # RNA AnnData
                idx1=idx1,
                idx2=idx2,
                mode="change",
                delta=de_delta,
                fdr_target=fdr,
                batch_size=batch_size_post,
                all_stats=True,
                silent=True,
            )

            # Differential Splicing (junctions) - uses the method you added
            ds_df = model.differential_splicing(
                adata=m_ct["splicing"],  # Splicing AnnData
                idx1=idx1,
                idx2=idx2,
                mode="change",
                delta=ds_delta,
                fdr_target=fdr,
                batch_size=batch_size_post,
                all_stats=True,
                silent=True,
            )

            key = f"{a}_vs_{b}"
            pairs.append(key)
            de_tables[key] = de_df
            ds_tables[key] = ds_df

    return ad_rna_ct, pairs, de_tables, ds_tables

ad_rna_ct, PAIRS, DE_TABLES, DS_TABLES = run_pairwise_de_ds(
    model, mdata, ad_rna, ct_key, TARGET_CELLTYPE, DE_DELTA, DS_DELTA, FDR, BATCH_SIZE_POST
)
PAIRS


In [None]:

# === Summarize counts and overlaps ===
def summarize_pair(de_df, ds_df, gene_col_in_splicing=("gene_id","gene_name")):
    # Count sig
    col_prob_de = "proba_de" if "proba_de" in de_df.columns else ("probability" if "probability" in de_df.columns else None)
    col_prob_ds = "proba_ds" if "proba_ds" in ds_df.columns else ("proba_de" if "proba_de" in ds_df.columns else None)
    if col_prob_de is None or col_prob_ds is None:
        raise RuntimeError("Could not find probability columns in DE/DS outputs.")

    sig_de = de_df[de_df[col_prob_de] >= 1 - FDR].copy()
    sig_ds = ds_df[ds_df[col_prob_ds] >= 1 - FDR].copy()

    # Map junctions -> gene if available in splicing var
    # Attempt common annotation names in the splicing AnnData
    sp_var = mdata["splicing"].var
    gene_col = None
    for cand in gene_col_in_splicing if isinstance(gene_col_in_splicing, (list,tuple)) else [gene_col_in_splicing]:
        if cand in sp_var.columns:
            gene_col = cand
            break

    if gene_col is not None:
        # The DS table index should correspond to var names; if not, try to align
        ds_genes = sp_var.loc[ds_df.index, gene_col].astype(str)
        sig_ds_genes = sp_var.loc[sig_ds.index, gene_col].astype(str)
    else:
        ds_genes = pd.Series(index=ds_df.index, data=["NA"]*len(ds_df))
        sig_ds_genes = pd.Series(index=sig_ds.index, data=["NA"]*len(sig_ds))

    # Overlap by gene symbol/id between DE and DS
    de_genes = set(sig_de.index.astype(str)) if sig_de.index.equals(mdata["rna"].var_names) else set(sig_de.index.astype(str))
    # If DE returns index as gene names matching rna.var_names, this works.
    # If DS genes are from mapping, use that set for overlap.
    ds_genes_set = set(sig_ds_genes.astype(str))

    gene_overlap = sorted(de_genes.intersection(ds_genes_set))

    summary = {
        "n_sig_de": int(sig_de.shape[0]),
        "n_sig_ds": int(sig_ds.shape[0]),
        "n_overlap_genes": len(gene_overlap),
        "overlap_genes": gene_overlap[:50],  # show first 50
    }
    return summary, sig_de, sig_ds

SUMMARIES = {}
SIG_DE = {}
SIG_DS = {}

for p in PAIRS:
    s, de_sig, ds_sig = summarize_pair(DE_TABLES[p], DS_TABLES[p])
    SUMMARIES[p] = s
    SIG_DE[p] = de_sig
    SIG_DS[p] = ds_sig

pd.DataFrame(SUMMARIES).T.sort_index()


In [None]:

# === Quick visualization of top features ===
def top_features_table(df, score_cols=("effect_size","emp_effect","lfc_mean"), n=N_TOP_SHOW):
    # Try to pick a sensible score column
    for c in score_cols:
        if c in df.columns:
            return df.sort_values(c, key=lambda s: s.abs(), ascending=False).head(n)
    # Fallback
    return df.head(n)

for p in PAIRS:
    print("\n===", p, "===")
    print("Top DE:")
    display(top_features_table(SIG_DE[p]))
    print("Top DS:")
    display(top_features_table(SIG_DS[p]))


In [None]:

# === UMAP of target cell type colored by denoised PSI and normalized expression ===
# Pick a representative DS junction and a DE gene from the first pair, if any
if len(PAIRS) > 0:
    p0 = PAIRS[0]
    ds_hits = SIG_DS[p0].index.tolist()
    de_hits = SIG_DE[p0].index.tolist() if isinstance(SIG_DE[p0].index, pd.Index) else []

    if len(ds_hits) == 0 or len(de_hits) == 0:
        print("Not enough hits for a demo overlay.")
    else:
        junc = ds_hits[0]
        gene = de_hits[0]
        print("Demo DS junction:", junc, " | Demo DE gene:", gene)

        # Subset MuData to target cell type
        m_ct = mdata[ad_rna.obs[ct_key] == TARGET_CELLTYPE].copy()

        # Decode splicing PSI and RNA expression for these cells only
        psi = model.get_normalized_splicing(adata=m_ct, batch_size=BATCH_SIZE_POST, return_numpy=True)   # n_cells x n_juncs
        expr = model.get_normalized_expression(adata=m_ct, batch_size=BATCH_SIZE_POST, return_numpy=True) # n_cells x n_genes

        # Build an AnnData for plotting with the same UMAP embedding from ad_rna_ct
        ad_plot = ad_rna[ad_rna.obs[ct_key] == TARGET_CELLTYPE].copy()

        # Locate columns
        j_idx = m_ct["splicing"].var_names.get_loc(junc) if junc in m_ct["splicing"].var_names else None
        g_idx = m_ct["rna"].var_names.get_loc(gene)       if gene in m_ct["rna"].var_names else None

        if j_idx is not None:
            ad_plot.obs[f"PSI::{junc}"] = psi[:, j_idx]
        if g_idx is not None:
            ad_plot.obs[f"EXP::{gene}"] = expr[:, g_idx]

        # UMAP colored by PSI
        if f"PSI::{junc}" in ad_plot.obs:
            fig, ax = plt.subplots(figsize=(7,6))
            sc.pl.embedding(ad_plot, basis="X_umap_joint", color=f"PSI::{junc}", color_map="viridis", show=False, ax=ax, frameon=True)
            ax.set_title(f"{TARGET_CELLTYPE} — PSI {junc}")
            plt.tight_layout()
            fig.savefig(OUTDIR / f"umap_{TARGET_CELLTYPE}_PSI_{junc}.png", dpi=300, bbox_inches="tight")
            plt.show()

        # UMAP colored by expression
        if f"EXP::{gene}" in ad_plot.obs:
            fig, ax = plt.subplots(figsize=(7,6))
            sc.pl.embedding(ad_plot, basis="X_umap_joint", color=f"EXP::{gene}", color_map="viridis", show=False, ax=ax, frameon=True)
            ax.set_title(f"{TARGET_CELLTYPE} — Expr {gene}")
            plt.tight_layout()
            fig.savefig(OUTDIR / f"umap_{TARGET_CELLTYPE}_Expr_{gene}.png", dpi=300, bbox_inches="tight")
            plt.show()
else:
    print("No subcluster pairs found; nothing to plot.")


In [None]:

# === Save pairwise DE/DS tables ===
for p in PAIRS:
    DE_TABLES[p].to_csv(OUTDIR / f"DE_{TARGET_CELLTYPE}_{p}.csv")
    DS_TABLES[p].to_csv(OUTDIR / f"DS_{TARGET_CELLTYPE}_{p}.csv")
print("Saved pairwise DE/DS CSVs to:", OUTDIR)
