# Mouse Cortex Neuronal Development - Training and Transfer Learning

This notebook demonstrates LSD training on mouse cortex neuronal development data and includes:
- Training LSD on neuronal differentiation data
- Cell fate prediction for excitatory/inhibitory neurons
- Transfer learning to unseen cell populations
- Evaluation on unseen data from different studies

**Data Requirements:**
- Training data: `../../data/Mouse_Cortex/pertrub_adata.h5ad`
- Unseen data: `../../data/Mouse_Cortex/perturb_adata_unseen.h5ad`

## Setup

In [None]:
import os
from pathlib import Path
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch

# Import lsdpy components
from sclsd import LSD, LSDConfig, set_all_seeds, clear_pyro_state

# Set random seed
SEED = 42
set_all_seeds(SEED)

# Configure device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create output directories
model_dir = Path("./mouse_cortex_model")
model_dir.mkdir(parents=True, exist_ok=True)

# Configure plotting
plt.rcParams['figure.dpi'] = 150
sns.set_style('white')

## Load and Inspect Data

In [None]:
# Load the Mouse Cortex dataset
data_path = "../../data/Mouse_Cortex/pertrub_adata.h5ad"
adata = sc.read(data_path)

print(f"Loaded {adata.n_obs:,} cells x {adata.n_vars:,} genes")
print(f"\nObs columns: {adata.obs.columns.tolist()}")
print(f"\nCell type distribution:")
print(adata.obs['clusters'].value_counts())

In [None]:
# Visualize the data
sc.pl.umap(adata, color=["clusters"], title="Mouse Cortex Cell Types")

## Configure and Train LSD Model

In [None]:
# Configure model for mouse cortex data
cfg = LSDConfig()
cfg.model.layer_dims.potential = [16, 16]
cfg.optimizer.adam.lr = 2e-3
cfg.optimizer.adam.T_0 = 100
cfg.walks.batch_size = 256
cfg.walks.path_len = 30
cfg.walks.num_walks = 8192
cfg.walks.random_state = SEED

In [None]:
# Initialize LSD model
clear_pyro_state()
lsd = LSD(
    adata,
    cfg,
    device=device,
    lib_size_key="librarysize",
    raw_count_key="raw"
)

In [None]:
# Define phylogeny for cortical development
# Apical progenitors -> IPC -> Upper/Lower layer neurons
# Apical progenitors -> IN nonMGE -> IN MGE
PHYLOGENY = {
    'Apical progenitors': ["IPC", 'IN nonMGE'],
    "IPC": ["ULPN", "Migrating neurons"],
    'IN nonMGE': ["IN MGE"],
    'ULPN': ["DLPN"],
    "Migrating neurons": ["DLPN"],
    'DLPN': [],
    'IN MGE': []
}

lsd.set_phylogeny(PHYLOGENY, "clusters")
lsd.set_prior_transition(prior_time_key="prior_pseudotime")
lsd.prepare_walks()

In [None]:
# Train the model
clear_pyro_state()
lsd.train(
    num_epochs=100,
    save_dir=str(model_dir),
    save_interval=50
)

print("Training completed!")

## Extract and Visualize Results

In [None]:
# Get LSD outputs
final_adata = lsd.get_adata()

print("LSD outputs added:")
print(f"  Obs: potential, lsd_pseudotime, entropy")
print(f"  Obsm: X_cell_state, X_diff_state")

In [None]:
# Visualize LSD outputs on UMAP
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

sc.pl.umap(final_adata, color="clusters", ax=axes[0, 0], show=False, title="Cell Types")
sc.pl.umap(final_adata, color="entropy", ax=axes[0, 1], show=False, cmap="viridis", title="Entropy")
sc.pl.umap(final_adata, color="potential", ax=axes[1, 0], show=False, cmap="coolwarm", title="Potential")
sc.pl.umap(final_adata, color="lsd_pseudotime", ax=axes[1, 1], show=False, cmap="plasma", title="LSD Pseudotime")

plt.tight_layout()
plt.savefig(model_dir / "umap_analysis.png", dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Streamlines visualization
lsd.stream_lines(embedding="X_umap")

## Cell Fate Prediction

In [None]:
# Compute cell fates by ODE simulation
dyn_adata = lsd.get_cell_fates(
    final_adata,
    time_range=50,
    dt=0.2,
    batch_size=512,
    cluster_key="clusters",
    return_paths=False
)

print("\nPredicted fate distribution:")
print(dyn_adata.obs['fate'].value_counts())

In [None]:
# Define cell function categories
excitatory = ["DLPN", "ULPN", "Migrating neurons"]
inhibitory = ['IN nonMGE', "IN MGE"]
progenitors = ["Apical progenitors", "IPC"]

# Assign function labels
conditions = [
    dyn_adata.obs['fate'].isin(excitatory),
    dyn_adata.obs['fate'].isin(inhibitory),
    dyn_adata.obs['fate'].isin(progenitors)
]
choices = ['excitatory', 'inhibitory', 'progenitors']
dyn_adata.obs['fate_function'] = np.select(conditions, choices, default='unknown')

# Visualize fate predictions
sc.pl.embedding(
    dyn_adata,
    basis="X_diff_state",
    color="fate_function",
    title="Predicted Cell Fate Function"
)

## Transfer Learning to Unseen Data

Test the trained model on unseen cell populations from a different study.

In [None]:
# Load unseen data
unseen_path = "../../data/Mouse_Cortex/perturb_adata_unseen.h5ad"
bdata = sc.read(unseen_path)

# Subset to same genes and preprocess
bdata = bdata[:, final_adata.var_names]
bdata.layers["raw"] = bdata.X.copy()
bdata.obs['librarysize'] = np.asarray(bdata.X.sum(axis=1)).flatten()
sc.pp.normalize_total(bdata, target_sum=1e4)
sc.pp.log1p(bdata)
sc.pp.neighbors(bdata)

print(f"Loaded unseen data: {bdata.n_obs} cells")
print(f"\nUnseen cell types:")
print(bdata.obs['clusters'].value_counts())

In [None]:
# Combine seen and unseen data
adata_combined = sc.concat(
    [final_adata, bdata],
    label='data_source',
    keys=['Seen', 'Unseen']
)

# Recompute UMAP on combined data
sc.pp.neighbors(adata_combined)
sc.tl.umap(adata_combined)

sc.pl.umap(adata_combined, color=["clusters", "data_source"], ncols=1)

In [None]:
# Define extended cell type categories for combined data
excitatory_extended = [
    "DLPN", "ULPN", "Migrating neurons",
    'Excit_Car3', 'Excit_L2 IT ENTl', 'Excit_L5 PT CTX',
    'Excit_L5IT', 'Excit_L5NP_CTX', 'Excit_L6CT_CTX',
    'Excit_L6IT', 'Excit_L6b CTX', 'Excit_L6b/CT ENT', 'Excit_Upper'
]
inhibitory_extended = [
    'IN nonMGE', "IN MGE",
    'Inhib_Id2', 'Inhib_Lhx6+Sst-', 'Inhib_Meis2', 'Inhib_Sst'
]

# Assign function labels
conditions = [
    adata_combined.obs['clusters'].isin(excitatory_extended),
    adata_combined.obs['clusters'].isin(inhibitory_extended),
    adata_combined.obs['clusters'].isin(progenitors)
]
adata_combined.obs['function'] = np.select(conditions, choices, default='unknown')

sc.pl.umap(adata_combined, color=['function', 'data_source'])

In [None]:
# Apply trained model to combined data
lsd.set_adata(adata_combined)
valid_adata = lsd.get_adata()

# Predict fates for all cells
combined_dyn = lsd.get_cell_fates(
    valid_adata,
    time_range=50,
    dt=0.2,
    batch_size=512,
    cluster_key="clusters",
    return_paths=False
)

# Assign fate function labels
conditions = [
    combined_dyn.obs['fate'].isin(excitatory_extended),
    combined_dyn.obs['fate'].isin(inhibitory_extended),
    combined_dyn.obs['fate'].isin(progenitors)
]
combined_dyn.obs['fate_function'] = np.select(conditions, choices, default='unknown')

print("Fate predictions for combined data:")
print(combined_dyn.obs['fate'].value_counts())

## Evaluate Transfer Learning Performance

In [None]:
# Compute confusion matrices for seen vs unseen data
seen_cells = combined_dyn[combined_dyn.obs["data_source"] == "Seen"]
unseen_cells = combined_dyn[combined_dyn.obs["data_source"] == "Unseen"]

confusion_seen = pd.crosstab(
    seen_cells.obs["function"],
    seen_cells.obs["fate_function"],
    normalize='index'
)

confusion_unseen = pd.crosstab(
    unseen_cells.obs["function"],
    unseen_cells.obs["fate_function"],
    normalize='index'
)

print("Confusion matrix (Seen data):")
print(confusion_seen)

print("\nConfusion matrix (Unseen data):")
print(confusion_unseen)

In [None]:
# Visualize confusion matrices
cmap = sns.light_palette("navy", as_cmap=True)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Seen data
sns.heatmap(
    confusion_seen,
    annot=True,
    fmt=".2f",
    cmap=cmap,
    linewidths=0.5,
    square=True,
    ax=axes[0],
    cbar_kws={'label': 'Proportion'}
)
axes[0].set_title("Fate Prediction - Seen Data", fontsize=14)
axes[0].set_xlabel("Predicted Fate Function")
axes[0].set_ylabel("True Function")

# Unseen data
sns.heatmap(
    confusion_unseen,
    annot=True,
    fmt=".2f",
    cmap=cmap,
    linewidths=0.5,
    square=True,
    ax=axes[1],
    cbar_kws={'label': 'Proportion'}
)
axes[1].set_title("Fate Prediction - Unseen Data", fontsize=14)
axes[1].set_xlabel("Predicted Fate Function")
axes[1].set_ylabel("True Function")

plt.tight_layout()
plt.savefig(model_dir / "transfer_learning_evaluation.png", dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Compute accuracy metrics
def compute_accuracy(confusion_matrix):
    """Compute accuracy from normalized confusion matrix."""
    correct = 0
    total = 0
    for idx in confusion_matrix.index:
        if idx in confusion_matrix.columns:
            correct += confusion_matrix.loc[idx, idx]
            total += 1
    return correct / total if total > 0 else 0

seen_accuracy = compute_accuracy(confusion_seen)
unseen_accuracy = compute_accuracy(confusion_unseen)

print(f"\nTransfer Learning Results:")
print(f"="*50)
print(f"Seen data accuracy: {seen_accuracy:.1%}")
print(f"Unseen data accuracy: {unseen_accuracy:.1%}")
print(f"="*50)

## Save Results

In [None]:
import json

# Save summary
results_summary = {
    "training_cells": int(len(final_adata)),
    "unseen_cells": int(len(bdata)),
    "seen_accuracy": float(seen_accuracy),
    "unseen_accuracy": float(unseen_accuracy),
    "fate_distribution": dyn_adata.obs['fate'].value_counts().to_dict(),
}

with open(model_dir / "analysis_summary.json", "w") as f:
    json.dump(results_summary, f, indent=2)

print(f"Results saved to {model_dir}/")
print("\nAnalysis complete!")