# Welch t-test evaluation (frozen priors)

This notebook reproduces the **main results** using **frozen priors** saved by Notebook 01.

- No LLM calls
- Deterministic given the same seed and frozen JSON files


**Prerequisite:** We need an `AnnData` object named `adata` in memory (same preprocessing / `var_names` as used to build the frozen priors). Either we can run Notebook 01 first in the same runtime, or load the same `.h5ad` here.


## 0) Imports + load frozen priors

In [None]:
import os
import anndata as ad

import re, math
import numpy as np, pandas as pd
from scipy.stats import ttest_ind

# Set this to the dataset path (e.g., a preprocessed MERFISH .h5ad).
DATA_PATH = os.environ.get('MERFISH_H5AD', '')

if 'adata' not in globals():
    if not DATA_PATH:
        raise ValueError("`adata` not found. Run Notebook 01 first or set MERFISH_H5AD / DATA_PATH to a .h5ad file.")
    adata = ad.read_h5ad(DATA_PATH)

print('adata:', adata)


In [None]:

# housekeeping / panel filter (keep consistent with synthesis)
HOUSEKEEPING_RE = r'^(Rpl|Rps|Mrpl|Mrps|mt\-|Mt\-)'
panel = set(map(str, adata.var_names))
MIN_GENES = 3

def _hk_filter_to_panel(gs, panel=panel):
    gs = [str(g) for g in gs if str(g) in panel]
    gs = [g for g in gs if not re.match(HOUSEKEEPING_RE, g)]
    out, seen = [], set()
    for g in gs:
        if g not in seen:
            seen.add(g); out.append(g)
    return out

def _aggregate_scores(X_slice, gene_idx, weights=None, agg="weighted", trim_pct=0.10):
    """
    X_slice: (n_cells, n_genes_all) view for pos or neg
    gene_idx: np.array of selected gene indices
    weights: np.array of weights aligned to gene_idx (or None)
    agg: "weighted" | "trimmed_mean" 
    """
    if len(gene_idx) == 0:
        return None
    sub = X_slice[:, gene_idx]  
    if agg == "weighted":
        if weights is None:
            return sub.mean(axis=1)
        w = np.asarray(weights, float)
        w = np.clip(w, 1e-6, np.inf)
        return (sub * w).sum(axis=1) / (w.sum())
    elif agg == "trimmed_mean":
        n = sub.shape[1]
        k = int(n * trim_pct)
        if k == 0:
            return sub.mean(axis=1)
        sub_sorted = np.sort(sub, axis=1)
        core = sub_sorted[:, k: n-k] if (n - 2*k) > 0 else sub_sorted[:, :1]
        return core.mean(axis=1)
    else:
        return sub.mean(axis=1)

def _cohens_d_posneg(pos, neg):
    pos = np.asarray(pos, float); neg = np.asarray(neg, float)
    if len(pos) < 2 or len(neg) < 2:
        return np.nan
    mx, my = pos.mean(), neg.mean()
    vx, vy = pos.var(ddof=1), neg.var(ddof=1)
    denom = (len(pos) + len(neg) - 2)
    if denom <= 0:
        return np.nan
    sp2 = ((len(pos)-1)*vx + (len(neg)-1)*vy) / denom
    if not np.isfinite(sp2) or sp2 <= 0:
        return np.nan
    return (mx - my) / math.sqrt(sp2)

def classify_by_ttest_custom(
    adata, sets_dict, weights_dict=None, K=100, neg_per_pos=5,
    agg="weighted", trim_pct=0.10, seed=42, boot_idx=None
):
    """
    Held-out Welch (pos>neg) classification using custom aggregation.

    For each *true* class (row):
      - sample K held-out positives (y==true) and neg_per_pos*K held-out negatives (y!=true)
      - for each candidate class, aggregate expression over its gene set
      - run one-sided Welch t-test: H1 mean(pos_scores) > mean(neg_scores)
      - predict the candidate with the smallest p-value among those with positive mean difference

    boot_idx is treated as an exclusion set (e.g., train_idx), preventing leakage.
    """
    rng = np.random.default_rng(seed)
    X = adata.X.toarray() if hasattr(adata.X, 'toarray') else np.asarray(adata.X)
    y = adata.obs['Cell_class'].astype(str).values
    g = np.array(adata.var_names, dtype=str)
    g2i = {gg: i for i, gg in enumerate(g)}
    classes = sorted([c for c in sets_dict.keys() if len(sets_dict[c]) >= MIN_GENES])

    # exclude any provided bootstrap/train indices
    idx_all = np.arange(adata.n_obs)
    mask_notboot = np.ones(len(idx_all), dtype=bool)
    if boot_idx is not None:
        boot_idx = np.asarray(boot_idx, dtype=int)
        if boot_idx.size > 0:
            mask_notboot[boot_idx] = False

    preds, rows = [], []
    for true_c in classes:
        pos_pool = np.where(mask_notboot & (y == true_c))[0]
        neg_pool = np.where(mask_notboot & (y != true_c))[0]
        n_pos = min(K, len(pos_pool))
        n_neg = min(max(neg_per_pos * n_pos, n_pos), len(neg_pool)) if n_pos > 0 else 0

        if n_pos < 2 or n_neg < 2:
            preds.append(("None", true_c))
            continue

        pos_idx = rng.choice(pos_pool, size=n_pos, replace=False)
        neg_idx = rng.choice(neg_pool, size=n_neg, replace=False)

        best_p, best_c = 1.0, None
        for cand in classes:
            genes = [gg for gg in sets_dict[cand] if gg in g2i]
            if len(genes) < MIN_GENES:
                continue

            gi = np.array([g2i[gg] for gg in genes], int)

            w = None
            if weights_dict is not None and cand in weights_dict:
                w = np.array([weights_dict[cand].get(gg, 1.0) for gg in genes], float)

            pos_scores = _aggregate_scores(X[pos_idx], gi, weights=w, agg=agg, trim_pct=trim_pct)
            neg_scores = _aggregate_scores(X[neg_idx], gi, weights=w, agg=agg, trim_pct=trim_pct)
            if pos_scores is None or neg_scores is None:
                continue

            diff = float(np.mean(pos_scores) - np.mean(neg_scores))
            p = ttest_ind(pos_scores, neg_scores, equal_var=False, alternative='greater').pvalue

            if (diff > 0) and (p < best_p):
                best_p, best_c = p, cand

            if cand == true_c:
                d = _cohens_d_posneg(pos_scores, neg_scores)
                rows.append({"class": true_c, "diag_p": float(p), "diag_d": float(d)})

        preds.append((best_c if best_c is not None else classes[0], true_c))

    df_pred = pd.DataFrame({"class": [t for (_, t) in preds], "predicted": [p for (p, _) in preds]})
    df_pred["correct"] = (df_pred["class"] == df_pred["predicted"])
    acc = float(df_pred["correct"].mean()) if len(df_pred) else np.nan
    diag = pd.DataFrame(rows).groupby("class", as_index=True).last()
    return acc, df_pred, diag


In [None]:
OUTDIR = "./"  

with open(os.path.join(OUTDIR, "ref_llm_top.json")) as f:
    ref_llm_top = {c: set(gs) for c, gs in json.load(f).items()}
with open(os.path.join(OUTDIR, "ref_llm_w.json")) as f:
    tmp = json.load(f)
    ref_llm_w = {c: {g: float(w) for g, w in inner.items()} for c, inner in tmp.items()}

with open(os.path.join(OUTDIR, "marker_top.json")) as f:
    marker_top = {c: set(gs) for c, gs in json.load(f).items()}
with open(os.path.join(OUTDIR, "marker_w.json")) as f:
    tmp = json.load(f)
    marker_w = {c: {g: float(w) for g, w in inner.items()} for c, inner in tmp.items()}

with open(os.path.join(OUTDIR, "split.json")) as f:
    split = json.load(f)
    train_idx = np.array(split["train_idx"], int)
    test_idx  = np.array(split["test_idx"],  int)

print("Loaded classes:", len(ref_llm_top), "| Train:", len(train_idx), "| Test:", len(test_idx))

## 1) Welch evaluation functions 

- sample K positives from class c (excluding `boot_idx`)
- sample `neg_per_pos*K` negatives
- compute aggregated score for each candidate gene set
- pick candidate with smallest p-value (subject to positive mean difference)

In [None]:
MIN_GENES = 3

def _aggregate_scores(X_slice, gene_idx, weights=None, agg="weighted", trim_pct=0.10):
    if len(gene_idx) == 0:
        return None
    sub = X_slice[:, gene_idx]
    if agg == "weighted":
        if weights is None:
            return sub.mean(axis=1)
        w = np.asarray(weights, float)
        w = np.clip(w, 1e-6, np.inf)
        return (sub * w).sum(axis=1) / w.sum()
    elif agg == "trimmed_mean":
        n = sub.shape[1]
        k = int(n * trim_pct)
        if k == 0:
            return sub.mean(axis=1)
        sub_sorted = np.sort(sub, axis=1)
        core = sub_sorted[:, k:n-k] if (n - 2*k) > 0 else sub_sorted[:, :1]
        return core.mean(axis=1)
    else:
        return sub.mean(axis=1)


def _cohens_d(pos, neg):
    pos = np.asarray(pos, float); neg = np.asarray(neg, float)
    if len(pos) < 2 or len(neg) < 2:
        return np.nan
    vx, vy = pos.var(ddof=1), neg.var(ddof=1)
    denom = (len(pos) + len(neg) - 2)
    sp2 = ((len(pos)-1)*vx + (len(neg)-1)*vy) / denom if denom > 0 else np.nan
    if not np.isfinite(sp2) or sp2 <= 0:
        return np.nan
    return (pos.mean() - neg.mean()) / math.sqrt(sp2)

def classify_by_ttest_custom(
    adata, sets_dict, weights_dict=None, K=100, neg_per_pos=5,
    agg="weighted", trim_pct=0.10, seed=42, boot_idx=None
):
    rng = np.random.default_rng(seed)
    X = adata.X.toarray() if hasattr(adata.X, "toarray") else np.asarray(adata.X)
    y = adata.obs["Cell_class"].astype(str).values
    g = np.array(adata.var_names, dtype=str)
    g2i = {gg: i for i, gg in enumerate(g)}
    classes = sorted([c for c in sets_dict.keys() if len(sets_dict[c]) >= MIN_GENES])

    idx_all = np.arange(adata.n_obs)
    mask_notboot = np.ones(len(idx_all), dtype=bool)
    if boot_idx is not None:
        boot_idx = np.asarray(boot_idx, dtype=int)
        if boot_idx.size > 0:
            mask_notboot[boot_idx] = False

    preds, diag_rows = [], []
    for true_c in classes:
        pos_pool = np.where(mask_notboot & (y == true_c))[0]
        neg_pool = np.where(mask_notboot & (y != true_c))[0]
        n_pos = min(K, len(pos_pool))
        n_neg = min(max(neg_per_pos * n_pos, n_pos), len(neg_pool)) if n_pos > 0 else 0
        if n_pos < 2 or n_neg < 2:
            preds.append(("None", true_c))
            continue

        pos_idx = rng.choice(pos_pool, size=n_pos, replace=False)
        neg_idx = rng.choice(neg_pool, size=n_neg, replace=False)

        best_p, best_c = 1.0, None
        for cand in classes:
            genes = [gg for gg in sets_dict[cand] if gg in g2i]
            if len(genes) < MIN_GENES:
                continue
            gi = np.array([g2i[gg] for gg in genes], int)

            w = None
            if weights_dict is not None and cand in weights_dict:
                w = np.array([weights_dict[cand].get(gg, 1.0) for gg in genes], float)

            pos_scores = _aggregate_scores(X[pos_idx], gi, weights=w, agg=agg, trim_pct=trim_pct)
            neg_scores = _aggregate_scores(X[neg_idx], gi, weights=w, agg=agg, trim_pct=trim_pct)
            if pos_scores is None or neg_scores is None:
                continue

            diff = float(pos_scores.mean() - neg_scores.mean())
            p = ttest_ind(pos_scores, neg_scores, equal_var=False, alternative="greater").pvalue

            if (diff > 0) and (p < best_p):
                best_p, best_c = p, cand

            if cand == true_c:
                diag_rows.append({"class": true_c, "diag_p": float(p), "diag_d": float(_cohens_d(pos_scores, neg_scores))})

        preds.append((best_c if best_c is not None else classes[0], true_c))

    df_pred = pd.DataFrame({"class": [t for (_, t) in preds],
                            "predicted": [p for (p, _) in preds]})
    df_pred["correct"] = (df_pred["class"] == df_pred["predicted"])
    acc = float(df_pred["correct"].mean()) if len(df_pred) else np.nan
    diag = pd.DataFrame(diag_rows).groupby("class", as_index=True).last()
    return acc, df_pred, diag

## 2) Sweep over seeds 

In [None]:
def sweep(seeds=(1,2,3,4,5)):
    llm, mrk = [], []
    for s in seeds:
        aL, _, _ = classify_by_ttest_custom(
            adata, ref_llm_top, ref_llm_w,
            K=100, neg_per_pos=5, agg="weighted",
            trim_pct=0.10, seed=s, boot_idx=train_idx
        )
        aM, _, _ = classify_by_ttest_custom(
            adata, marker_top, marker_w,
            K=100, neg_per_pos=5, agg="weighted",
            trim_pct=0.10, seed=s, boot_idx=train_idx
        )
        llm.append(aL); mrk.append(aM)
    return (float(np.mean(llm)), float(np.std(llm))), (float(np.mean(mrk)), float(np.std(mrk)))

print("LLM mean±sd, Markers mean±sd:", sweep())