# DeltaNMF demo: loading `test_data.h5ad` and running one-stage + two-stage

This notebook demonstrates how `resources/test_data.h5ad` is structured and how to run:

1. **One-stage** DeltaNMF (control-only) with **no** foundation-model (FM) regularization  
2. **One-stage** DeltaNMF with **FM** regularization  
3. **Two-stage** DeltaNMF (case-control)

`test_data.h5ad` is a small, GitHub-friendly subset sampled from the Engreitz lab Perturb-seq dataset (control vs perturbed cells), downsampled to 1000 number of cells per group (https://pubmed.ncbi.nlm.nih.gov/38326615/).


In [None]:
# Notebook setup
from pathlib import Path
import sys
import numpy as np
import anndata as ad


## Paths + imports (no install)

Adds repo root to `sys.path` so `import deltanmf` works without pip.

In [None]:
repo_root = Path.cwd().resolve()
if not (repo_root / "deltanmf").exists() and (repo_root.parent / "deltanmf").exists():
    repo_root = repo_root.parent
sys.path.insert(0, str(repo_root))

from deltanmf.api import run_onestage_deltanmf, run_twostage_deltanmf
from deltanmf.io import h5ad_to_npy

resources_dir = repo_root / "resources"
h5ad_path = resources_dir / "test_data.h5ad"
out_dir = resources_dir / "out_notebook"
out_dir.mkdir(parents=True, exist_ok=True)


## What’s inside `test_data.h5ad`

DeltaNMF runners expect:
- `adata.X`: **cells × genes**
- `adata.obs["condition"]`: contains `"NTC"` and `"CASE"`
- `adata.var_names`: gene names (strings)

`h5ad_to_npy(...)` converts to:
- `X_ntc`: **genes × control_cells**
- `X_spec`: **genes × case_cells**
- `gene_names`: length = n_genes


In [None]:
adata = ad.read_h5ad(h5ad_path)
print("adata.X (cells x genes):", adata.X.shape)
print("condition counts:\n", adata.obs["condition"].value_counts())
print("genes:", adata.n_vars, "cells:", adata.n_obs)


In [None]:
X_ntc, X_spec, gene_names = h5ad_to_npy(
    h5ad_path,
    ntc_key="NTC",
    condition_key="condition",
    case_key=None,   # if set, only that label becomes CASE; else CASE = not NTC
    layer=None,      # set to a layer name to use adata.layers[layer]
)
print("X_ntc:", X_ntc.shape, "(genes x control_cells)")
print("X_spec:", X_spec.shape, "(genes x case_cells)")
print("gene_names:", gene_names.shape)


## FM resources (scGPT or TranscriptFormer)

Optional FM regularization uses:
- `resources/<model>/S_E_relu.npy`
- `resources/<model>/genes_order.json`

Set `model_name` to `"scgpt"` or `"transcriptformer"`.


In [None]:
model_name = "scgpt"  # or "transcriptformer"
S_E_PATH = resources_dir / model_name / "S_E_relu.npy"
S_E_GENES_PATH = resources_dir / model_name / "genes_order.json"

for p in [h5ad_path, S_E_PATH, S_E_GENES_PATH]:
    if not p.exists():
        raise FileNotFoundError(f"Missing: {p}")


# 1) One-stage (control-only), no FM loss

Key args:
- `K`: number of programs
- `MIN_CELLS`: per-gene presence filter used during alignment
- `rel_alpha=0.0`: disables FM regularization
- `use_minibatch_ntc`: choose full-batch (`False`) or minibatch (`True`) for NTC fitting


In [None]:
res_one_no_fm = run_onestage_deltanmf(
    X_control=X_ntc,
    gene_names=gene_names,
    S_E_PATH=S_E_PATH,
    S_E_GENES_PATH=S_E_GENES_PATH,
    K=30,
    MIN_CELLS=10,
    REMOVE_GENES=[],
    rel_alpha=0.0,
    max_iter=10000,
    lr=0.01,
    use_minibatch_ntc=False,  # full-batch
)

print("W:", res_one_no_fm["W"].shape, "(genes x K)")
print("H:", res_one_no_fm["H"].shape, "(K x control_cells)")


# 2) One-stage with FM loss

Turn on FM by setting `rel_alpha > 0` (recommend 1).

This example also enables minibatching for NTC fitting.


In [None]:
res_one_fm = run_onestage_deltanmf(
    X_control=X_ntc,
    gene_names=gene_names,
    S_E_PATH=S_E_PATH,
    S_E_GENES_PATH=S_E_GENES_PATH,
    K=30,
    MIN_CELLS=10,
    REMOVE_GENES=[],
    rel_alpha=0.05,
    max_iter=10000,
    lr=0.01,
    use_minibatch_ntc=True,
    minibatch_size_ntc=40960,
)

print("W:", res_one_fm["W"].shape, "(genes x K)")
print("H:", res_one_fm["H"].shape, "(K x control_cells)")


# 3) Two-stage (case-control)

Stage 1: fit baseline programs on control cells (`K_stage1`)  
Stage 2: fit case-specific programs on case residuals and refit on case cells (`K_stage2`)

Key args:
- `stage1_rel_alpha`: FM on baseline (stage 1)
- `stage2_rel_alpha`: FM on case-specific (stage 2)
- `stage2_rel_gamma`: optional baseline vs case-specific orthogonality
- `stage1_use_minibatch_ntc`: choose full-batch (`False`) or minibatch (`True`) for stage-1 NTC fitting


In [None]:
res_two = run_twostage_deltanmf(
    X_control=X_ntc,
    X_case=X_spec,
    gene_names=gene_names,
    S_E_PATH=S_E_PATH,
    S_E_GENES_PATH=S_E_GENES_PATH,
    K_stage1=30,
    K_stage2=60,
    MIN_CELLS=10,
    REMOVE_GENES=[],
    stage1_rel_alpha=0.0,
    stage2_rel_alpha=0.0,
    stage2_rel_gamma=0.0,
    stage1_max_iter=10000,
    stage2_max_iter=10000,
    lr=0.01,
    stage1_use_minibatch_ntc=True,
    stage1_minibatch_size_ntc=40960,
)

print("W_stage1:", res_two["W_stage1"].shape, "(genes x K_stage1)")
print("H_stage1:", res_two["H_stage1"].shape, "(K_stage1 x control_cells)")
print("W_stage2:", res_two["W_stage2"].shape, "(genes x K_stage2)")
print("H_stage2:", res_two["H_stage2"].shape, "(K_stage1+K_stage2 x case_cells)")
