In [None]:
# Inputs:
#   - 11282025_PaperRevision_Exp2_Lahi_report.pg_matrix (1).tsv  (protein x sample intensities)
# Output:
#   - exp2_rev_psupertime.h5ad (AnnData with psupertime results in .obs)
#
# Install (if needed):
#   pip install anndata scanpy pandas numpy scipy scikit-learn
#   pip install git+https://github.com/AlexandreHutton/pyPsupertime.git  # updated version of pyPsupertime

import os, re
import numpy as np
import pandas as pd
import anndata as ad
import sklearn
from pypsupertime import Psupertime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import time
import pickle
import numpy as np
import pandas as pd
import pywt
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.linear_model import LinearRegression
from typing import Union
from pywt import Wavelet

from scTransient.windowing import ConfinedGaussianWindow
import os
import time
import pickle
import numpy as np
import pandas as pd
import pywt
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.linear_model import LinearRegression
from scTransient.utils import permutation_dist
from scTransient.metrics import transient_event_score
from scTransient.utils import convert_to_signal
from scTransient.windowing import ConfinedGaussianWindow
from scipy.stats import false_discovery_control
from scTransient.wavelets import WaveletTransform

In [None]:
# Parameters

data_path = "11282025_PaperRevision_Exp2_Lahi_report.pg_matrix.tsv"
min_detect_frac = 0.30  # minimum detection fraction, used for filtering out samples

h5ad_path = "exp2_subset_for_psupertime.h5ad"  # intermediate data path for psupertime

ptime_col = "psupertime"  # column where pseudotime is stored
n_windows = 30  # number of windows to use to convert samples to pseudotime signal
sigma = 0.10  # std of gaussian window
max_distance = 0.25  # max distance before cutting to 0 for confined gaussian window

wavelet = "mexh"  # wavelet to use
scales = np.array([2, 4, 8])  # scales to use for wavelet

zscore_genes = True  # whether to convert abundances to z-scores
detrend_monotonic = True  # whether to remove linear trends
edge_frac = 0.02

do_perms = True  # whether to do permutation testing to get p-values(takes a while)
n_perms = 10  # number of permutations
topK_for_perms = 5000  # top genes to select
seed = 1  # random seed

do_cluster = True  # whether to cluster gene signals (kmeans)
n_clusters = 4  # number of clusters to use
cluster_topN = 5000  # set None to cluster all genes in res

top_plot_n = 12

# Cache: change this string if you change parameters and want a fresh run
cache_path = "psupertime_transient_wavelet_cache.pkl"

In [None]:
# 1) Load matrix and parse sample metadata from column names
df = pd.read_csv(data_path, sep="\t")

# adjust if your sample columns use a different suffix than ".raw"
sample_cols = [c for c in df.columns if c.lower().endswith(".raw")]
if len(sample_cols) == 0:
    raise ValueError("No sample columns found ending in '.raw'. Update sample_cols selection.")

def basename_win(path):
    return re.split(r"[\\/]", path)[-1]

def parse_sample(col):
    # Column looks like: D:\...\A1.raw or similar (maybe with _Pl2 before .raw)
    b = os.path.splitext(basename_win(col))[0]     # e.g. "A1" or "A1_Pl2"
    m = re.match(r"([A-H]\d{1,2})(?:_Pl2)?$", b)
    if not m:
        raise ValueError(f"Could not parse well from sample column: {col}")
    well = m.group(1)
    plate = 2 if b.endswith("_Pl2") else 1
    row = well[0]
    colnum = int(well[1:])

    # Plate 1: weeks 8,6,4,2 across triplets of columns, within each triplet: rev, hom, het
    # Plate 2: week 1 across columns: rev, hom, het
    if plate == 1:
        trip = (colnum - 1) // 3
        week = {0: 8, 1: 6, 2: 4, 3: 2}[trip]
        genotype = {0: "rev", 1: "hom", 2: "het"}[(colnum - 1) % 3]
    else:
        week = 1
        genotype = {1: "rev", 2: "hom", 3: "het"}[colnum]

    return dict(sample=b, well=well, plate=plate, row=row, col=colnum, week=int(week), genotype=genotype)

obs = pd.DataFrame([parse_sample(c) for c in sample_cols], index=sample_cols)


# 2) Pick the subset you want trajectories for
keep_mask = obs["genotype"].eq("rev")


obs_sub = obs.loc[keep_mask].copy()
sample_cols_sub = obs_sub.index.tolist()
if len(sample_cols_sub) < 5:
    raise ValueError(f"Too few samples kept ({len(sample_cols_sub)}). Check keep_mask.")


# 3) Build expression matrix (samples x genes) for AnnData
gene = (
    df["Genes"]
    .fillna("")
    .astype(str)
    .str.split(";")
    .str[0]
    .str.strip()
)
mask_gene = gene.ne("")  # remove blank genes
df2 = df.loc[mask_gene].copy()
df2["gene"] = gene.loc[mask_gene].values

# Pull intensities and do simple proteomics-friendly transforms
X = df2[sample_cols_sub].replace(0, np.nan).astype(float)

# Filter to reduce missingness (tune if you want)
min_detect_frac = 0.30
keep_feat = (X.notna().mean(axis=1) >= min_detect_frac)
X = X.loc[keep_feat]
df2 = df2.loc[keep_feat]

# Log2 transform with a small offset (min positive)
vals = X.values
min_pos = np.nanmin(vals[vals > 0]) if np.any(vals > 0) else 1.0
X_log = np.log2(X + min_pos)

# Per-sample median centering
X_norm = X_log.sub(X_log.median(axis=0), axis=1)

# Collapse to gene-level (median across protein rows per gene)
X_gene = X_norm.assign(gene=df2["gene"]).groupby("gene").median()   # genes x samples

# Samples x genes for AnnData
M = X_gene.T
M = M.apply(lambda col: col.fillna(col.median()), axis=0)

adata = ad.AnnData(
    X=M.values,
    obs=obs_sub.loc[M.index].copy(),
    var=pd.DataFrame(index=M.columns)
)

# pypsupertime expects a numeric ordinal label in .obs
# Use "week" as the ordinal time label
adata.obs["time"] = adata.obs["week"].astype(int)


# 4) Run pypsupertime
# Easiest path is: write .h5ad, then p.run(path, "time")
adata.write_h5ad(h5ad_path)

p = Psupertime()
adata_out = p.run(h5ad_path, "time")

# 5) Inspect / export results
print("obs columns:", list(adata_out.obs.columns))
print("var columns:", list(adata_out.var.columns))

# search for likely keys:
likely = [c for c in adata_out.obs.columns if "super" in c.lower() or "pseudo" in c.lower() or "time" == c.lower()]
print("likely pseudotime columns:", likely)

# Save output
out_path = "exp2_rev_psupertime.h5ad"
adata_out.write_h5ad(out_path)
print("Wrote:", out_path)


In [None]:
p.plot_grid_search(title="Grid Search")

In [None]:
p.plot_model_perf((adata_out.X, adata_out.obs.time), figsize=(6,5))

In [None]:
fig, ax = plt.subplots(2,2)
_ = p.plot_identified_gene_coefficients(adata_out, n_top=6, ax=ax[1,1])
_ = p.plot_labels_over_psupertime(adata_out, "time", ax=ax[0,1])

In [None]:
print([c for c in adata_out.obs.columns if "time" in c.lower()])


In [None]:
# genes to inspect
genes = [
    "DCC", "TH", "CAMK2A", "DYRK1A",
    "SRGAP1", "SEPTIN8", "HPCAL4"
]

ptime_col = "psupertime"   # change if needed

# ensure genes exist after psupertime filtering
genes = [g for g in genes if g in adata_out.var_names]
print("Plotting:", genes)

# sort cells by pseudotime
order = np.argsort(adata_out.obs[ptime_col].values)
ptime = adata_out.obs[ptime_col].values[order]

n = len(genes)
fig, axes = plt.subplots(n, 1, figsize=(6, 1.8*n), sharex=True)

if n == 1:
    axes = [axes]

for ax, g in zip(axes, genes):
    y = adata_out.X[order, adata_out.var_names.get_loc(g)]
    ax.plot(ptime, y, marker="o", linestyle="-", alpha=0.7)
    ax.set_ylabel(g)
    ax.axhline(0, color="gray", lw=0.5, ls="--")

axes[-1].set_xlabel("psupertime")
plt.tight_layout()
plt.show()


In [None]:
# # For visualizing the wavelet - in case you're unsure about what the wavelets look like with data.

# def ricker(center, pos, scale):
#     return (2/(np.sqrt(3*scale)*np.power(np.pi,1/4))) * (1 - ((pos-center)/scale)**2) * (np.exp((-(pos-center)**2)/(2*scale**2)))

# x = np.linspace(-6,6,1024)
# wlet = ricker(0, x, scale=2)
# plt.plot(x, wlet)

In [None]:
# # Here, "num_samples" corresponds to the number of windows you would use for converting samples to a signal. 
# # This affects what the scale means for the wavelet.
# num_samples = 40  
# # For visualizing the wavelet of different scales with different window counts.

# for num_samples in [20,30,40]:
#     plt.figure()
#     x = np.linspace(0,num_samples-1,num_samples)
#     c = np.mean(x)
#     # scales = [2,4,8,16]
#     for s in scales:
#         plt.plot(x, ricker(c, x, s), ".-", label=f"{s}")
#     plt.legend(title="Scales")
#     plt.title(f"Wavelets of different scales (n={num_samples})")
#     plt.ylabel("Wavelet value")

In [None]:
ptime = adata_out.obs["psupertime"]
ptime = (ptime - np.min(ptime)) / (np.max(ptime) - np.min(ptime))

In [None]:
def _to_dense(X):
    try:
        import scipy.sparse as sp
        if sp.issparse(X):
            return X.toarray()
    except Exception:
        pass
    return np.asarray(X)


def _make_params_dict():
    return {
        "ptime_col": ptime_col,
        "n_windows": n_windows,
        "sigma": sigma,
        "max_distance": max_distance,
        "wavelet": wavelet,
        "scales": scales.astype(int).tolist(),
        "zscore_genes": bool(zscore_genes),
        "detrend_monotonic": bool(detrend_monotonic),
        "edge_frac": float(edge_frac),
        "do_perms": bool(do_perms),
        "n_perms": int(n_perms),
        "topK_for_perms": int(topK_for_perms),
        "seed": int(seed),
        "do_cluster": bool(do_cluster),
        "n_clusters": int(n_clusters),
        "cluster_topN": None if cluster_topN is None else int(cluster_topN),
        "top_plot_n": int(top_plot_n),
    }


def run_or_load(cache_path: str = None):
    params = _make_params_dict()

    if cache_path is not None and os.path.exists(cache_path):
        print(f"[cache] Found {cache_path}, loading...")
        with open(cache_path, "rb") as f:
            out = pickle.load(f)

        old = out.get("params", {})
        if old != params:
            print("[cache] Cache params differ from current settings.")
            print("        Delete/rename cache_path if you want a full recompute.")
            print("        Using cached results anyway.")
        print("[cache] Loaded keys:", list(out.keys()))
        return out

    t0 = time.time()

    # 1) Pull ordered data
    print("[1/8] Pulling psupertime and ordering cells...")
    if ptime_col not in adata_out.obs.columns:
        raise KeyError(f"{ptime_col} not in adata_out.obs. Available: {list(adata_out.obs.columns)}")

    ptime = adata_out.obs[ptime_col].to_numpy().astype(float)
    order = np.argsort(ptime)
    ptime = ptime[order]
    ptime = (ptime - ptime.min()) / (ptime.max() - ptime.min() + 1e-12)  # normalize pseudotime range to 0-1

    print("[2/8] Loading X and var names...")
    X = _to_dense(adata_out.X)[order, :].astype(float)

    var_names_arr = np.asarray(list(adata_out.var_names), dtype=object)
    var_names_arr = np.array([str(x) for x in var_names_arr], dtype=str)
    gene_to_idx = {g: i for i, g in enumerate(var_names_arr)}

    print(f"       cells: {X.shape[0]:,}   genes: {X.shape[1]:,}")

    if zscore_genes:
        print("[3/8] Z-scoring per gene (across cells)...")
        X = (X - X.mean(axis=0, keepdims=True)) / (X.std(axis=0, keepdims=True) + 1e-12)
    else:
        print("[3/8] Skipping z-scoring (zscore_genes=False)")

    # 2) Confined Gaussian windowing
    print("[4/8] Building window weights and computing windowed signals (matmul)...")
    conf_gauss_window = ConfinedGaussianWindow(n_windows=params["n_windows"],
                                               sigma=params["sigma"],
                                               max_distance=params["max_distance"],
                                               signal_domain=(np.min(ptime), np.max(ptime)))
    signals = convert_to_signal(values=X, positions=ptime, window=conf_gauss_window)
    signals = signals.T  # assumption made downstream
    # 3) Detrend + Wavelet + TES
    print("[5/8] Computing TES (per gene)...")
    n_genes = signals.shape[1]
    tes = np.zeros(n_genes, dtype=float)
    peak_scale = np.zeros(n_genes, dtype=float)
    peak_t = np.zeros(n_genes, dtype=float)
    peak_sign = np.zeros(n_genes, dtype=float)
    centers = conf_gauss_window.window_centers
    t_feat = centers.reshape(-1, 1)
    lr = LinearRegression()

    step = max(1, n_genes // 20)
    for j in range(n_genes):
        if j % step == 0:
            print(f"       wavelets: {j:,}/{n_genes:,}")

        sig = signals[:, j]
        if detrend_monotonic:  # shouldn't matter for wavelets
            lr.fit(t_feat, sig)
            sig = sig - lr.predict(t_feat)

        coefs, _ = pywt.cwt(sig, scales=scales, wavelet=wavelet)
        tes[j] = transient_event_score(coefs)

        s_idx, t_idx = np.unravel_index(np.argmax(np.abs(coefs)), coefs.shape)
        pt = float(t_idx / (n_windows - 1))
        peak_t[j] = pt
        peak_scale[j] = float(scales[s_idx])
        peak_sign[j] = float(np.sign(coefs[s_idx, t_idx]))

        if pt < edge_frac or pt > (1.0 - edge_frac):
            tes[j] = np.nan

    res = pd.DataFrame({
        "gene": var_names_arr,
        "TES": tes,
        "peak_pseudotime": peak_t,
        "peak_scale": peak_scale,
        "peak_sign": peak_sign,
    }).dropna(subset=["TES"]).sort_values("TES", ascending=False).reset_index(drop=True)

    # 4) Two-stage permutation p-values (topK only)
    if do_perms:
        print("[6/8] Two-stage permutations (topK only)...")
        rng = np.random.default_rng(seed)

        topK = min(topK_for_perms, len(res))
        genes_test = res.loc[:topK-1, "gene"].astype(str).to_numpy()

        idx_test = np.array([gene_to_idx[g] for g in genes_test], dtype=int)

        tes_obs = tes[idx_test].copy()
        gt = np.zeros_like(tes_obs, dtype=float)

        X_test = X[:, idx_test]  # (cells, topK)
        wt = WaveletTransform(wavelet=params["wavelet"],
                              scales=params["scales"])
        tes, pvals, _ = permutation_dist(values=X_test,
                                         positions=ptime,
                                         n_permutations=n_perms,
                                         wavelet_transform=wt)
        
        pmap = dict(zip(genes_test, pvals))
        res["p_perm"] = res["gene"].astype(str).map(pmap)

        tested_mask = res["p_perm"].notna()
        res.loc[tested_mask, "q_perm_bh"] = false_discovery_control(res.loc[tested_mask, "p_perm"].to_numpy(), method="bh")
        res.loc[tested_mask, "q_perm_bh"] = false_discovery_control(res.loc[tested_mask, "p_perm"].to_numpy(), method="bh")
    else:
        print("[6/8] Skipping permutations (do_perms=False)")

    # 5) Cluster trajectories by shape
    cluster_info = None
    if do_cluster:
        print("[7/8] Clustering trajectories (k-means)...")
        if cluster_topN is None:
            use_genes = res["gene"].astype(str).to_numpy()
        else:
            use_genes = res.head(min(cluster_topN, len(res)))["gene"].astype(str).to_numpy()

        use_idx = np.array([gene_to_idx[g] for g in use_genes], dtype=int)

        S = signals[:, use_idx].T  # (n_use_genes, n_windows)

        if detrend_monotonic:
            Sd = np.zeros_like(S)
            for i in range(S.shape[0]):
                lr.fit(t_feat, S[i])
                Sd[i] = S[i] - lr.predict(t_feat)
            S = Sd

        S = StandardScaler(with_mean=True, with_std=True).fit_transform(S)
        labels = KMeans(n_clusters=n_clusters, random_state=0, n_init=20).fit_predict(S)

        cl_map = dict(zip(use_genes, labels))
        res["cluster"] = res["gene"].astype(str).map(cl_map)

        cluster_info = {
            "use_genes": use_genes,
            "labels": labels,
            "n_clusters": n_clusters,
        }
    else:
        print("[7/8] Skipping clustering (do_cluster=False)")

    # Save cache
    out = {
        "res": res,
        "signals": signals,
        "centers": centers,
        "ptime": ptime,
        "var_names_arr": var_names_arr,
        "gene_to_idx": gene_to_idx,
        "cluster_info": cluster_info,
        "params": params,
    }
    if cache_path is not None:
        print(f"[8/8] Saving cache to: {cache_path}")
        with open(cache_path, "wb") as f:
            pickle.dump(out, f)

    print(f"[done] total time: {(time.time() - t0):.1f} sec")
    return out

In [None]:
# out = run_or_load(cache_path)
out = run_or_load(None)

res = out["res"]
signals = out["signals"]
centers = out["centers"]
gene_to_idx = out["gene_to_idx"]

In [None]:
fig, ax = plt.subplots(3,2)
adata_out.obs["Week"] = adata_out.obs["time"]
_ = p.plot_identified_gene_coefficients(adata_out, n_top=6, ax=ax[0,1])
_ = p.plot_labels_over_psupertime(adata_out, "Week", ax=ax[0,0])
ax[0,1].set_ylabel("Gene")

# PLOT cluster-average trajectories (if cluster exists)
cluster_ax = ax[1,0]
if "cluster" in res.columns and res["cluster"].notna().any():
    print("[plot] Cluster-average trajectories...")
    # plt.figure(figsize=(7, 4))
    for c in sorted(res["cluster"].dropna().unique()):
        genes_c = res.loc[res["cluster"] == c, "gene"].astype(str).to_list()
        idx_c = np.array([gene_to_idx[g] for g in genes_c], dtype=int)
        traj = signals[:, idx_c].mean(axis=1)

        if detrend_monotonic:
            lr = LinearRegression()
            t_feat = centers.reshape(-1, 1)
            lr.fit(t_feat, traj)
            traj = traj - lr.predict(t_feat)

        cluster_ax.plot(centers, traj, label=f"{int(c)} (n={len(idx_c)})", linewidth=2.5)

    cluster_ax.set_xlabel("psupertime (window centers)")
    cluster_ax.set_ylabel("Mean windowed trajectory")
    cluster_ax.set_title("Cluster-average trajectories")
    cluster_ax.legend(title="Cluster", loc="lower right")
    # cluster_ax.tight_layout()
    # cluster_ax.show()

# PLOT top spike-like genes
print("[plot] Top spike-like genes...")
# res.head(top_plot_n)["Gene"] = res.head(top_plot_n)["gene"]
top_genes = res.head(top_plot_n)["gene"].astype(str).tolist()

# plt.figure(figsize=(8, 5))
spike_gene_ax = ax[1,1]
for g in top_genes:
    j = gene_to_idx.get(g, None)
    if j is None:
        print("[plot] missing gene:", repr(g))
        continue

    sig = signals[:, j].copy()
    if detrend_monotonic:
        lr = LinearRegression()
        t_feat = centers.reshape(-1, 1)
        lr.fit(t_feat, sig)
        sig = sig - lr.predict(t_feat)

    spike_gene_ax.plot(centers, sig, label=g, alpha=0.95, linewidth=2.5, marker=".")

spike_gene_ax.set_xlabel("Psupertime (window centers)")
spike_gene_ax.set_ylabel("Windowed abundance (detrended)" if detrend_monotonic else "windowed abundance")
spike_gene_ax.set_title("Top spike-like genes")
spike_gene_ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), fontsize=8)
# plt.tight_layout()
# plt.show()
fig.set_size_inches(10,8)
fig.tight_layout()
print(res.head(30))


In [None]:
display(fig)

In [None]:
# Cluster-average trajectories (show n, use median to be robust)
labels = out["cluster_info"]["labels"]
use_genes = out["cluster_info"]["use_genes"]
use_idx_map = out["gene_to_idx"]
signals = out["signals"]

plt.figure(figsize=(5,4))
for c in sorted(np.unique(labels)):
    g_c = use_genes[labels == c]
    idx_c = np.array([use_idx_map[g] for g in g_c], dtype=int)
    M = signals[:, idx_c]  # (n_windows, n_genes_in_cluster)

    if detrend_monotonic:
        # detrend each gene quickly
        Md = np.zeros_like(M)
        for k in range(M.shape[1]):
            lr.fit(t_feat, M[:, k])
            Md[:, k] = M[:, k] - lr.predict(t_feat)
        M = Md

    med = np.median(M, axis=1)
    plt.plot(centers, med, label=f"cluster {c} (n={len(g_c)})", linewidth=3)

plt.xlabel("psupertime (window centers)")
plt.ylabel("median windowed trajectory")
plt.title("Cluster-average trajectories (median)")
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5), fontsize=8)
plt.tight_layout()
plt.savefig('total clusters.svg', bbox_inches='tight')
plt.show()


In [None]:
from IPython.display import display

In [None]:
display(fig)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# genes you want to plot
genes_to_plot = ["CALB1", "NEFH", "GPRIN3", "GRID1"]

# ALWAYS resolve from adata_out, never from a reused variable
var_names = np.asarray(adata_out.var_names).astype(str)

def find_gene_index(target, var_names):
    # 1) exact match
    hits = np.where(var_names == target)[0]
    if hits.size > 0:
        return int(hits[0]), target, "exact"

    # 2) case-insensitive exact match
    hits = np.where(np.char.lower(var_names) == target.lower())[0]
    if hits.size > 0:
        return int(hits[0]), var_names[hits[0]], "case-insensitive exact"

    # 3) contains (case-insensitive) fallback
    mask = np.char.find(np.char.lower(var_names), target.lower()) >= 0
    hits = np.where(mask)[0]
    if hits.size > 0:
        return int(hits[0]), var_names[hits[0]], f"contains (picked 1 of {hits.size})"

    return None, None, "missing"

# Resolve gene indices with logging
resolved = []

print("Resolving genes against adata_out.var_names:")
for g in genes_to_plot:
    idx, matched_name, how = find_gene_index(g, var_names)
    print(f"  {g:8s} -> {how}" + (f" : {matched_name}" if matched_name else ""))
    if idx is not None:
        resolved.append((g, idx, matched_name, how))

if len(resolved) == 0:
    raise ValueError(
        "None of the requested genes were found in adata_out.var_names. "
        "Check gene naming (symbols vs IDs)."
    )

# Plot trajectories
# plt.figure(figsize=(5, 4))
rep_transient_ax = ax[2,0]
for g, j, matched_name, how in resolved:
    sig = signals[:, j].copy()

    # detrend if used upstream
    if detrend_monotonic:
        lr.fit(t_feat, sig)
        sig = sig - lr.predict(t_feat)

    label = matched_name if matched_name is not None else g
    rep_transient_ax.plot(
        centers,
        sig,
        linewidth=2.8,
        alpha=1.0,
        label=label
    )

rep_transient_ax.set_xlabel("psupertime (window centers)")
rep_transient_ax.set_ylabel("windowed abundance (detrended)" if detrend_monotonic else "windowed abundance")
rep_transient_ax.set_title("Representative transient trajectories")
rep_transient_ax.legend(frameon=False, fontsize=8)
display(fig)


In [None]:
sig = res.loc[res["q_perm_bh"] < 0.02].copy()

sig[["cluster", "gene"]].to_csv(
    "cluster_gene_lists_qlt0.02.tsv",
    sep="\t",
    index=False
)


In [None]:
res_gpm = res[res["gene"].str.contains("GPM", case=False, na=False)]
print(res_gpm)


In [None]:
res.to_csv("psupertime_scTransient_spikes_tuned.csv", index=False)


In [None]:
import numpy as np
import matplotlib.pyplot as plt

gene = "NEFH"
nefh_raw_ax = ax[2,1]
# find gene index
if gene not in adata_out.var_names:
    raise ValueError(f"{gene} not found in adata_out.var_names")

g_idx = np.where(adata_out.var_names == gene)[0][0]

# extract data
ptime = adata_out.obs["psupertime"].to_numpy()
expr = adata_out.X[:, g_idx]

# handle sparse matrix
try:
    expr = expr.toarray().ravel()
except Exception:
    expr = np.asarray(expr).ravel()

nefh_raw_ax.scatter(ptime, expr, s=40, alpha=0.7)
nefh_raw_ax.set_xlabel("psupertime")
nefh_raw_ax.set_ylabel("NEFH abundance")
nefh_raw_ax.set_title("NEFH raw values over psupertime")
display(fig)


In [None]:
ax[0,0].set_title("Distribution of labels across pseudotime")
ax[0,1].set_title("Genes predictive of pseudotime")

In [None]:
display(fig)

In [None]:
fig.savefig("ineuron_results.png", bbox_inches="tight")
fig.savefig("ineuron_results.svg", bbox_inches="tight")

In [None]:
hits = res[res["q_perm_bh"] < 0.02].copy()

hits = hits.sort_values(
    ["TES", "peak_scale"],
    ascending=[False, False]
)


In [None]:
import gseapy as gp

ranked = hits[["gene", "TES"]].dropna()

pre_res = gp.prerank(
    rnk=ranked,
    gene_sets="GO_Biological_Process_2023",
    min_size=10,
    max_size=300,
    permutation_num=1000,
    outdir=None,
    seed=1
)

pre_res.res2d.sort_values("NES", ascending=False).head(20)


In [None]:
pre_res.res2d.sort_values("NES", ascending=False).head(20)


In [None]:
pre_res.res2d.sort_values("FDR q-val", ascending=True).head(20)


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# 0) Grab the GSEA results table
gsea_df = pre_res.res2d.copy()

# Sanity: make sure expected columns exist
required_cols = ["Term", "NES", "FDR q-val", "Lead_genes"]
missing = [c for c in required_cols if c not in gsea_df.columns]
if missing:
    raise KeyError(f"Missing columns in pre_res.res2d: {missing}. Found: {list(gsea_df.columns)}")

# 1) TES-ranked gene list
tes_ranked = res.sort_values("TES", ascending=False)["gene"].astype(str).values
tes_rank_index = {g: i for i, g in enumerate(tes_ranked)}
N = len(tes_ranked)

# 2) Helpers
def parse_lead_genes(s):
    if pd.isna(s):
        return []
    s = str(s).strip()

    # gseapy can use ";", sometimes ","
    if ";" in s:
        parts = s.split(";")
    elif "," in s:
        parts = s.split(",")
    else:
        parts = [s]

    return [p.strip() for p in parts if p.strip()]

def cumulative_fraction_curve(genes, rank_index, N):
    idx = sorted([rank_index[g] for g in genes if g in rank_index])
    if len(idx) == 0:
        return None, None, 0

    x = np.arange(N)
    y = np.zeros(N, dtype=float)
    for i in idx:
        y[i:] += 1.0
    y /= len(idx)
    return x / (N - 1), y, len(idx)

def make_unique_labels(df, base_prefix):
    # If terms repeat, make label unique with a counter
    counts = {}
    labels = []
    for term in df["Term"].astype(str).tolist():
        key = term
        counts[key] = counts.get(key, 0) + 1
        suffix = f" #{counts[key]}" if counts[key] > 1 else ""
        labels.append(f"{base_prefix}: {term}{suffix}")
    return labels

# 3) Pick pathways: top 4 by descending NES, and top 4 by ascending FDR q-val
#    Then union (dedupe) while keeping a stable order.
top_nes = (
    gsea_df.dropna(subset=["NES"])
    .sort_values("NES", ascending=False)
    .head(4)
    .copy()
)
top_nes["plot_label"] = make_unique_labels(top_nes, "Top NES")

top_fdr = (
    gsea_df.dropna(subset=["FDR q-val"])
    .sort_values("FDR q-val", ascending=True)
    .head(4)
    .copy()
)
top_fdr["plot_label"] = make_unique_labels(top_fdr, "Top FDR")

# Union and de-dupe by Term, preferring the first occurrence (keeps NES-ranked choices first)
to_plot = pd.concat([top_nes, top_fdr], axis=0, ignore_index=True)
to_plot = to_plot.drop_duplicates(subset=["Term"], keep="first").reset_index(drop=True)

print("Plotting these terms:")
print(to_plot[["Term", "NES", "FDR q-val"]])

# 4) Plot cumulative recovery curves
plt.figure(figsize=(6, 5))

for _, row in to_plot.iterrows():
    term = str(row["Term"])
    label_prefix = str(row["plot_label"])

    lead = parse_lead_genes(row.get("Lead_genes", np.nan))
    x, y, n_hit = cumulative_fraction_curve(lead, tes_rank_index, N)
    if x is None:
        print(f"Skipping '{term}': 0 lead genes found in TES-ranked list")
        continue

    nes = row["NES"]
    fdr = row["FDR q-val"]
    lab = f"{label_prefix} (n={n_hit}, NES={nes:.2f}, FDR={fdr:.3g})"

    plt.plot(x, y, label=lab, linewidth=2)

# diagonal reference
plt.plot([0, 1], [0, 1], linestyle="--", color="gray", alpha=0.6)

plt.xlabel("Fraction of genes (TES-ranked)")
plt.ylabel("Fraction of pathway lead genes recovered")
plt.title("Cumulative recovery of GSEA lead genes across TES ranking")
plt.legend(fontsize=7, loc="center left", bbox_to_anchor=(1, 0.5))
plt.tight_layout()
plt.savefig('gsea-ranked.svg', bbox_inches='tight')
plt.show()


In [None]:
import numpy as np
import pandas as pd

# -------- settings --------
K_PATHS = 25                 # export this many positive-NES pathways
MIN_NES = 0                  # only positive NES
USE_Q_FILTER = True          # restrict genes to res[q_perm_bh < 0.02]
Q_THRESH = 0.02
TOPN_GENES_PER_PATH = 50     # keep only top N TES genes per pathway (after intersection)
OUTFILE = "top_positiveNES_pathways_byTES_leadgenes.tsv"

# -------- inputs --------
gsea_df = pre_res.res2d.copy()
res_df = res.copy()

# sanity checks
need_gsea = ["Term", "NES", "FDR q-val", "Lead_genes"]
missing = [c for c in need_gsea if c not in gsea_df.columns]
if missing:
    raise KeyError(f"Missing columns in pre_res.res2d: {missing}")

if "gene" not in res_df.columns or "TES" not in res_df.columns:
    raise KeyError("`res` must have columns: gene, TES")

# optional q filter
if USE_Q_FILTER:
    if "q_perm_bh" not in res_df.columns:
        raise KeyError("USE_Q_FILTER=True but res has no column 'q_perm_bh'")
    res_df = res_df[res_df["q_perm_bh"].notna() & (res_df["q_perm_bh"] < Q_THRESH)].copy()

# make lookup: gene -> TES, q
res_df["gene"] = res_df["gene"].astype(str)
tes_map = dict(zip(res_df["gene"], res_df["TES"]))
q_map = dict(zip(res_df["gene"], res_df["q_perm_bh"])) if "q_perm_bh" in res_df.columns else {}

# helpers
def parse_lead_genes(s):
    if pd.isna(s):
        return []
    s = str(s).strip()
    if ";" in s:
        parts = s.split(";")
    elif "," in s:
        parts = s.split(",")
    else:
        parts = [s]
    return [p.strip() for p in parts if p.strip()]

def summarize_genes(glist):
    return ";".join(glist)

# pick top K positive NES pathways
paths = (
    gsea_df.dropna(subset=["NES"])
    .loc[gsea_df["NES"] > MIN_NES]
    .sort_values("NES", ascending=False)  # if you meant "top by NES"
    .head(K_PATHS)
    .copy()
)

rows = []
for _, r in paths.iterrows():
    term = str(r["Term"])
    nes = float(r["NES"])
    fdr = float(r["FDR q-val"]) if pd.notna(r["FDR q-val"]) else np.nan

    lead = parse_lead_genes(r.get("Lead_genes", np.nan))
    lead_set = [g for g in lead if g in tes_map]  # present in TES (and pass q filter if enabled)

    # rank lead genes by TES
    lead_ranked = sorted(lead_set, key=lambda g: tes_map.get(g, -np.inf), reverse=True)
    lead_ranked = lead_ranked[:TOPN_GENES_PER_PATH]

    # assemble per-gene strings like GENE(TES=...,q=...)
    gene_summ = []
    for g in lead_ranked:
        tesv = tes_map.get(g, np.nan)
        qv = q_map.get(g, np.nan) if q_map else np.nan
        if np.isfinite(qv):
            gene_summ.append(f"{g}(TES={tesv:.3f},q={qv:.3g})")
        else:
            gene_summ.append(f"{g}(TES={tesv:.3f})")

    rows.append({
        "Term": term,
        "NES": nes,
        "FDR_qval": fdr,
        "n_lead_genes_total": len(lead),
        "n_lead_genes_in_TES_list": len(lead_set),
        "top_lead_genes_by_TES": summarize_genes(lead_ranked),
        "top_lead_genes_by_TES_with_scores": summarize_genes(gene_summ),
        "all_lead_genes_raw": summarize_genes(lead),
    })

export_df = pd.DataFrame(rows).sort_values(["NES", "FDR_qval"], ascending=[False, True])
export_df.to_csv(OUTFILE, sep="\t", index=False)

print(f"Wrote: {OUTFILE}")
print(export_df.head(10)[["Term","NES","FDR_qval","n_lead_genes_in_TES_list"]])


In [None]:
genes_to_plot = {
    "Stem cell maintenance": "HMGA2",
    "Positive stem cell maintenance": "SMARCA4",
    "Negative regulation of development": "THY1",
    "DSB repair regulation": "KDM1A",
}


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pywt
from sklearn.linear_model import LinearRegression

# ---------- helpers ----------
def rescale_to(sig, target):
    sig = np.asarray(sig, float)
    target = np.asarray(target, float)
    sig = (sig - sig.mean()) / (sig.std() + 1e-12)
    return sig * target.std() + target.mean()

def nearest_center_index(pt01, centers):
    # pt01 is 0..1; centers is array of window centers in 0..1
    return int(np.argmin(np.abs(centers - float(pt01))))

def get_gene_index_map(adata):
    v = np.asarray(adata.var_names).astype(str)
    return v, {g: i for i, g in enumerate(v)}

# ---------- choose genes to plot ----------
genes_to_plot = ["HMGA2", "SMARCA4"]  # edit

var_names, gene_to_idx = get_gene_index_map(adata_out)

missing = [g for g in genes_to_plot if g not in gene_to_idx]
if missing:
    raise ValueError(f"Not found in adata_out.var_names: {missing[:20]}{'...' if len(missing)>20 else ''}")

# ---------- raw data ordered by psupertime ----------
X_raw = adata_out.X
try:
    import scipy.sparse as sp
    if sp.issparse(X_raw):
        X_raw = X_raw.toarray()
except Exception:
    X_raw = np.asarray(X_raw)

ptime_raw = adata_out.obs[ptime_col].to_numpy().astype(float)
order = np.argsort(ptime_raw)

ptime_sorted = ptime_raw[order]
ptime_scaled = (ptime_sorted - ptime_sorted.min()) / (ptime_sorted.max() - ptime_sorted.min() + 1e-12)

X_sorted = np.asarray(X_raw, float)[order, :]

# match upstream z-scoring choice
if zscore_genes:
    X_sorted = (X_sorted - X_sorted.mean(axis=0, keepdims=True)) / (X_sorted.std(axis=0, keepdims=True) + 1e-12)

lr = LinearRegression()
t_feat = centers.reshape(-1, 1)

# ---------- plotting ----------
for g in genes_to_plot:
    j = gene_to_idx[g]

    # pull stored peak info from res
    row = res.loc[res["gene"].astype(str) == g]
    if row.empty:
        raise ValueError(f"{g} not found in res['gene']. Are you plotting genes not in the TES results?")
    row = row.iloc[0]

    peak_pt01 = float(row["peak_pseudotime"])
    peak_scale_val = float(row["peak_scale"])
    peak_sign = float(row.get("peak_sign", np.nan))

    # map peak pseudotime to a window index
    t_idx_from_res = nearest_center_index(peak_pt01, centers)

    # windowed trajectory (from precomputed signals)
    y_win = signals[:, j].copy()
    if detrend_monotonic:
        lr.fit(t_feat, y_win)
        y_win = y_win - lr.predict(t_feat)

    # wavelet coefficients on the same signal used for TES
    coefs, _ = pywt.cwt(y_win, scales=scales, wavelet=wavelet)

    # find nearest scale index to stored peak_scale
    s_idx_from_res = int(np.argmin(np.abs(scales - peak_scale_val)))

    # coefficient time series at that scale
    wave_line = coefs[s_idx_from_res, :]
    # wave_line = np.abs(wave_line) * (1.0 if np.isnan(peak_sign) else np.sign(peak_sign))

    wave_overlay = rescale_to(wave_line, y_win)

    # sanity check: does recomputed global max agree with stored peak?
    s_idx_max, t_idx_max = np.unravel_index(np.argmax(np.abs(coefs)), coefs.shape)
    pt_max = t_idx_max / (len(centers) - 1)
    sc_max = float(scales[s_idx_max])

    print(
        f"{g}: stored peak_pt={peak_pt01:.3f} (t_idx={t_idx_from_res}), stored scale={peak_scale_val:.1f} "
        f"| recomputed max_pt={pt_max:.3f} (t_idx={t_idx_max}), max scale={sc_max:.1f}"
    )

    # raw points
    y_raw = X_sorted[:, j]

    plt.figure(figsize=(8, 4.5))
    plt.scatter(ptime_scaled, y_raw, s=12, alpha=0.25, label="raw (z-scored)")

    plt.plot(centers, y_win, linewidth=2.8, label="windowed trajectory (detrended)" if detrend_monotonic else "windowed trajectory")

    plt.plot(centers, wave_overlay, linewidth=2.2, linestyle="--",
             label=f"wavelet coef @ scale={scales[s_idx_from_res]} (from res)")

    # vertical line at stored peak
    plt.axvline(centers[t_idx_from_res], linestyle=":", alpha=0.7,
                label=f"stored peak (pt={centers[t_idx_from_res]:.2f})")

    plt.xlabel("psupertime")
    plt.ylabel("abundance")
    plt.title(f"{g} (TES={row['TES']:.3f}, q={row.get('q_perm_bh', np.nan):.3g})")
    plt.legend(frameon=False, fontsize=8)
    plt.tight_layout()
    plt.show()


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression

# 1. Load the pathway data and get top 4 by NES
df_pathways = pd.read_csv('top_positiveNES_pathways_byTES_leadgenes.tsv', sep='\t')
top_4_pathways = df_pathways.sort_values(by='NES', ascending=False).head(4)

# 2. Setup indices from your adata object (assumes adata_out exists in your env)
var_names = np.asarray(adata_out.var_names).astype(str)
gene_to_idx = {g: i for i, g in enumerate(var_names)}

# Detrending model setup (as per your template)
lr = LinearRegression()
t_feat = centers.reshape(-1, 1)

# 3. Generate one plot per pathway
for idx, row in top_4_pathways.iterrows():
    pathway_name = row['Term']
    # Split the raw lead genes string into a list
    genes_in_pathway = row['all_lead_genes_raw'].split(';')
    
    plt.figure(figsize=(10, 5))
    
    genes_plotted_count = 0
    for gene in genes_in_pathway:
        if gene in gene_to_idx:
            j = gene_to_idx[gene]
            
            # Extract windowed signal (assumes 'signals' and 'centers' exist)
            y_win = signals[:, j].copy()
            
            # Apply detrending if your pipeline uses it
            if detrend_monotonic:
                lr.fit(t_feat, y_win)
                y_win = y_win - lr.predict(t_feat)
            
            # Plot the smoothed trajectory for this gene
            plt.plot(
                centers, 
                y_win, 
                linewidth=2, 
                alpha=0.7, 
                label=gene
            )
            genes_plotted_count += 1
            
    # Formatting the plot
    plt.title(f"Top Lead Genes: {pathway_name}\n(NES: {row['NES']:.2f})", fontsize=12)
    plt.xlabel("psupertime", fontsize=10)
    plt.ylabel("z-scored abundance (windowed)", fontsize=10)
    
    # Place legend outside if there are many genes
    if genes_plotted_count > 0:
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=False, fontsize=8)
    
    plt.grid(axis='y', linestyle='--', alpha=0.3)
    plt.tight_layout()
    plt.show()