# Multi-model Latent Space Analysis Notebook
This notebook will:

1. Load a MuData object and multiple MULTIVISPLICE models from given paths.
2. Compute latent representations for each model.
3. Generate UMAP embeddings for each model x grouping key, plotted side-by-side.
4. Compute silhouette scores for each model and grouping key, and plot a bar chart.


# 1. Imports and Setup

In [2]:
import os
import mudata as mu
import scvi  # your local copy of scvi-tools

print("scvi version:", getattr(scvi, "__version__", "No version attr"))
print("scvi loaded from:", scvi.__file__)
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import silhouette_score
import pandas as pd

scvi version: 1.3.1
scvi loaded from: /gpfs/commons/home/svaidyanathan/repos/scvi-tools-splicing/src/scvi/__init__.py


### User Inputs

Note: All models must have been trained on the same mdata file.

In [None]:
# Path to your MuData file
MUDATA_PATH="/gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/MOUSE_SPLICING_FOUNDATION/MODEL_INPUT/052025/mouse_foundation_data_20250502_155802_ge_splice_combined.h5mu"

# Dictionary mapping model names to their saved directories
model_paths = {
    "dataset_batch_key": "/gpfs/commons/home/svaidyanathan/multi_vi_splice_runs/MultiVISpliceTraining_20250512_174535_job4478547/models",
    "mouse_id_batch_key": "/gpfs/commons/home/svaidyanathan/multi_vi_splice_runs/MultiVISpliceTraining_20250511_192758_job4472612/models",
    # add more models as needed
}

# List of obs keys to color UMAPs by
groups = [
    'dataset', 'broad_cell_type', 'mouse.id'
]

# Output directory for figures
FIG_DIR = "/gpfs/commons/home/svaidyanathan/repos/multivi_tools_splicing/models/multivisplice"
os.makedirs(FIG_DIR, exist_ok=True)

## 2. Load Data and Models

In [None]:
# Load the MuData object once
mdata = mu.read_h5mu(MUDATA_PATH)
print("Loaded MuData: ", mdata)

# Container for latent representations
latents = {}

# Loop through each model, reload it, and compute latent representation
for name, path in model_paths.items():
    print(f"Loading model {name} from {path}")
    model = scvi.model.MULTIVISPLICE.load(path, adata=mdata)
    print(f"Model {name} reloaded. AnnData shape: {mdata['rna'].shape}")
    model.view_anndata_setup()
    print(f"Computing latent for {name}...")
    latent = model.get_latent_representation()
    latents[name] = latent
    print(f"Latent for {name} has shape {latent.shape}")

FileNotFoundError: [Errno 2] No such file or directory: '/gpfs/commons/groups/knowles_lab/Karin/Leaflet-analysis-WD/MOUSE_SPLICING_FOUNDATION/MODEL_INPUT/052025/mouse_foundation_data_20250502_155802_ge_splice_combined.h5mu'

## 3. Generate Side-by-Side UMAPs

In [None]:
for group in groups:
    # Ensure the group column is categorical
    mdata['rna'].obs[group] = mdata['rna'].obs[group].astype('category')
    categories = mdata['rna'].obs[group].cat.categories

    # Create a figure with one panel per model
    n_models = len(model_paths)
    fig, axes = plt.subplots(1, n_models, figsize=(5*n_models, 4), squeeze=False)

    for idx, (model_name, latent) in enumerate(latents.items()):
        ax = axes[0, idx]
        # Create a temporary AnnData for UMAP
        ad = sc.AnnData(latent)
        ad.obs = mdata['rna'].obs.copy()
        ad.obsm['X_umap_input'] = latent

        sc.pp.neighbors(ad, use_rep='X_umap_input', show=False)
        sc.tl.umap(ad, min_dist=0.2, show=False)

        sc.pl.umap(
            ad,
            color=group,
            ax=ax,
            show=False,
            title=f"{model_name}",
            legend_loc='right margin'
        )

    fig.suptitle(f"UMAPs colored by {group}")
    outpath = os.path.join(FIG_DIR, f"umap_{group}.png")
    fig.savefig(outpath, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved UMAP figure for {group} to {outpath}")


## 4. Compute and Plot Silhouette Scores

In [None]:
# Prepare DataFrame to store scores
sil_df = pd.DataFrame(index=groups, columns=model_paths.keys(), dtype=float)

for model_name, latent in latents.items():
    for group in groups:
        labels = mdata['rna'].obs[group].cat.codes.values
        score = silhouette_score(latent, labels)
        sil_df.loc[group, model_name] = score
        print(f"Silhouette for {model_name} on {group}: {score:.3f}")

# Plot bar chart
fig, ax = plt.subplots(figsize=(8, 6))
sil_df.plot(kind='bar', ax=ax)
ax.set_ylabel('Silhouette Score')
ax.set_xlabel('Group')
ax.set_title('Silhouette Scores by Model and Group')
plt.xticks(rotation=45, ha='right')
plt.legend(title='Model')
fig.tight_layout()
outpath = os.path.join(FIG_DIR, "silhouette_scores.png")
fig.savefig(outpath, dpi=300)
plt.show()
print(f"Saved silhouette bar plot to {outpath}")
