# Cross-platform transfer (MERFISH → Stereo-seq) using frozen priors

This notebook evaluates **transfer** of **frozen LLM gene-set priors** (learned once on MERFISH) onto a **new spatial technology** dataset (e.g., Stereo-seq).

**Workflow**
- Load frozen priors from disk (`ref_llm_top.json`, `ref_llm_w.json`) and the matched marker baseline (`marker_top.json`, `marker_w.json`).
- Map the new dataset’s raw labels into the same **macro classes** used for MERFISH.
- Intersect priors/markers with the new dataset gene panel (and remove housekeeping genes).
- Run the **Welch one-sided** classifier (`classify_by_ttest_custom`) used in the paper.
- Seed sweep + paired stats plot.

**Note:** This notebook does **not** run any LLM calls.


In [None]:
import os, re, json, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy.stats import ttest_ind, ttest_rel, wilcoxon, t as student_t


## 1) Inputs
Provide the new dataset AnnData as `adataX` (Stereo-seq) before running.


In [None]:
# New technology dataset (e.g., Stereo-seq)
adata_new = adataX.copy()  

# Folder containing frozen priors 
OUTDIR = "./"       


## 2) Load frozen priors + marker baseline



In [None]:
with open(os.path.join(OUTDIR, "ref_llm_top.json")) as f:
    ref_llm_top = {c: set(gs) for c, gs in json.load(f).items()}

with open(os.path.join(OUTDIR, "ref_llm_w.json")) as f:
    _w = json.load(f)
    ref_llm_w = {c: {g: float(w) for g, w in gm.items()} for c, gm in _w.items()}

with open(os.path.join(OUTDIR, "marker_top.json")) as f:
    marker_top = {c: set(gs) for c, gs in json.load(f).items()}

with open(os.path.join(OUTDIR, "marker_w.json")) as f:
    _mw = json.load(f)
    marker_w = {c: {g: float(w) for g, w in gm.items()} for c, gm in _mw.items()}

print("Loaded classes (LLM):", len(ref_llm_top))
print("Loaded classes (MRK):", len(marker_top))


## 3) Map new dataset labels → MERFISH macro classes




In [None]:
ALL_CLASSES = ['Astrocyte','Endothelial','Ependymal','Excitatory','Inhibitory',
               'Microglia','OD Immature','OD Mature','Pericytes']

RAW_COL = "annotation"  

mapper = {
    r"^Astro.*":         "Astrocyte",
    r"^Endoth.*":        "Endothelial",
    r"^Ependym.*":       "Ependymal",
    r"^(Ex|Excit).*":    "Excitatory",
    r"^(Inh|GABA).*":    "Inhibitory",
    r"^Microglia.*":     "Microglia",
    r"^(OPC|OD Imm).*":  "OD Immature",
    r"^(Oligodendro).*": "OD Mature",
    r"^(Pericyte|VSMC|VLMC).*": "Pericytes",
}

def to_macro(label):
    s = str(label)
    for pat, target in mapper.items():
        if re.search(pat, s, flags=re.I):
            return target
    return "Unknown"

adata_new.obs["Cell_class"] = adata_new.obs[RAW_COL].map(to_macro)
print("Per-class counts in new dataset:\n", adata_new.obs["Cell_class"].value_counts())


## 4) Panel intersection + housekeeping filter


In [None]:
HOUSEKEEPING_RE = r'^(Rpl|Rps|Mrpl|Mrps|mt\-|Mt\-)'
MIN_GENES = 3
panel_new = set(map(str, adata_new.var_names))

def hk_filter_to_panel(gs, panel):
    out, seen = [], set()
    for g in gs:
        g = str(g)
        if g in panel and not re.match(HOUSEKEEPING_RE, g) and g not in seen:
            seen.add(g); out.append(g)
    return out

def renorm_weights_mean1(wdict):
    vals = np.array(list(wdict.values()), float)
    if vals.size == 0:
        return wdict
    s = float(np.nansum(vals))
    if not np.isfinite(s) or s <= 0:
        return {k: 1.0 for k in wdict}
    scale = len(vals) / s
    return {k: float(v) * scale for k, v in wdict.items()}

# Filter priors + weights
ref_llm_top_X, ref_llm_w_X = {}, {}
for c, s in ref_llm_top.items():
    keep = hk_filter_to_panel(s, panel_new)
    if len(keep) >= MIN_GENES:
        ref_llm_top_X[c] = set(keep)
        w = {g: max(ref_llm_w.get(c, {}).get(g, 0.0), 0.0) for g in keep}
        ref_llm_w_X[c] = renorm_weights_mean1(w)

# Filter markers + weights
marker_top_X, marker_w_X = {}, {}
for c, s in marker_top.items():
    keep = hk_filter_to_panel(s, panel_new)
    if len(keep) >= MIN_GENES:
        marker_top_X[c] = set(keep)
        w = {g: float(marker_w.get(c, {}).get(g, 1.0)) for g in keep}
        marker_w_X[c] = renorm_weights_mean1(w)

# Keep only classes present in new dataset and macro list
present = set(adata_new.obs["Cell_class"].astype(str).unique()) & set(ALL_CLASSES)
ref_llm_top_X = {c: s for c, s in ref_llm_top_X.items() if c in present}
ref_llm_w_X   = {c: w for c, w in ref_llm_w_X.items()   if c in present}
marker_top_X  = {c: s for c, s in marker_top_X.items()  if c in present}
marker_w_X    = {c: w for c, w in marker_w_X.items()    if c in present}

print("Classes evaluated:", sorted(ref_llm_top_X.keys()))


## 5) Welch t-test classifier 


In [None]:
def _aggregate_scores(X_slice, gene_idx, weights=None, agg="weighted", trim_pct=0.10):
    if len(gene_idx) == 0:
        return None
    sub = X_slice[:, gene_idx]
    if agg == "weighted":
        if weights is None:
            return sub.mean(axis=1)
        w = np.asarray(weights, float)
        w = np.clip(w, 1e-6, np.inf)
        return (sub * w).sum(axis=1) / (w.sum())
    elif agg == "trimmed_mean":
        n = sub.shape[1]
        k = int(n * trim_pct)
        if k == 0:
            return sub.mean(axis=1)
        sub_sorted = np.sort(sub, axis=1)
        core = sub_sorted[:, k: n-k] if (n - 2*k) > 0 else sub_sorted[:, :1]
        return core.mean(axis=1)
    else:
        return sub.mean(axis=1)

def _cohens_d_posneg(pos, neg):
    pos = np.asarray(pos, float); neg = np.asarray(neg, float)
    if len(pos) < 2 or len(neg) < 2:
        return np.nan
    mx, my = pos.mean(), neg.mean()
    vx, vy = pos.var(ddof=1), neg.var(ddof=1)
    sp2 = ((len(pos)-1)*vx + (len(neg)-1)*vy) / max(len(pos)+len(neg)-2, 1)
    if not np.isfinite(sp2) or sp2 <= 0:
        return np.nan
    return (mx - my) / math.sqrt(sp2)

def classify_by_ttest_custom(
    adata, sets_dict, weights_dict=None, K=100, neg_per_pos=5,
    agg="weighted", trim_pct=0.10, seed=42, boot_idx=None
):
    rng = np.random.default_rng(seed)
    X = adata.X.toarray() if hasattr(adata.X, 'toarray') else np.asarray(adata.X)
    y = adata.obs['Cell_class'].astype(str).values
    g = np.array(adata.var_names, dtype=str)
    g2i = {gg:i for i,gg in enumerate(g)}
    classes = sorted([c for c in sets_dict.keys() if len(sets_dict[c]) >= MIN_GENES])

    idx_all = np.arange(adata.n_obs)
    mask_notboot = np.ones(len(idx_all), dtype=bool)
    if boot_idx is not None:
        boot_idx = np.asarray(boot_idx, dtype=int)
        if boot_idx.size > 0:
            mask_notboot[boot_idx] = False

    preds, rows = [], []
    for true_c in classes:
        pos_pool = np.where(mask_notboot & (y == true_c))[0]
        neg_pool = np.where(mask_notboot & (y != true_c))[0]
        n_pos = min(K, len(pos_pool))
        n_neg = min(max(neg_per_pos * n_pos, n_pos), len(neg_pool)) if n_pos > 0 else 0
        if n_pos < 2 or n_neg < 2:
            preds.append(("None", true_c))
            continue

        pos_idx = rng.choice(pos_pool, size=n_pos, replace=False)
        neg_idx = rng.choice(neg_pool, size=n_neg, replace=False)

        best_p, best_c = 1.0, None
        for cand in classes:
            genes = [gg for gg in sets_dict[cand] if gg in g2i]
            if len(genes) < MIN_GENES:
                continue
            gi = np.array([g2i[gg] for gg in genes], int)

            w = None
            if weights_dict is not None and cand in weights_dict:
                w = np.array([weights_dict[cand].get(gg, 1.0) for gg in genes], float)

            pos_scores = _aggregate_scores(X[pos_idx], gi, weights=w, agg=agg, trim_pct=trim_pct)
            neg_scores = _aggregate_scores(X[neg_idx], gi, weights=w, agg=agg, trim_pct=trim_pct)
            if pos_scores is None or neg_scores is None:
                continue

            diff = float(pos_scores.mean() - neg_scores.mean())
            p = ttest_ind(pos_scores, neg_scores, equal_var=False, alternative='greater').pvalue
            if (diff > 0) and (p < best_p):
                best_p, best_c = p, cand

            if cand == true_c:
                d = _cohens_d_posneg(pos_scores, neg_scores)
                rows.append({"class": true_c, "diag_p": float(p), "diag_d": float(d)})

        preds.append((best_c if best_c is not None else classes[0], true_c))

    df_pred = pd.DataFrame({"class": [t for (_, t) in preds],
                            "predicted": [p for (p, _) in preds]})
    df_pred["correct"] = (df_pred["class"] == df_pred["predicted"])
    acc = float(df_pred["correct"].mean()) if len(df_pred) else np.nan
    diag = pd.DataFrame(rows).groupby("class", as_index=True).last()
    return acc, df_pred, diag


## 6) Run cross-platform evaluation 


In [None]:
def run_once(seed, sets, weights, neg_per_pos=5):
    acc, df_pred, diag = classify_by_ttest_custom(
        adata_new, sets_dict=sets, weights_dict=weights,
        K=100, neg_per_pos=neg_per_pos, agg="weighted", trim_pct=0.10,
        seed=seed, boot_idx=None
    )
    return float(acc), df_pred, diag

# Sweep across seeds
SEEDS = [1,2,3,4,5, 6, 7, 8, 9, 10]
acc_llm, acc_mrk = [], []
for s in SEEDS:
    aL, *_ = run_once(s, ref_llm_top_X, ref_llm_w_X, neg_per_pos=5)
    aM, *_ = run_once(s, marker_top_X,  marker_w_X,  neg_per_pos=5)
    acc_llm.append(aL); acc_mrk.append(aM)

print("LLM runs:", np.round(acc_llm, 3).tolist())
print("MRK runs:", np.round(acc_mrk, 3).tolist())

def mean_ci95(a):
    a = np.asarray(a, float); n=len(a); m=np.mean(a)
    sd = np.std(a, ddof=1) if n>1 else 0.0
    tcrit = student_t.ppf(0.975, n-1) if n>1 else 0.0
    half = tcrit * sd / np.sqrt(n) if n>1 else 0.0
    return m, half

def paired_stats(a, b):
    a, b = np.asarray(a, float), np.asarray(b, float)
    t_res = ttest_rel(a, b)
    try:
        w_res = wilcoxon(a, b, zero_method="wilcox")
        w_p = w_res.pvalue
    except Exception:
        w_p = np.nan
    return t_res.statistic, t_res.pvalue, w_p

def plot_acc_ci(acc_llm, acc_mrk, title="MERFISH → Stereo-seq (no-disjoint)", outfile="fig_llm_vs_markers.png"):
    mA, ciA = mean_ci95(acc_llm); mB, ciB = mean_ci95(acc_mrk)
    _, p_t, p_w = paired_stats(acc_llm, acc_mrk)

    x = np.array([0,1], float)
    fig, ax = plt.subplots(1,1, figsize=(5.5,4), dpi=150)
    ax.bar(x, [mA, mB], yerr=[ciA, ciB], width=0.6,
           capsize=5, alpha=0.7, edgecolor="black")

    rng = np.random.default_rng(123)
    j1 = x[0] + rng.normal(0, 0.03, size=len(acc_llm))
    j2 = x[1] + rng.normal(0, 0.03, size=len(acc_mrk))
    ax.scatter(j1, acc_llm, s=18, zorder=3)
    ax.scatter(j2, acc_mrk, s=18, zorder=3)
    for xi, yi, xj, yj in zip(j1, acc_llm, j2, acc_mrk):
        ax.plot([xi, xj], [yi, yj], color="0.7", lw=0.8, zorder=2)

    ax.set_xticks(x); ax.set_xticklabels(["LLM (frozen)", "Markers"])
    ax.set_ylim(0.0, 1.05); ax.set_ylabel("Accuracy")
    ax.grid(axis="y", ls="--", alpha=0.3)
    ax.set_title(title + f"\npaired t p={p_t:.2e} | Wilcoxon p={p_w:.2e}")
    plt.tight_layout()
    plt.savefig(outfile, bbox_inches="tight")
    plt.show()
    print("Saved:", outfile)

plot_acc_ci(acc_llm, acc_mrk)
