# pySCENIC regulon analysis
related to extended fig1

## Overview

- **Goal:** infer TF regulons using **pySCENIC** and summarize **cell-group–specific regulon activity**.
- **Inputs (prepared offline in `./data/`):** an AnnData `.h5ad` plus pySCENIC auxiliary resources (TF list, ranking database, motif annotations).
- **Outputs (written to `./results/`):** loom files, AUC matrix, RSS table, top-regulon table, and key figures (PDF/PNG/SVG).


In [None]:
# =========================
# CONFIG + reproducibility
# =========================
from __future__ import annotations
import os
import json
import random
from pathlib import Path
import numpy as np
import pandas as pd
try:
    import scanpy as sc
except Exception as e:
    raise ImportError("scanpy is required for this notebook.") from e
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams["pdf.fonttype"] = 42
mpl.rcParams["ps.fonttype"]  = 42
mpl.rcParams["svg.fonttype"] = "none"
CONFIG = {
    "dataset_id": "HCC_new_Fibro",
    "input_h5ad": "./outputs/h5ad/1_Fibro_count_new.h5ad",
    "group_key": "subtype",
    "layer_for_scenic": None,  # e.g. "counts"
    "aux_dir": "./data/auxiliaries",
    "tf_list": "allTFs_hg38.txt",
    "ranking_db": "hg38_10kbp_up_10kbp_down_full_tx_v10_clust.genes_vs_motifs.rankings.feather",
    "motif_annotations": "motifs-v10nr_clust-nr.hgnc-m0.001-o0.0.tbl",
    "seed": 0,
    "n_workers": 16,
    "top_n_regulons_per_group": 10,
    "out_dir": "./results/pyscenic",
}
seed = int(CONFIG["seed"])
random.seed(seed)
np.random.seed(seed)
try:
    import torch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
except Exception:
    pass
dataset_id = CONFIG["dataset_id"]
out_dir = Path(CONFIG["out_dir"]) / dataset_id
fig_dir = out_dir / "figures"
out_dir.mkdir(parents=True, exist_ok=True)
fig_dir.mkdir(parents=True, exist_ok=True)
aux_dir = Path(CONFIG["aux_dir"])
input_h5ad = Path(CONFIG["input_h5ad"])
paths = {
    "expr_csv": out_dir / f"{dataset_id}.qc.tpm.csv",
    "loom_in":  out_dir / f"{dataset_id}.qc.count.loom",
    "adj":      out_dir / f"{dataset_id}.adjacencies.tsv",
    "motifs":   out_dir / f"{dataset_id}.motifs.csv",
    "loom_out": out_dir / f"{dataset_id}.out.loom",
    "auc_csv":  out_dir / "regulon_auc.csv.gz",
    "rss_csv":  out_dir / "rss.csv",
    "top_csv":  out_dir / "top_regulons_by_group.csv",
    "adata_out": out_dir / "adata_with_regulon_auc.h5ad",
    "cli_sh":   out_dir / "run_pyscenic.sh",
    "config_json": out_dir / "config.json",
}
def _v(modname: str):
    try:
        m = __import__(modname)
        return getattr(m, "__version__", "unknown")
    except Exception:
        return "not-installed"
versions = {
    "python": f"{os.sys.version_info.major}.{os.sys.version_info.minor}.{os.sys.version_info.micro}",
    "numpy": _v("numpy"),
    "pandas": _v("pandas"),
    "scanpy": _v("scanpy"),
    "anndata": _v("anndata"),
    "pyscenic": _v("pyscenic"),
    "loompy": _v("loompy"),
    "matplotlib": _v("matplotlib"),
    "seaborn": _v("seaborn"),
}
print(json.dumps(versions, indent=2))
Path(paths["config_json"]).write_text(
    json.dumps(
        {"CONFIG": CONFIG, "paths": {k: str(v) for k, v in paths.items()}, "versions": versions},
        indent=2,
    )
)
print(f"\nOutputs will be written under: {out_dir.resolve()}")


## Data loading

Data acquisition and preprocessing (QC, normalization, cell-type annotation) were performed **offline**.  
The notebook assumes all required input files are already present locally under `./data/` (or provided by the authors).


In [None]:
# =========================
# Data loading + minimal sanity checks
# =========================
import warnings
warnings.filterwarnings("ignore")
# Fallbacks for common path variants
if not aux_dir.exists():
    alt_aux = Path(str(aux_dir).replace("auxiliaries", "auxilliaries"))
    if alt_aux.exists():
        print(f"INFO: aux_dir '{aux_dir}' not found; using fallback '{alt_aux}'.")
        aux_dir = alt_aux
if not input_h5ad.exists():
    alt_input = Path("./data/1_Fibro_count_new.h5ad")
    if alt_input.exists():
        print(f"INFO: input_h5ad '{input_h5ad}' not found; using fallback '{alt_input}'.")
        input_h5ad = alt_input
ranking_db_path = aux_dir / CONFIG["ranking_db"]
if not ranking_db_path.exists():
    raise FileNotFoundError(f"Ranking DB not found: {ranking_db_path}")
try:
    import pyarrow.feather as feather
    ranking_genes = [c for c in feather.read_table(ranking_db_path).column_names if c != "motifs"]
except Exception:
    ranking_genes = list(pd.read_feather(ranking_db_path).columns)
    if ranking_genes and ranking_genes[-1] == "motifs":
        ranking_genes = ranking_genes[:-1]
ranking_genes = set(ranking_genes)
if not input_h5ad.exists():
    raise FileNotFoundError(f"Input h5ad not found: {input_h5ad}")
adata = sc.read_h5ad(input_h5ad)
group_key = CONFIG["group_key"]
if group_key not in adata.obs.columns:
    raise KeyError(f"Required obs column missing: {group_key}")
genes_before = adata.n_vars
keep_genes = [g for g in adata.var_names if g in ranking_genes]
if len(keep_genes) == 0:
    raise ValueError("No overlap between adata.var_names and ranking DB genes.")
adata = adata[:, keep_genes].copy()
print(f"Loaded adata: n_cells={adata.n_obs:,}  n_genes={adata.n_vars:,}  (before intersection: {genes_before:,})")
print(f"Groups in adata.obs[{group_key!r}]: {adata.obs[group_key].nunique()}")
adata.obs[group_key] = adata.obs[group_key].astype(str)


In [None]:
# =========================
# Export expression + create LOOM for pySCENIC
# =========================
from scipy import sparse

layer = CONFIG["layer_for_scenic"]
if layer is None:
    X = adata.X
else:
    if layer not in adata.layers:
        raise KeyError(f"layer_for_scenic={layer!r} not found in adata.layers")
    X = adata.layers[layer]

if sparse.issparse(X):
    X_csr = X.tocsr()
else:
    X_csr = sparse.csr_matrix(X)

expr_df = pd.DataFrame.sparse.from_spmatrix(X_csr, index=adata.obs_names, columns=adata.var_names)
expr_df.to_csv(paths["expr_csv"], index=True)
print(f"Wrote expression matrix: {paths['expr_csv']}")

import loompy as lp
import numpy as np

row_attrs = {"Gene": np.array(adata.var_names)}
col_attrs = {
    "CellID": np.array(adata.obs_names),
    group_key: np.array(adata.obs[group_key].values, dtype=str),
}

M = X_csr.transpose().astype(np.float32).toarray()  # genes × cells
lp.create(str(paths["loom_in"]), M, row_attrs, col_attrs)
print(f"Wrote LOOM input: {paths['loom_in']}")

## Core analysis

1. Match input genes to the cisTarget ranking database gene universe.
2. Create a **LOOM** file for pySCENIC input.
3. Run pySCENIC CLI (**grn → ctx → aucell**) via an autogenerated shell script (no installs/downloads here).
4. Load the `*.out.loom`, compute **regulon specificity scores (RSS)** by a chosen `group_key`.


In [None]:
# =========================
# pySCENIC CLI script (grn → ctx → aucell)
# =========================


tf_list_path = aux_dir / CONFIG["tf_list"]
motif_anno_path = aux_dir / CONFIG["motif_annotations"]
for p in [tf_list_path, motif_anno_path]:
    if not p.exists():
        raise FileNotFoundError(f"Required auxiliary file not found: {p}")

cmds = "#!/usr/bin/env bash\n"
cmds += "set -euo pipefail\n\n"

cmds += "# GRN inference (GRNBoost2)\n"
cmds += "pyscenic grn \\\n"
cmds += f'  "{paths["loom_in"]}" \\\n'
cmds += f'  "{tf_list_path}" \\\n'
cmds += f'  -o "{paths["adj"]}" \\\n'
cmds += f'  --num_workers {int(CONFIG["n_workers"])} \\\n'
cmds += "  --method grnboost2\n\n"

cmds += "# Regulon prediction (cisTarget / ctx)\n"
cmds += "pyscenic ctx \\\n"
cmds += f'  "{paths["adj"]}" \\\n'
cmds += f'  "{ranking_db_path}" \\\n'
cmds += f'  --annotations_fname "{motif_anno_path}" \\\n'
cmds += f'  --expression_mtx_fname "{paths["loom_in"]}" \\\n'
cmds += f'  --output "{paths["motifs"]}" \\\n'
cmds += f'  --num_workers {int(CONFIG["n_workers"])} \\\n'
cmds += "  --mode custom_multiprocessing\n\n"

cmds += "# AUCell scoring\n"
cmds += "pyscenic aucell \\\n"
cmds += f'  "{paths["loom_in"]}" \\\n'
cmds += f'  "{paths["motifs"]}" \\\n'
cmds += f'  --output "{paths["loom_out"]}" \\\n'
cmds += f'  --num_workers {int(CONFIG["n_workers"])}\n'

from pathlib import Path
import os
Path(paths["cli_sh"]).write_text(cmds)
os.chmod(paths["cli_sh"], 0o755)

print(f"Wrote CLI script: {paths['cli_sh']}")
print(f"Run: bash {paths['cli_sh']}")

In [None]:
# =========================
# Load AUCell output + compute RSS
# =========================
from pathlib import Path
import loompy as lp
import numpy as np

if not Path(paths["loom_out"]).exists():
    raise FileNotFoundError(
        f"Missing AUCell output loom: {paths['loom_out']}\n"
        f"Run the generated CLI script first: bash {paths['cli_sh']}"
    )

def load_regulon_auc_from_loom(loom_path: str) -> pd.DataFrame:
    # Robustly load regulon AUC from a pySCENIC `*.out.loom`.
    # Returns: DataFrame with index=CellID and columns=regulon names (best-effort).
    with lp.connect(loom_path, mode="r") as lf:
        if "RegulonsAUC" not in lf.ca:
            raise KeyError("Expected column attribute 'RegulonsAUC' not found in loom.ca")

        cell_ids = lf.ca["CellID"]
        cell_ids = np.array(cell_ids).astype(str)

        arr = np.array(lf.ca["RegulonsAUC"])
        if arr.shape[0] == len(cell_ids):
            auc = pd.DataFrame(arr, index=cell_ids)
        elif arr.shape[1] == len(cell_ids):
            auc = pd.DataFrame(arr.T, index=cell_ids)
        else:
            raise ValueError(f"Unexpected RegulonsAUC shape={arr.shape} vs n_cells={len(cell_ids)}")

        reg_names = None
        for d in [lf.ca, lf.ra]:
            for k in d.keys():
                v = d[k]
                if isinstance(v, np.ndarray) and v.ndim == 1 and len(v) == auc.shape[1] and ("regulon" in k.lower()):
                    reg_names = list(v.astype(str))
                    break
            if reg_names is not None:
                break

        if reg_names is None:
            reg_names = [f"Regulon_{i}" for i in range(auc.shape[1])]

        auc.columns = reg_names
        return auc

auc_mtx = load_regulon_auc_from_loom(str(paths["loom_out"]))

common_cells = adata.obs_names.intersection(auc_mtx.index)
if len(common_cells) == 0:
    raise ValueError("No overlapping CellIDs between input adata and out.loom AUC matrix.")
auc_mtx = auc_mtx.loc[common_cells]
adata_sub = adata[common_cells].copy()

print(f"AUC matrix: cells={auc_mtx.shape[0]:,} regulons={auc_mtx.shape[1]:,}")
auc_mtx.to_csv(paths["auc_csv"], compression="gzip")
print(f"Wrote: {paths['auc_csv']}")

from pyscenic.rss import regulon_specificity_scores
rss = regulon_specificity_scores(auc_mtx, adata_sub.obs[group_key])
rss.to_csv(paths["rss_csv"])
print(f"Wrote: {paths['rss_csv']}")

## Results & exports

This section saves:
- `regulon_auc.csv.gz` (cells × regulons)
- `rss.csv` (regulons × groups)
- `top_regulons_by_group.csv`
- figures: `rss_heatmap.*`, `rss_barplots.*`, and optional UMAP overlays if available
- `adata_with_regulon_auc.h5ad` (AnnData augmented with regulon AUC in `obsm`)


In [None]:
# =========================
# Top regulons per group + figures
# =========================
import seaborn as sns

top_n = int(CONFIG["top_n_regulons_per_group"])

rows = []
for g in rss.columns:
    s = rss[g].sort_values(ascending=False).head(top_n)
    for regulon, val in s.items():
        rows.append({"group": g, "regulon": regulon, "rss": float(val)})
top_tbl = pd.DataFrame(rows).sort_values(["group", "rss"], ascending=[True, False])
top_tbl.to_csv(paths["top_csv"], index=False)
print(f"Wrote: {paths['top_csv']}")

rss_z = rss.copy()
for g in rss_z.columns:
    mu = rss_z[g].mean()
    sd = rss_z[g].std(ddof=0)
    rss_z[g] = (rss_z[g] - mu) / (sd if sd != 0 else 1.0)

top_regulons = sorted(set(top_tbl["regulon"].tolist()))
rss_z_sub = rss_z.loc[top_regulons]

sns.set_context("paper")

g = sns.clustermap(
    rss_z_sub,
    cmap="vlag",
    center=0,
    linewidths=0.2,
    figsize=(max(6, 0.25 * rss_z_sub.shape[1] + 2), max(6, 0.18 * rss_z_sub.shape[0] + 2)),
    yticklabels=True,
    xticklabels=True,
)
heatmap_pdf = fig_dir / "rss_heatmap.pdf"
heatmap_png = fig_dir / "rss_heatmap.png"
g.savefig(heatmap_pdf, bbox_inches="tight")
g.savefig(heatmap_png, dpi=300, bbox_inches="tight")
plt.close(g.fig)
print(f"Saved: {heatmap_pdf}")
print(f"Saved: {heatmap_png}")

from pyscenic.plotting import plot_rss
n_groups = rss.shape[1]
n_cols = 4
n_rows = int(np.ceil(n_groups / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 3*n_rows), squeeze=False)

for ax, grp in zip(axes.ravel(), rss.columns):
    plot_rss(rss, grp, top_n=top_n, ax=ax)
    ax.set_title(grp)

for ax in axes.ravel()[n_groups:]:
    ax.axis("off")

bar_pdf = fig_dir / "rss_barplots.pdf"
bar_png = fig_dir / "rss_barplots.png"
fig.tight_layout()
fig.savefig(bar_pdf)
fig.savefig(bar_png, dpi=300)
plt.close(fig)
print(f"Saved: {bar_pdf}")
print(f"Saved: {bar_png}")

if "X_umap" in adata_sub.obsm:
    best = top_tbl.sort_values("rss", ascending=False).groupby("group", as_index=False).head(1)
    regulons_to_plot = best["regulon"].tolist()
    for r in regulons_to_plot:
        adata_sub.obs[f"AUC:{r}"] = auc_mtx[r].values

    sc.pl.umap(
        adata_sub,
        color=[group_key] + [f"AUC:{r}" for r in regulons_to_plot],
        wspace=0.4,
        frameon=False,
        show=False,
    )
    umap_pdf = fig_dir / "umap_regulon_auc.pdf"
    plt.savefig(umap_pdf, bbox_inches="tight")
    plt.close()
    print(f"Saved: {umap_pdf}")
else:
    print("UMAP not found in adata.obsm['X_umap']; skipped UMAP overlays.")

In [None]:
# =========================
# Save augmented AnnData (AUC in obsm) + final paths
# =========================
adata_sub.obsm["X_regulon_auc"] = auc_mtx.values.astype(np.float32)
adata_sub.uns["regulon_auc_columns"] = list(auc_mtx.columns)
adata_sub.uns["rss_table_path"] = str(paths["rss_csv"])
adata_sub.uns["top_regulons_table_path"] = str(paths["top_csv"])

adata_sub.write_h5ad(paths["adata_out"])
print(f"Wrote: {paths['adata_out']}")

print("\nKey outputs:")
for k in ["cli_sh", "loom_in", "adj", "motifs", "loom_out", "auc_csv", "rss_csv", "top_csv", "adata_out"]:
    print(f"- {k}: {paths[k]}")
print(f"- figures: {fig_dir}")

# scFEA metabolic flux analysis (Fibro example)
related to extended fig1

**Inputs (in `./data/`).**
1. `1_Fibro_count_new.h5ad`: AnnData with `obs['subtype']` (fibroblast subtypes).
2. `Fibro_flux.csv`: scFEA-predicted flux matrix (`cells × modules`, columns like `M_1 ...`).  
3. `Human_M168_information.symbols.csv`: module annotations (at least `Compound_IN_name`, `Compound_OUT_name`).

**Outputs (written to `./outputs/scfea_fibro/`).**
- Heatmap (PDF/SVG) of subtype signature modules (row z-scores).
- CSV tables: mean flux per subtype, z-scored matrix, per-subtype top modules.
- A compact `.h5ad` subset containing the matched cells and selected module fluxes (optional but reproducible).


## Data loading

This notebook can start from an existing `Fibro_flux.csv`.  
If you **do not** have it yet, run the **scFEA upstream** cells below to generate it from the `.h5ad`.


In [None]:
from pathlib import Path
import os
import random
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.colors import LinearSegmentedColormap

DATA_DIR = Path("./data")
OUT_DIR  = Path("./outputs/scfea_fibro")
OUT_DIR.mkdir(parents=True, exist_ok=True)

H5AD_PATH        = DATA_DIR / "1_Fibro_count_new.h5ad"
FLUX_CSV_PATH    = DATA_DIR / "Fibro_flux.csv"
MODULE_INFO_PATH = DATA_DIR / "Human_M168_information.symbols.csv"

SEED = 0
SUBTYPE_KEY = "subtype"
DROP_SUBTYPES = ["myFibro_TAGLN", "Pericyte"]
TOP_K = 8

random.seed(SEED)
np.random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)

mpl.rcParams["pdf.fonttype"] = 42
mpl.rcParams["ps.fonttype"]  = 42
mpl.rcParams["svg.fonttype"] = "none"

def _safe_version(pkg_name: str) -> str:
    try:
        import importlib
        mod = importlib.import_module(pkg_name)
        return getattr(mod, "__version__", "unknown")
    except Exception:
        return "not installed"

print("Versions:",
      f"scanpy={_safe_version('scanpy')}",
      f"anndata={_safe_version('anndata')}",
      f"numpy={_safe_version('numpy')}",
      f"pandas={_safe_version('pandas')}",
      f"matplotlib={_safe_version('matplotlib')}",
      sep="\n  - ")
print("\nCONFIG:")
print("  H5AD_PATH       =", H5AD_PATH)
print("  FLUX_CSV_PATH   =", FLUX_CSV_PATH)
print("  MODULE_INFO_PATH=", MODULE_INFO_PATH)
print("  OUT_DIR         =", OUT_DIR)

### scFEA upstream to generate Fibro_flux.csv

Run this section only if `./data/Fibro_flux.csv` is missing.

What it does:
1. Subset the AnnData by fibro subtypes (drops `DROP_SUBTYPES`).
2. Export scFEA input matrix (**genes × cells**) from raw counts if available.
3. Call `src/scFEA.py` to generate `Fibro_flux.csv` and `Fibro_balance.csv`.

Requirements:
- A local scFEA repo checkout (set `SCFEA_DIR`, e.g. `/content/scFEA` on Colab).
- `module_gene_m168.csv` and `cmMat_c70_m168.csv` available under `SCFEA_DIR/data/` (the code will also try to copy them from `./data/` if present).


In [None]:
import sys
import subprocess
import shutil

try:
    from scipy.sparse import issparse
except Exception as e:
    raise ImportError("scipy is required for exporting scFEA input (sparse matrix handling).") from e

# -------------------------
# scFEA repository path
# -------------------------
# EDIT ME if needed:
#   - Colab example: SCFEA_DIR = Path("/content/scFEA")
#   - Local example: SCFEA_DIR = Path("./scFEA")
SCFEA_DIR = Path(os.environ.get("SCFEA_DIR", "./scFEA"))
SCFEA_SRC = SCFEA_DIR / "src" / "scFEA.py"

SCFEA_DATA_DIR  = SCFEA_DIR / "data"
SCFEA_INPUT_DIR = SCFEA_DIR / "input"
SCFEA_OUT_DIR   = SCFEA_DIR / "output"

SCFEA_TEST_FILE = "Fibro_scFEA_input.csv"
SCFEA_INPUT_CSV = SCFEA_INPUT_DIR / SCFEA_TEST_FILE

# scFEA reference filenames (expected under SCFEA_DATA_DIR)
SCFEA_MODULE_GENE_FILE = "module_gene_m168.csv"
SCFEA_STOICH_FILE      = "cmMat_c70_m168.csv"

# Where to write scFEA results used by downstream analysis
BALANCE_CSV_PATH = DATA_DIR / "Fibro_balance.csv"

if SCFEA_DIR.exists():
    print("SCFEA_DIR:", SCFEA_DIR.resolve())
    print("SCFEA_SRC:", SCFEA_SRC)
else:
    print("WARNING: SCFEA_DIR does not exist:", SCFEA_DIR)
    print("         If you want to run scFEA here, clone the repo and set SCFEA_DIR accordingly.")

print("Will use (downstream) FLUX_CSV_PATH:", FLUX_CSV_PATH.resolve())
print("Will write BALANCE_CSV_PATH:", BALANCE_CSV_PATH.resolve())
print("Will write scFEA input to:", SCFEA_INPUT_CSV)


In [None]:
def export_scfea_input_from_adata(
    adata_in: sc.AnnData,
    out_csv: Path,
    prefer_layers=("raw_counts", "raw"),
    make_unique_genes: bool = True,
) -> pd.DataFrame:
    """Export scFEA input CSV (genes × cells).

    scFEA expects:
      - rows: gene symbols
      - columns: cell IDs
    """
    X = None
    genes = None

    # 1) Prefer raw-count layers if present
    for layer in prefer_layers:
        if layer in adata_in.layers:
            X = adata_in.layers[layer]
            genes = adata_in.var_names
            print(f"Using adata.layers['{layer}'] for scFEA input.")
            break

    # 2) Next try adata.raw.X
    if X is None and getattr(adata_in, "raw", None) is not None:
        X = adata_in.raw.X
        genes = adata_in.raw.var_names
        print("Using adata.raw.X for scFEA input.")

    # 3) Fallback to X (make sure it is not log1p-normalized)
    if X is None:
        X = adata_in.X
        genes = adata_in.var_names
        print("Using adata.X for scFEA input (fallback). Make sure it is not log1p-normalized.")

    # Ensure dense (cells × genes)
    if issparse(X):
        X = X.toarray()
    X = np.asarray(X)

    # scFEA needs genes × cells
    mat = X.T

    df = pd.DataFrame(
        mat,
        index=pd.Index(genes.astype(str), name="gene"),
        columns=pd.Index(adata_in.obs_names.astype(str), name="cell"),
    )

    # Sanitize CSV-breaking characters
    df.index = df.index.str.replace(",", "_", regex=False)
    df.columns = df.columns.str.replace(",", "_", regex=False)

    # Handle duplicated gene symbols
    if make_unique_genes and df.index.duplicated().any():
        n_dup = int(df.index.duplicated().sum())
        print(f"WARNING: duplicated gene symbols detected (n={n_dup}). Collapsing by sum.")
        df = df.groupby(df.index, sort=False).sum()

    out_csv = Path(out_csv)
    out_csv.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(out_csv)

    print("Wrote scFEA input CSV:", out_csv.resolve())
    print("scFEA input shape (genes × cells):", df.shape)
    return df


In [None]:
import sys
import subprocess
import shutil

adata_all = sc.read_h5ad(H5AD_PATH)
if SUBTYPE_KEY not in adata_all.obs.columns:
    raise KeyError(f"AnnData.obs is missing required column: {SUBTYPE_KEY}")

# 1) Subset cells used for scFEA (drop unwanted subtypes)
adata = adata_all[~adata_all.obs[SUBTYPE_KEY].isin(DROP_SUBTYPES)].copy()

# 2) Clean cell IDs early (commas break CSV and will also break scFEA I/O)
adata.obs_names = adata.obs_names.astype(str).str.replace(",", "_", regex=False)

# Local path for balance output (kept for completeness)
BALANCE_CSV_PATH = globals().get("BALANCE_CSV_PATH", DATA_DIR / "Fibro_balance.csv")

# 3) If Fibro_flux.csv is missing, try to generate it by running scFEA here
if not FLUX_CSV_PATH.exists():
    print(f"Missing flux file: {FLUX_CSV_PATH} -> will try to run scFEA to generate it.")

    # If upstream config cells were not executed, define sane defaults
    if "SCFEA_DIR" not in globals():
        SCFEA_DIR = Path(os.environ.get("SCFEA_DIR", "./scFEA"))
        SCFEA_SRC = SCFEA_DIR / "src" / "scFEA.py"
        SCFEA_DATA_DIR  = SCFEA_DIR / "data"
        SCFEA_INPUT_DIR = SCFEA_DIR / "input"
        SCFEA_OUT_DIR   = SCFEA_DIR / "output"
        SCFEA_TEST_FILE = "Fibro_scFEA_input.csv"
        SCFEA_INPUT_CSV = SCFEA_INPUT_DIR / SCFEA_TEST_FILE
        SCFEA_MODULE_GENE_FILE = "module_gene_m168.csv"
        SCFEA_STOICH_FILE      = "cmMat_c70_m168.csv"

    if "export_scfea_input_from_adata" not in globals():
        raise NameError("export_scfea_input_from_adata() not found. Run the scFEA upstream cells above first.")

    # Guardrails: only run if scFEA repo looks available
    if not SCFEA_DIR.exists():
        raise FileNotFoundError(
            "SCFEA_DIR does not exist.\n"
            f"Current SCFEA_DIR={SCFEA_DIR}\n"
            "Clone scFEA repo and/or set SCFEA_DIR (see upstream config cell)."
        )
    if not SCFEA_SRC.exists():
        raise FileNotFoundError(
            "Cannot find scFEA entry script.\n"
            f"Expected: {SCFEA_SRC}\n"
            "Check your scFEA repo layout (should contain 'src/scFEA.py')."
        )

    # Ensure folders
    SCFEA_INPUT_DIR.mkdir(parents=True, exist_ok=True)
    SCFEA_OUT_DIR.mkdir(parents=True, exist_ok=True)
    SCFEA_DATA_DIR.mkdir(parents=True, exist_ok=True)

    # Make sure scFEA reference files exist under SCFEA_DATA_DIR
    needed = [
        SCFEA_DATA_DIR / SCFEA_MODULE_GENE_FILE,
        SCFEA_DATA_DIR / SCFEA_STOICH_FILE,
    ]
    missing = [p for p in needed if not p.exists()]
    if missing:
        print("Missing scFEA reference files under SCFEA_DATA_DIR:")
        for p in missing:
            print(" -", p)
        print("Will try to copy from ./data/ if available ...")

        for p in missing:
            alt = DATA_DIR / p.name
            if alt.exists():
                shutil.copy2(alt, p)
                print(f"Copied {alt} -> {p}")

        missing = [p for p in needed if not p.exists()]
        if missing:
            raise FileNotFoundError(
                "Still missing required scFEA reference files.\n"
                + "\n".join([str(p) for p in missing])
                + "\nPut them under SCFEA_DIR/data/ (or ./data/) and rerun."
            )

    # Export scFEA input CSV (genes × cells)
    _ = export_scfea_input_from_adata(adata, SCFEA_INPUT_CSV)

    # Run scFEA (write outputs under SCFEA_DIR/output/, then copy flux/balance into ./data/)
    cmd = [
        sys.executable, str(SCFEA_SRC),
        "--data_dir", "data",
        "--input_dir", "input",
        "--test_file", SCFEA_TEST_FILE,
        "--moduleGene_file", SCFEA_MODULE_GENE_FILE,
        "--stoichiometry_matrix", SCFEA_STOICH_FILE,
        "--sc_imputation", "True",
        "--output_flux_file", "output/Fibro_flux.csv",
        "--output_balance_file", "output/Fibro_balance.csv",
    ]
    print("Running scFEA with command:")
    print("  (cwd =", SCFEA_DIR.resolve(), ")")
    print(" ", " ".join(cmd))

    subprocess.run(cmd, cwd=str(SCFEA_DIR), check=True)

    # Copy results into ./data/ so the downstream analysis is path-stable
    scfea_flux_out = SCFEA_OUT_DIR / "Fibro_flux.csv"
    scfea_bal_out  = SCFEA_OUT_DIR / "Fibro_balance.csv"

    if not scfea_flux_out.exists():
        raise FileNotFoundError(f"scFEA finished but flux file not found: {scfea_flux_out}")

    shutil.copy2(scfea_flux_out, FLUX_CSV_PATH)
    print("Copied scFEA flux ->", FLUX_CSV_PATH.resolve())

    if scfea_bal_out.exists():
        shutil.copy2(scfea_bal_out, BALANCE_CSV_PATH)
        print("Copied scFEA balance ->", BALANCE_CSV_PATH.resolve())
    else:
        print("WARNING: scFEA balance file not found:", scfea_bal_out)

# 4) Load scFEA flux and match cell IDs
flux_df = pd.read_csv(FLUX_CSV_PATH, index_col=0)

flux_df.index = flux_df.index.astype(str).str.replace(",", "_", regex=False)

common_cells = flux_df.index.intersection(adata.obs_names)
if len(common_cells) == 0:
    raise ValueError("No overlapping cell IDs between flux CSV and AnnData.")

flux_df = flux_df.loc[common_cells].copy()
adata   = adata[common_cells].copy()

print("Matched cells:", len(common_cells))
print("flux_df shape (cells × modules):", flux_df.shape)
print("adata shape (cells × genes):", adata.shape)
print("n_subtypes:", adata.obs[SUBTYPE_KEY].nunique())


## Core analysis

For each module, compute subtype-wise mean flux, then perform **row-wise z-scoring** (module-centered) across subtypes.  
Select the **top-k modules per subtype** (by z-score) as subtype “signature” modules, take the union set, and visualize as a heatmap.

In [None]:
flux_by_subtype = flux_df.groupby(adata.obs[SUBTYPE_KEY]).mean().T

flux_z = flux_by_subtype.copy()
flux_z = flux_z.sub(flux_z.mean(axis=1), axis=0)
flux_z = flux_z.div(flux_z.std(axis=1, ddof=0).replace(0, np.nan), axis=0).fillna(0.0)

modules_per_subtype = {}
for st in flux_z.columns:
    modules_per_subtype[st] = (
        flux_z[st].sort_values(ascending=False).head(TOP_K).index.tolist()
    )

selected_modules = []
for st in flux_z.columns:
    for m in modules_per_subtype[st]:
        if m not in selected_modules:
            selected_modules.append(m)

print("Selected modules (union) =", len(selected_modules))
print("Example (first 10):", selected_modules[:10])

In [None]:
mod_info = pd.read_csv(MODULE_INFO_PATH, index_col=0)

required_cols = {"Compound_IN_name", "Compound_OUT_name"}
missing = required_cols - set(mod_info.columns)
if missing:
    raise KeyError(f"Module info CSV is missing columns: {sorted(missing)}")

mod_info["pretty"] = (
    mod_info.index.astype(str)
    + ": "
    + mod_info["Compound_IN_name"].astype(str)
    + " \u2192 "
    + mod_info["Compound_OUT_name"].astype(str)
)

flux_sel = flux_z.loc[selected_modules].copy()

subtype_order = list(pd.unique(adata.obs[SUBTYPE_KEY]))
subtype_order = [st for st in subtype_order if st in flux_sel.columns]
flux_sel = flux_sel[subtype_order]

row_labels = []
for m in flux_sel.index:
    row_labels.append(mod_info.loc[m, "pretty"] if m in mod_info.index else m)

flux_sel_labeled = flux_sel.copy()
flux_sel_labeled.index = row_labels

In [None]:
def plot_flux_heatmap(
    df_modules_x_subtypes: pd.DataFrame,
    vmin: float = -2,
    vmax: float =  2,
    figsize=(12, 4),
    out_pdf: Path = OUT_DIR / "scFEA_flux_signature_fibro.pdf",
    out_svg: Path = OUT_DIR / "scFEA_flux_signature_fibro.svg",
):
    cmap = LinearSegmentedColormap.from_list(
        "flux_cmap",
        ["#bbd9f2", "#c9e0f4", "#d7e7f7", "#f9fcf7",
         "#e8a8c2", "#e084a9", "#a31515"]
    )

    mat = df_modules_x_subtypes.values
    row_names = df_modules_x_subtypes.index.tolist()
    col_names = df_modules_x_subtypes.columns.tolist()
    n_rows, n_cols = mat.shape

    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(
        mat, aspect="auto", interpolation="nearest", cmap=cmap,
        vmin=vmin, vmax=vmax, origin="upper"
    )
    im.set_rasterized(True)
    ax.grid(False)

    for i in range(n_rows):
        for j in range(n_cols):
            ax.add_patch(Rectangle((j - 0.5, i - 0.5), 1.0, 1.0,
                                   fill=False, edgecolor="black", linewidth=0.8))
    ax.add_patch(Rectangle((-0.5, -0.5), n_cols, n_rows,
                           fill=False, edgecolor="black", linewidth=1.2))

    ax.set_xticks(np.arange(n_cols))
    ax.set_xticklabels(col_names, rotation=60, ha="right", va="top", fontsize=9)
    ax.set_yticks(np.arange(n_rows))
    ax.set_yticklabels(row_names, fontsize=7.5)

    for spine in ax.spines.values():
        spine.set_visible(False)

    ax.set_xlim(-0.5, n_cols - 0.5)
    ax.set_ylim(n_rows - 0.5, -0.5)
    cbar = plt.colorbar(im, ax=ax, fraction=0.03, pad=0.02)
    cbar.set_label("Predicted metabolic flux (row z-score)", rotation=90)

    fig.tight_layout()
    fig.savefig(out_pdf, dpi=300, bbox_inches="tight")
    fig.savefig(out_svg, dpi=300, bbox_inches="tight")
    return fig, ax

_ = plot_flux_heatmap(flux_sel_labeled, figsize=(12, max(3, 0.18 * flux_sel_labeled.shape[0])))
plt.show()

## Results & exports

This section writes all key artifacts to disk and prints the final output paths.

In [None]:
flux_by_subtype.to_csv(OUT_DIR / "flux_mean_by_subtype_modules_x_subtypes.csv")
flux_z.to_csv(OUT_DIR / "flux_row_zscore_modules_x_subtypes.csv")
flux_sel.to_csv(OUT_DIR / "flux_row_zscore_selected_modules_x_subtypes.csv")

rows = []
for st, mods in modules_per_subtype.items():
    for rank, m in enumerate(mods, start=1):
        pretty = mod_info.loc[m, "pretty"] if m in mod_info.index else m
        rows.append({
            "subtype": st,
            "rank": rank,
            "module_id": m,
            "module_pretty": pretty,
            "zscore": float(flux_z.loc[m, st]),
        })
top_table = pd.DataFrame(rows)
top_table.to_csv(OUT_DIR / "top_modules_per_subtype.csv", index=False)

adata_out = adata.copy()
selected_flux_cell = flux_df[selected_modules].copy()
adata_out.obsm["scfea_flux_selected"] = selected_flux_cell.values
adata_out.uns["scfea_flux_selected_columns"] = selected_flux_cell.columns.tolist()
adata_out.write_h5ad(OUT_DIR / "Fibro_scFEA_matched_cells_with_selected_flux.h5ad")

print("Wrote outputs to:", OUT_DIR.resolve())
for p in [
    OUT_DIR / "scFEA_flux_signature_fibro.pdf",
    OUT_DIR / "scFEA_flux_signature_fibro.svg",
    OUT_DIR / "top_modules_per_subtype.csv",
    OUT_DIR / "flux_row_zscore_selected_modules_x_subtypes.csv",
    OUT_DIR / "Fibro_scFEA_matched_cells_with_selected_flux.h5ad",
]:
    print(" -", p)