# Zebrafish Dataset - Postprocessing and Gene Perturbation Analysis

This notebook demonstrates postprocessing of LSD results including:
- Velocity projection and streamline visualization
- Cell fate analysis
- **Gene perturbation analysis (NOTO gene)**
- Large-scale gene knockout for early blastomeres
- Enrichment analysis

**Prerequisites:**
- Complete the `train.ipynb` notebook first
- Have a trained model saved in `./zebrafish_model/`

## Setup

In [None]:
import os
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from scipy.stats import gaussian_kde

# Import lsdpy components
from sclsd import LSD, LSDConfig, set_all_seeds, clear_pyro_state

# Set random seed
SEED = 42
set_all_seeds(SEED)

# Configure device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Load Trained Model and Results

In [None]:
# Path to data and model directory
data_path = "../../data/Zebrafish/perturb_valid_adata.h5ad"
model_dir = "./zebrafish_model"

# Load original data
adata = sc.read(data_path)

# Initialize LSD model with Zebrafish-specific config
cfg = LSDConfig()
cfg.model.layer_dims.potential = [16, 16]  # Zebrafish-specific

lsd = LSD(
    adata,
    cfg,
    device=device,
    lib_size_key="librarysize",
    raw_count_key="raw"
)

# Load the trained model
model_path = "lsd_model_epoch0250.pth"
lsd.load(model_dir, model_path)

# Extract LSD outputs (transition matrix, potential, entropy, pseudotime)
result = lsd.get_adata()

print(f"Loaded {result.n_obs} cells x {result.n_vars} genes")
print(f"\nAvailable obs columns: {result.obs.columns.tolist()}")
print(f"\nAvailable obsm keys: {list(result.obsm.keys())}")

## Predict Cell Fates

In [None]:
# Predict cell fates by ODE propagation
# Zebrafish uses longer time_range=20 for fate prediction
dyn_adata = lsd.get_cell_fates(result, time_range=20, cluster_key="clusters")

# Use lineages as clusters for visualization
dyn_adata.obs["clusters"] = dyn_adata.obs["lineages"]

## Visualize LSD Outputs

In [None]:
# Plot LSD outputs
cols = ["potential", "lsd_pseudotime", "fate", "entropy"]

print("=" * 60)
print("Plotting LSD outputs on force directed projection")
print("=" * 60)

sc.pl.embedding(dyn_adata, color=cols, basis="X_force_directed", ncols=2)

In [None]:
print("=" * 60)
print("Plotting LSD outputs on differentiation state projection")
print("=" * 60)
sc.pl.embedding(dyn_adata, color=cols, basis="X_diff_state", ncols=2)

## Plot Velocity Streamlines

In [None]:
print("=" * 60)
print("Plotting LSD streamlines on force directed projection")
print("=" * 60)
lsd.stream_lines(embedding="X_force_directed", size=60)

In [None]:
print("=" * 60)
print("Plotting LSD streamlines on differentiation state projection")
print("=" * 60)
lsd.stream_lines(embedding="X_diff_state", size=60)

## Gene Perturbation Analysis - NOTO Gene

This section demonstrates how to analyze the effect of perturbing a specific gene (NOTO) on cell fate decisions.

In [None]:
# Run perturbation sweep over levels 0..max_pert for NOTO gene
x = torch.from_numpy(dyn_adata.X.toarray())
all_levels_fates = []
levels = list(range(1, 10 + 1))

for lvl in levels:
    pert_fates, unpert_fates = lsd.perturb(
        adata=dyn_adata,
        x=x,
        gene_name="NOTO",
        cluster_key="clusters",
        max_perturbations=lvl,
        batch_size=2048
    )
    all_levels_fates.append(np.asarray(pert_fates, dtype=object))
baseline_fate = np.asarray(unpert_fates, dtype=object)

# Encode fates as integer codes for fast comparison across levels
all_unique_fates = pd.Index(np.unique(np.concatenate(all_levels_fates + [baseline_fate])))
fate_to_code = {f: i for i, f in enumerate(all_unique_fates)}

code_levels = np.vstack([np.vectorize(fate_to_code.get)(f) for f in all_levels_fates])
baseline_codes = np.vectorize(fate_to_code.get)(baseline_fate)

# Per-cell: first level where fate differs from baseline
chg_vs_base = code_levels != baseline_codes[None, :]
any_change = chg_vs_base.any(axis=0)

first_idx0 = np.argmax(chg_vs_base, axis=0)
first_change_level = np.where(any_change, first_idx0 + 1, -1)

fate_at_first_change = np.where(
    any_change,
    np.array([all_levels_fates[first_idx0[i]][i] for i in range(len(first_idx0))], dtype=object),
    baseline_fate,
)

# Total number of flips across successive perturbation levels
total_changes = np.sum(code_levels[1:, :] != code_levels[:-1, :], axis=0) if len(all_levels_fates) > 1 else 0

change_status = np.where(
    first_change_level == -1,
    "unchanged",
    np.array([f"changed_at_{lv}" for lv in first_change_level], dtype=object)
)

# Summary dataframe
summary_df = pd.DataFrame({
    "cell_id": dyn_adata.obs_names.values,
    "cluster": dyn_adata.obs["clusters"].astype(str).values,
    "baseline_fate": baseline_fate,
    "lsd_pseudotime": dyn_adata.obs["lsd_pseudotime"].values,
    "first_change_level": first_change_level,
    "fate_at_first_change": fate_at_first_change,
    "final_fate": all_levels_fates[-1],
    "total_changes": np.asarray(total_changes, dtype=int),
    "change_status": change_status,
})

print(f"Summary of perturbation effects on {len(summary_df)} cells")
print(f"Cells with fate change: {(summary_df.first_change_level != -1).sum()}")

In [None]:
# Plot the distribution of cell fate changes after NOTO perturbation across pseudotime

# Split pseudotime by fate change
changed = summary_df.loc[summary_df.first_change_level != -1, "lsd_pseudotime"].dropna().values
unchanged = summary_df.loc[summary_df.first_change_level == -1, "lsd_pseudotime"].dropna().values
all_vals = np.r_[changed, unchanged]

# Binning
bins = np.linspace(all_vals.min(), all_vals.max(), 30)
x_kde = np.linspace(all_vals.min(), all_vals.max(), 500)

plt.figure(figsize=(8, 6))

# Stacked histogram
plt.hist(
    [unchanged, changed],
    bins=bins,
    density=True,
    stacked=True,
    alpha=0.6,
    color=["#9e9e9e", "#1f77b4"],
    label=["Unchanged fate", "Changed fate"],
)

# KDEs
plt.plot(x_kde, gaussian_kde(all_vals)(x_kde), color="0.3", lw=2, label="All KDE")
if len(changed) > 1:
    plt.plot(
        x_kde,
        gaussian_kde(changed)(x_kde) * (len(changed) / len(all_vals)),
        color="#1f77b4",
        lw=2,
        label="Changed KDE (scaled)",
    )

plt.xlabel("LSD pseudotime")
plt.ylabel("Density")
plt.title("Pseudotime distribution by fate change (NOTO perturbation)")
plt.legend(frameon=False)
sns.despine()
plt.tight_layout()
plt.savefig(os.path.join(model_dir, "noto_perturbation_pseudotime.png"), dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Plot frequency heatmap for cell fate changes across NOTO perturbation iterations
n_bins = 10
max_pert = 10

baseline_fate_arr = summary_df["baseline_fate"].to_numpy(dtype=object)
final_fate_arr = summary_df["final_fate"].to_numpy(dtype=object)
first_lvl = summary_df["first_change_level"].to_numpy()
pt = summary_df["lsd_pseudotime"].to_numpy(dtype=float)

# Bin pseudotime
edges = np.linspace(pt.min(), pt.max(), n_bins + 1)
centers = 0.5 * (edges[:-1] + edges[1:])
bin_idx = np.clip(np.digitize(pt, edges) - 1, 0, n_bins - 1)

# Fate at each perturbation level (level 0 = baseline)
level_fates = [baseline_fate_arr]
for lvl in range(1, max_pert + 1):
    level_fates.append(np.where((first_lvl != -1) & (lvl >= first_lvl), final_fate_arr, baseline_fate_arr))

def fold_change_mat(target):
    freq = np.full((max_pert + 1, n_bins), np.nan, dtype=float)
    for lvl in range(max_pert + 1):
        f = level_fates[lvl]
        for b in range(n_bins):
            m = (bin_idx == b)
            denom = m.sum()
            if denom:
                freq[lvl, b] = np.mean(f[m] == target)

    base = freq[0, :]
    with np.errstate(divide="ignore", invalid="ignore"):
        fc = freq / base[None, :]
    fc[np.isinf(fc)] = np.nan
    return pd.DataFrame(fc, index=range(max_pert + 1), columns=range(n_bins))

# Plot heatmaps for key fates
for target in ["Notochord", "Prechordal Plate"]:
    mat = fold_change_mat(target)
    plt.figure(figsize=(9, 4))
    sns.heatmap(mat, cmap="viridis", cbar=True)
    plt.title(f"Fold-change vs baseline: {target}")
    plt.xlabel("Pseudotime bin")
    plt.ylabel("# perturbations")
    xt = np.linspace(0, n_bins - 1, 5, dtype=int)
    plt.xticks(xt + 0.5, [f"{centers[i]:.2f}" for i in xt], rotation=30, ha="right")
    plt.tight_layout()
    plt.savefig(os.path.join(model_dir, f"fold_change_{target.replace(' ', '_')}.png"), dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# Plot frequency of cells that change their fate after NOTO perturbation
max_pert = 10

lvl = summary_df.loc[summary_df.first_change_level != -1, "first_change_level"]
fate_change_percent = (
    lvl.value_counts()
       .reindex(range(1, max_pert + 1), fill_value=0)
       .sort_index()
       .rename_axis("perturbation_level")
       .reset_index(name="n_changed_cells")
)
fate_change_percent["percent_of_changed"] = 100 * fate_change_percent["n_changed_cells"] / len(lvl)

# Bar: % share per level
plt.figure(figsize=(6, 4))
plt.bar(fate_change_percent["perturbation_level"], fate_change_percent["percent_of_changed"])
plt.xlabel("# perturbations")
plt.ylabel("% of changed cells")
plt.title("Fate-change share by perturbation level")
plt.xticks(range(1, max_pert + 1))
plt.tight_layout()
plt.savefig(os.path.join(model_dir, "fate_change_by_level.png"), dpi=150, bbox_inches='tight')
plt.show()

# Line: cumulative %
plt.figure(figsize=(6, 4))
plt.plot(
    fate_change_percent["perturbation_level"],
    fate_change_percent["percent_of_changed"].cumsum(),
    marker="o"
)
plt.xlabel("# perturbations")
plt.ylabel("Cumulative % of changed cells")
plt.title("Cumulative fate-change share")
plt.xticks(range(1, max_pert + 1))
plt.ylim(0, 105)
plt.tight_layout()
plt.savefig(os.path.join(model_dir, "cumulative_fate_change.png"), dpi=150, bbox_inches='tight')
plt.show()

## Large-Scale Gene Knockout for Early Blastomeres

This section performs genome-wide perturbation analysis on early blastomere cells.

In [None]:
# Large scale gene KO for early blastomeres
from tqdm.notebook import tqdm

rows = []
x = torch.from_numpy(dyn_adata.X.toarray())
idx = np.where(dyn_adata.obs['clusters'] == "Early Blastomeres")[0]
IC = x[idx].to(device)

print(f"Running perturbation analysis on {len(idx)} Early Blastomere cells")
print(f"Total genes to perturb: {len(adata.var_names)}")

for i, gene in enumerate(tqdm(adata.var_names, desc="Perturbing genes", total=len(adata.var_names))):
    perturbed_fates, unperturbed_fates = lsd.perturb(
        adata=dyn_adata,
        x=IC,
        gene_name=gene,
        cluster_key="clusters",
        batch_size=4096
    )

    # Count fate outcomes
    pert_count_noto = np.sum(perturbed_fates == "Notochord")
    unpert_count_noto = np.sum(unperturbed_fates == "Notochord")
    pert_count_pre = np.sum(perturbed_fates == "Prechordal Plate")
    unpert_count_pre = np.sum(unperturbed_fates == "Prechordal Plate")

    # Calculate log fold changes
    pseudocount = 0
    logFC_noto = np.log2((pert_count_noto + pseudocount) / (unpert_count_noto + pseudocount))
    logFC_pre = np.log2((pert_count_pre + pseudocount) / (unpert_count_pre + pseudocount))

    rows.append({
        "Gene": gene,
        "Perturbed_Notochord_count": pert_count_noto,
        "Unperturbed_Notochord_count": unpert_count_noto,
        "Perturbed_Prechordal_count": pert_count_pre,
        "Unperturbed_Prechordal_count": unpert_count_pre,
        "Notochord_LogFC": logFC_noto,
        "Prechordal_LogFC": logFC_pre,
        "Total_cells": len(IC)
    })

df = pd.DataFrame(rows)
df.to_csv(os.path.join(model_dir, "large_scale_gene_perturbation.csv"), index=False)
print(f"\nResults saved to {model_dir}/large_scale_gene_perturbation.csv")

In [None]:
# Rank genes by Notochord LogFC
df_clean = df.dropna()
df_ranked = df_clean.copy()
df_ranked['Noto_Rank'] = df_ranked['Notochord_LogFC'].rank(method='first', ascending=True).astype(int)
df_ranked = df_ranked.sort_values('Noto_Rank')
top_noto_negative = df_clean.nsmallest(4, 'Notochord_LogFC')[['Gene', 'Notochord_LogFC']]

# Create scatter plot
plt.figure(figsize=(8, 6))
plt.scatter(
    df_ranked['Noto_Rank'],
    df_ranked['Notochord_LogFC'],
    alpha=0.6, s=20, c='darkorange', edgecolor='none'
)

# Highlight top genes
top_genes_noto = pd.concat([top_noto_negative.head(4)])
if 'Gene' in df_ranked.columns:
    mark_df = df_ranked[df_ranked['Gene'].isin(top_genes_noto['Gene'])]
    plt.scatter(mark_df['Noto_Rank'], mark_df['Notochord_LogFC'], color='red', s=50, zorder=5)
    for _, row in mark_df.iterrows():
        plt.annotate(
            row['Gene'],
            (row['Noto_Rank'], row['Notochord_LogFC']),
            fontsize=8, ha='center', va='bottom'
        )

plt.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
plt.xlabel('Rank by Notochord Log2FC (smaller LogFC -> higher rank)', fontsize=11)
plt.ylabel('Notochord Log2FC', fontsize=11)
plt.title('Notochord LogFC vs Rank', fontsize=12, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.join(model_dir, "notochord_logfc_rank.png"), dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Highlight known Notochord regulators
known_tfs = [
    "FOXA2", "SP5L", "NOTO", "CREB3L2", "EVE1", "HER6", "JUN",
    "CDX4", "LEF1", "FOXA", "TWIST2", "SOX2", "SP5A"
]

tfs_df = df_ranked[df_ranked["Gene"].isin(known_tfs)].copy()
tfs_df = tfs_df.sort_values("Noto_Rank")

fig, ax = plt.subplots(figsize=(8, 6))

# All genes in grey
ax.scatter(
    df_ranked["Noto_Rank"],
    df_ranked["Notochord_LogFC"],
    color="#b0b0b0",
    alpha=0.6,
    s=18,
    rasterized=True,
    label="All genes",
)

# Highlight known regulators
if not tfs_df.empty:
    ax.scatter(
        tfs_df["Noto_Rank"],
        tfs_df["Notochord_LogFC"],
        color="#1f77b4",
        s=50,
        zorder=3,
        rasterized=True,
        label="Known Notochord regulators",
    )

    for j, (_, row) in enumerate(tfs_df.iterrows()):
        label = f"{row['Gene']} (rank {row['Noto_Rank']})"
        y_offset = (j % 4 - 1.5) * 8
        ax.annotate(
            label,
            (row["Noto_Rank"], row["Notochord_LogFC"]),
            fontsize=8,
            ha="center",
            va="bottom",
            xytext=(0, 6 + y_offset),
            textcoords="offset points",
            arrowprops=dict(arrowstyle="-", lw=0.5, color="black", alpha=0.6),
        )

ax.axhline(0, color="gray", linestyle="--", alpha=0.5, linewidth=1)
ax.set_xlabel("Rank")
ax.set_ylabel("Notochord Log2FC")
ax.set_title("LogFC of Notochord fate for known Notochord lineage regulators", fontweight="bold")
ax.grid(True, alpha=0.3)
ax.legend(frameon=False)

plt.tight_layout()
plt.savefig(os.path.join(model_dir, "known_regulators_logfc.png"), dpi=150, bbox_inches='tight')
plt.show()

## Enrichment Analysis

This section performs pathway enrichment analysis on the perturbation results.

In [None]:
# Step 1: Fit a Negative Binomial distribution to logFC changes
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceMeanField_ELBO
from pyro.infer.autoguide import AutoDelta
from pyro.optim import Adam as PyroAdam

data = df_clean["Perturbed_Notochord_count"].values
data_tensor = torch.tensor(data, dtype=torch.float)

def model(data):
    total_count = pyro.sample("total_count", dist.Gamma(2.0, 1.0))
    probs = pyro.sample("probs", dist.Beta(1.0, 1.0))
    with pyro.plate("data_plate", len(data)):
        pyro.sample("obs", dist.NegativeBinomial(total_count, probs), obs=data)

guide = AutoDelta(model)
optimizer = PyroAdam({"lr": 0.1})
svi = SVI(model, guide, optimizer, loss=TraceMeanField_ELBO())

num_steps = 5000
for step in range(num_steps):
    loss = svi.step(data_tensor)
    if step % 1000 == 0:
        print(f"Step {step} : loss = {loss:.2f}")

inferred_params = guide.median()
inferred_total_count = inferred_params["total_count"]
inferred_probs = inferred_params["probs"]

print(f"\nInferred total_count: {inferred_total_count}")
print(f"Inferred probs: {inferred_probs}")

In [None]:
# Step 2: Extract p-values for each gene
from scipy.stats import nbinom

r = float(inferred_total_count.item() if hasattr(inferred_total_count, "item") else inferred_total_count)
p = float(1 - inferred_probs.item() if hasattr(inferred_probs, "item") else 1 - inferred_probs)

df_clean["-log_p_val"] = -np.log(nbinom.cdf(df["Perturbed_Notochord_count"].values, r, p))
rnk = df_clean[['Gene', '-log_p_val']]
rnk = rnk.dropna()
rnk = rnk.set_index('Gene')['-log_p_val']
rnk = rnk.sort_values(ascending=False)

print(f"Top 10 genes by -log(p-value):")
print(rnk.head(10))

In [None]:
# Step 3: Perform preranked enrichment analysis
import gseapy as gp

pre_res = gp.prerank(
    rnk=rnk,
    gene_sets=['GO_Biological_Process_2021', 'Reactome_Pathways_2024'],
    organism='Zebrafish',
    permutation_num=1000,
    seed=42
)

# Get top 5 pathways
top_5_terms = pre_res.res2d.Term.head(5).tolist()

print("Top 5 enriched pathways:")
for i, term in enumerate(top_5_terms):
    print(f"{i+1}. {term}")

In [None]:
# Plot Running Enrichment Score for top pathways
pathways = pre_res.res2d['Term'].head(5).tolist()
colors = plt.cm.Set1(np.linspace(0, 1, len(pathways)))

fig, ax = plt.subplots(figsize=(12, 8))

for i, pathway in enumerate(pathways):
    res_scores = pre_res.results[pathway]["RES"]
    ax.plot(range(len(res_scores)), res_scores, linewidth=2, alpha=0.7, color=colors[i], label=pathway)

ax.axhline(y=0, color='black', linestyle='--', alpha=0.5, linewidth=1)
ax.set_xlabel('Gene Rank', fontsize=12)
ax.set_ylabel('Running Enrichment Score (RES)', fontsize=12)
ax.set_title('Running Enrichment Score vs Gene Rank - Top 5 Pathways', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
ax.legend(loc='upper right', fontsize=9, frameon=True, fancybox=True, shadow=True, framealpha=0.9)

plt.tight_layout()
plt.savefig(os.path.join(model_dir, "enrichment_analysis.png"), dpi=150, bbox_inches='tight')
plt.show()

## Save Results

In [None]:
import json

# Save summary of perturbation results
perturbation_summary = {
    "noto_perturbation": {
        "total_cells": len(summary_df),
        "cells_with_fate_change": int((summary_df.first_change_level != -1).sum()),
        "percentage_changed": float((summary_df.first_change_level != -1).mean() * 100),
    },
    "large_scale_ko": {
        "genes_tested": len(df),
        "early_blastomere_cells": int(len(idx)),
    },
    "top_notochord_regulators": top_noto_negative.to_dict('records'),
}

with open(os.path.join(model_dir, "perturbation_summary.json"), "w") as f:
    json.dump(perturbation_summary, f, indent=2)

print(f"Perturbation summary saved to {model_dir}/perturbation_summary.json")
print(f"\nAnalysis complete!")