# One-time prior synthesis (MERFISH)

This notebook implements the **one-time** LLM reasoning step to synthesize gene-set priors and then **freeze** them to disk.
1) Loads MERFISH (Squidpy)
2) Builds a spatial kNN graph
3) Learns spot/cell embeddings using a Graph Autoencoder (GAE)
4) Runs one-time LLM inference to generate frozen, reusable priors:
   - per-class gene sets from LLM justifications


Outputs saved to `OUTDIR`:
- `ref_llm_sets.json` — refined per-class gene sets (unordered, for audit)
- `ref_llm_top.json`  — final frozen per-class gene sets used downstream
- `ref_llm_w.json`    — per-class gene weights (used for weighted aggregation)
- `marker_top.json` / `marker_w.json` — matched DE-marker baselines
- `split.json`        — train/test indices used for (leak-free) training stats

Downstream evaluation (Welch t-test + sweeps) is in Notebook 02 and **does not** require any LLM access.

## 0) Imports 



In [None]:
import os, json, re, math, time, random
from functools import lru_cache
import numpy as np
import pandas as pd

from typing import List, Dict, Iterable

import scanpy as sc
import squidpy as sq

from scipy.spatial import cKDTree
from scipy.sparse import coo_matrix

import torch
from torch_geometric.data import Data
from torch_geometric.nn import GAE, GCNConv
from torch_geometric.utils import from_scipy_sparse_matrix

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## 1)  Data loading, pre-processing and GAE model training

In [None]:
adata = sq.datasets.merfish()

if "Cell_class" not in adata.obs.columns:
    raise ValueError("adata.obs does not contain 'Cell_class'.")

rename_dict = {
    "Endothelial 1": "Endothelial", "Endothelial 2": "Endothelial", "Endothelial 3": "Endothelial",
    "OD Mature 1": "OD Mature", "OD Mature 2": "OD Mature", "OD Mature 3": "OD Mature", "OD Mature 4": "OD Mature",
    "OD Immature 1": "OD Immature", "OD Immature 2": "OD Immature",
}
adata.obs["Cell_class"] = adata.obs["Cell_class"].replace(rename_dict)
if "Ambiguous" in adata.obs["Cell_class"].unique():
    adata = adata[adata.obs["Cell_class"] != "Ambiguous"].copy()

# DE markers

if "rank_genes_groups" not in adata.uns:
    sc.tl.rank_genes_groups(adata, groupby="Cell_class", method="wilcoxon")

def get_top_marker_genes_for_class(adata, cls, top_n=20):
    rg = adata.uns.get("rank_genes_groups", None)
    if rg is None or cls not in rg["names"].dtype.names:
        return []
    return [str(g) for g in list(rg["names"][cls][:top_n])]

K_MARKERS = 20
MARKER_DICT = {c: get_top_marker_genes_for_class(adata, c, top_n=K_MARKERS) for c in ALL_CLASSES}

# Build spatial kNN adjacency

def build_spatial_knn_graph(adata, k=30):
    coords = adata.obs[["Centroid_X", "Centroid_Y"]].values
    tree = cKDTree(coords)
    rows, cols = [], []
    for i, pt in enumerate(coords):
        _, idxs = tree.query(pt, k=k + 1)
        for nb in idxs[1:]:
            rows.append(i); cols.append(nb)
    mat = coo_matrix((np.ones(len(rows)), (rows, cols)), shape=(adata.n_obs, adata.n_obs))
    return mat.maximum(mat.transpose())

adj = build_spatial_knn_graph(adata, k=30)
edge_index, _ = from_scipy_sparse_matrix(adj)
edge_index.shape

# PCA + GAE embeddings

N_PCS = 30
sc.pp.pca(adata, n_comps=N_PCS, random_state=RANDOM_SEED)
X_pca = adata.obsm["X_pca"]

class Encoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels=20):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 64)
        self.conv2 = GCNConv(64, out_channels)

    def forward(self, x, edge_index):
        x = torch.relu(self.conv1(x, edge_index))
        return self.conv2(x, edge_index)

model = GAE(Encoder(in_channels=N_PCS, out_channels=20)).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-2)

data_pyg = Data(
    x=torch.tensor(X_pca, dtype=torch.float32),
    edge_index=edge_index
).to(device)

EPOCHS = 30
model.train()
for epoch in range(EPOCHS):
    opt.zero_grad()
    z = model.encode(data_pyg.x, data_pyg.edge_index)
    loss = model.recon_loss(z, data_pyg.edge_index)
    loss.backward()
    opt.step()
    if (epoch + 1) % 5 == 0:
        print(f"epoch {epoch+1}/{EPOCHS} | recon_loss={loss.item():.4f}")

model.eval()
with torch.no_grad():
    SPATIAL_EMB = model.encode(data_pyg.x, data_pyg.edge_index).cpu().numpy()

# top genes + neighborhood top genes

VAR_NAMES = list(map(str, adata.var_names))

def top_expressed_genes(adata, idx, top_n=10):
    x = adata.X[idx]
    x = x.A1 if hasattr(x, "A1") else np.asarray(x).ravel()
    top_idx = np.argsort(x)[-top_n:][::-1]
    return [VAR_NAMES[i] for i in top_idx]

def neighbor_top_genes(adata, adj, idx, top_n=5):
    nbrs = adj.getrow(idx).indices
    if len(nbrs) == 0:
        return []
    avg = np.array(adata.X[nbrs].mean(axis=0)).ravel()
    top_idx = np.argsort(avg)[-top_n:][::-1]
    return [VAR_NAMES[i] for i in top_idx]




## 2) Vertex AI / LLM set up 

If we need to run LLM calls, uncomment the imports and use the agent helpers below.
We can skip LLM calls entirely, if we have already have `llm_records` saved from a prior run.

In [None]:
# --- Vertex environment (set these in your shell or here)
GOOGLE_CLOUD_PROJECT  = os.getenv("GOOGLE_CLOUD_PROJECT",  "YOUR_PROJECT_ID")
GOOGLE_CLOUD_LOCATION = os.getenv("GOOGLE_CLOUD_LOCATION", "global")


In [None]:
ALL_CLASSES = sorted(map(str, adata.obs["Cell_class"].unique()))

def build_prompt(cell_idx: int, candidate_classes: List[str], k_markers_per_class: int = 10) -> str:
    marker_lines = []
    for c in candidate_classes:
        mk = _hk_filter_to_panel(MARKER_DICT.get(c, [])[:k_markers_per_class], panel)
        marker_lines.append(f"- {c}: {', '.join(mk) if mk else 'none'}")
    
    expr = top_expressed_genes(adata, cell_idx)
    nbr  = neighbor_top_genes(adata, adj, cell_idx)
    emb  = SPATIAL_EMB[cell_idx][:5]
    norm = float(np.linalg.norm(SPATIAL_EMB[cell_idx]))
    return (
        "We are classifying a MERFISH cell using marker genes, expression, neighbour context, and spatial embedding.\n"
        f"Marker genes (hints): {', '.join(marker_lines) or 'none'}.\n"
        f"Top expressed genes: {', '.join(expr)}.\n"
        f"Neighbour genes: {', '.join(nbr) or 'none'}.\n"
        f"Embedding dims: [{', '.join(f'{v:.2f}' for v in emb)}], norm ~ {norm:.2f}.\n"
        f"Candidate classes: {candidate_classes}.\n"
        "Return *only* JSON with keys 'label','confidence','justification'.\n"
    )



In [None]:
def configure_agents():
    tools = []
    inf = Agent(
        role="Bioinformatics Researcher",
        goal="Predict MERFISH cell type and return JSON {label, confidence, justification}.",
        backstory="You classify spatial single cells using gene markers and neighborhood context.",
        allow_delegation=False,
        tools=tools,
        llm=LLM(
            model="vertex_ai/gemini-2.5-pro",
            temperature=0.7,
            vertex_project=os.environ["GOOGLE_CLOUD_PROJECT"],
            vertex_location=os.environ["GOOGLE_CLOUD_LOCATION"],
        ),
        verbose=False,
        memory=False,
    )
    rev = Agent(
        role="Senior QA Bioinformatician",
        goal="Review predictions and return JSON {'label': <class>} only.",
        backstory="Senior scientist checks markers and adjusts labels.",
        allow_delegation=True,
        tools=tools,
        llm=LLM(
            model="vertex_ai/gemini-2.5-pro",
            temperature=0.7,
            vertex_project=os.environ["GOOGLE_CLOUD_PROJECT"],
            vertex_location=os.environ["GOOGLE_CLOUD_LOCATION"],
        ),
        verbose=False,
        memory=False,
    )
    return inf, rev

# inf, rev = configure_agents()

## 3) Bootstrap indices + `llm_records`

`llm_records` must be a list of dicts with keys at least: `idx`, `label`, `confidence`, `justification`.

In [None]:
OUTDIR = "./"
os.makedirs(OUTDIR, exist_ok=True)

LLM_RECORDS_PATH = os.path.join(OUTDIR, "llm_records.json")



def balanced_indices(n_per_class: int = 20, seed: int = 42):
    rng = np.random.default_rng(seed)
    y = adata.obs["Cell_class"].astype(str).values
    out = []
    for ct in ALL_CLASSES:
        idxs = np.where(y == ct)[0]
        take = min(len(idxs), n_per_class)
        if take > 0:
            out.extend(rng.choice(idxs, take, replace=False))
    return np.array(out, dtype=int)

boot_idx = balanced_indices(n_per_class=20, seed=42)
print("Boot cells:", len(boot_idx))


# Option A: generate llm_records by calling the LLM (run once) 
def call_llm_once(prompt: str, agent) -> Dict:
    task = Task(description=prompt, expected_output="JSON with keys 'label','confidence','justification'", agent=agent)
    result = Crew(agents=[agent], tasks=[task], process=Process.sequential).kickoff()
    raw = getattr(result, "raw", None)
    if not raw or not isinstance(raw, str) or not raw.strip():
        raise ValueError("Invalid response from LLM call.")
    txt = re.sub(r"```json|```", "", raw).strip()
    js = json.loads(txt)
    if isinstance(js, list) and js and isinstance(js[0], dict):
        js = js[0]
    if not isinstance(js, dict) or "label" not in js:
        raise ValueError(f"LLM JSON missing 'label': {js}")
    return js

for i in boot_idx:
    prompt = build_prompt(int(i), ALL_CLASSES, k_markers_per_class=10)
    js = call_llm_once(prompt, inf_agent)
    llm_records.append({"idx": int(i), **js})

with open(LLM_RECORDS_PATH, "w") as f:
    json.dump(llm_records, f, indent=2)
print(f"Saved llm_records → {LLM_RECORDS_PATH}")


# Option B: load existing records  
if os.path.exists(LLM_RECORDS_PATH):
    with open(LLM_RECORDS_PATH) as f:
        llm_records = json.load(f)
    print(f"Loaded llm_records: {len(llm_records)} from {LLM_RECORDS_PATH}")
else:
    llm_records = []
    print("No llm_records.json found. To generate it, enable Option A below.")


## 4) Build raw per-class LLM gene sets from justifications

We extract gene mentions from the LLM `justification` field, filter to the gene panel, and drop housekeeping genes.  
We then **bucket** these genes by the **true class** of each bootstrap cell (`use_true_class=True`) since the bootstrap set is labeled.

In [None]:
HOUSEKEEPING_RE = r'^(Rpl|Rps|Mrpl|Mrps|mt\-|Mt\-)'
panel = set(map(str, adata.var_names))

def _hk_filter_to_panel(gs, panel):
    gs = [str(g) for g in gs if str(g) in panel]
    gs = [g for g in gs if not re.match(HOUSEKEEPING_RE, g)]
    out, seen = [], set()
    for g in gs:
        if g not in seen:
            seen.add(g); out.append(g)
    return out


def extract_gene_mentions(text: str) -> List[str]:
    """Simple gene-symbol extractor: keep tokens that match adata.var_names (case-insensitive)."""
    if not isinstance(text, str) or not text.strip():
        return []
    # map uppercase -> canonical panel symbol
    up = {str(g).upper(): str(g) for g in map(str, adata.var_names)}
    toks = re.findall(r"\b[A-Za-z0-9\-]{2,}\b", text)
    out = []
    for t in toks:
        if t.upper() in up:
            out.append(up[t.upper()])
    # unique keep order
    seen=set(); uniq=[]
    for g in out:
        if g not in seen:
            seen.add(g); uniq.append(g)
    return uniq

def llm_gene_sets_from_records_frozen(
    llm_records: List[dict],
    use_true_class: bool = True,
    min_genes: int = 3,
    topup_from_markers: bool = True,
) -> Dict[str, set]:
    class2genes = {ct: set() for ct in ALL_CLASSES}
    for r in llm_records:
        genes = extract_gene_mentions(r.get("justification",""))
        genes = _hk_filter_to_panel(genes, panel)
        if not genes:
            continue
        ct = (str(adata.obs["Cell_class"].iloc[int(r["idx"])])
              if use_true_class else str(r.get("label","None")))
        if ct in class2genes:
            class2genes[ct].update(genes)

    if topup_from_markers:
        for ct, gs in class2genes.items():
            if len(gs) < min_genes:
                extras = _hk_filter_to_panel(MARKER_DICT.get(ct, []), panel)
                need = max(0, min_genes - len(gs))
                class2genes[ct] = set(list(gs) + extras[:need])
    return class2genes

llm_sets_raw = llm_gene_sets_from_records_frozen(
    llm_records,
    use_true_class=True,
    min_genes=3,
    topup_from_markers=True
)



## 5) Refine gene sets on TRAIN only  and compute weights

In [None]:
MIN_GENES = 3

def _cohens_d_1gene(pos, neg):
    pos = np.asarray(pos, float); neg = np.asarray(neg, float)
    if len(pos) < 2 or len(neg) < 2:
        return np.nan
    mx, my = pos.mean(), neg.mean()
    vx, vy = pos.var(ddof=1), neg.var(ddof=1)
    denom = (len(pos) + len(neg) - 2)
    sp2 = ((len(pos)-1)*vx + (len(neg)-1)*vy) / denom if denom > 0 else np.nan
    if not np.isfinite(sp2) or sp2 <= 0:
        return np.nan
    return (mx - my) / math.sqrt(sp2)

def _per_gene_stats(X, y, gene_names, cls, genes, train_idx):
    g2i = {g:i for i,g in enumerate(gene_names)}
    idx_pos = [i for i in train_idx if y[i] == cls]
    idx_neg = [i for i in train_idx if y[i] != cls]
    out = {}
    for g in genes:
        if g not in g2i:
            continue
        gi = g2i[g]
        pos = X[idx_pos, gi]; neg = X[idx_neg, gi]
        diff = float(np.mean(pos) - np.mean(neg)) if len(pos) and len(neg) else np.nan
        d = _cohens_d_1gene(pos, neg)
        out[g] = (diff, d)
    return out

def build_refined_llm_sets(adata, llm_sets_raw, marker_dict, train_idx,
                           target_N=18, min_gene_d=0.10, keep_if_pos_only=True):
    X = adata.X.toarray() if hasattr(adata.X, "toarray") else np.asarray(adata.X)
    y = adata.obs["Cell_class"].astype(str).values
    gene_names = np.array(adata.var_names, dtype=str)
    panel = set(map(str, gene_names))

    llm_f = {c: _hk_filter_to_panel(list(gs), panel) for c, gs in llm_sets_raw.items()}

    refined_sets, gene_weights = {}, {}
    for c, llm_list in llm_f.items():
        if len(llm_list) == 0:
            continue

        stats = _per_gene_stats(X, y, gene_names, c, llm_list, train_idx)

        keep, wts = [], {}
        for gene in llm_list:
            diff, d = stats.get(gene, (np.nan, np.nan))
            if not np.isfinite(diff):
                continue
            if keep_if_pos_only and diff <= 0:
                continue
            if np.isfinite(d) and d < min_gene_d:
                continue
            keep.append(gene)
            wts[gene] = max(float(d), 0.0) if np.isfinite(d) else 0.0

        if len(keep) < MIN_GENES:
            scored = [(stats.get(g,(0,0))[0], stats.get(g,(0,0))[1], g) for g in llm_list]
            scored.sort(key=lambda t: (t[0], t[1]), reverse=True)
            keep = [g for _,__,g in scored[:max(MIN_GENES, min(target_N, len(scored)))]]
            wts  = {g: max(stats.get(g,(0,0))[1], 0.0) for g in keep}

        rk = _hk_filter_to_panel(marker_dict.get(c, []), panel)
        for mg in rk:
            if len(keep) >= target_N:
                break
            if mg not in keep:
                keep.append(mg)
                wts[mg] = wts.get(mg, 0.05)

        vals = np.array([wts[g] for g in keep], float)
        if np.isfinite(vals).all() and vals.sum() > 0:
            scale = float(len(vals) / vals.sum())
            for k in keep:
                wts[k] *= scale
        else:
            for k in keep:
                wts[k] = 1.0

        refined_sets[c] = set(keep[:target_N])
        gene_weights[c] = {g: float(wts[g]) for g in refined_sets[c]}
    return refined_sets, gene_weights


skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
train_idx, test_idx = next(skf.split(adata.X, adata.obs["Cell_class"]))


marker_ranked = {c: MARKER_DICT.get(c, [])[:200] for c in ALL_CLASSES}

ref_llm_sets, ref_llm_w = build_refined_llm_sets(
    adata=adata,
    llm_sets_raw=llm_sets_raw,
    marker_dict=marker_ranked,
    train_idx=train_idx,
    target_N=18,
    min_gene_d=0.10,
)

## 6) Freeze final priors and build a matched marker baseline

In [None]:
OUTDIR = "./"
os.makedirs(OUTDIR, exist_ok=True)

def sets_to_jsonable(d):
    return {str(k): sorted(map(str, v)) for k, v in d.items()}

def weights_to_jsonable(d):
    return {str(c): {str(g): float(w) for g, w in gm.items()} for c, gm in d.items()}

def freeze_llm_top(ref_llm_sets, ref_llm_w):
    out = {}
    for c in ref_llm_sets:
        llm_sorted = sorted(ref_llm_sets[c], key=lambda g: ref_llm_w.get(c, {}).get(g, 0.0), reverse=True)
        out[c] = set(llm_sorted)
    return out

def make_marker_baseline(ref_llm_top, marker_ranked):
    marker_top, marker_w = {}, {}
    for c in ref_llm_top:
        n_match = len(ref_llm_top[c])
        m_full = _hk_filter_to_panel(marker_ranked.get(c, []), panel)
        m_take = m_full[:n_match]
        marker_top[c] = set(m_take)
        marker_w[c] = {g: 1.0/(i+1) for i, g in enumerate(m_take)}
    return marker_top, marker_w

ref_llm_top = freeze_llm_top(ref_llm_sets, ref_llm_w)
marker_top, marker_w = make_marker_baseline(ref_llm_top, marker_ranked)



## 7) Save priors for (Welch t-test evaluation)


In [None]:
def sets_to_jsonable(d):
    return {str(k): sorted(map(str, v)) for k, v in d.items()}

def weights_to_jsonable(d):
    return {str(c): {str(g): float(w) for g, w in gm.items()} for c, gm in d.items()}

with open(os.path.join(OUTDIR, "llm_records.json"), "w") as f:
    json.dump(llm_records, f, indent=2)

with open(os.path.join(OUTDIR, "ref_llm_sets.json"), "w") as f:
    json.dump(sets_to_jsonable(ref_llm_sets), f, indent=2)
with open(os.path.join(OUTDIR, "ref_llm_top.json"), "w") as f:
    json.dump(sets_to_jsonable(ref_llm_top), f, indent=2)
with open(os.path.join(OUTDIR, "ref_llm_w.json"), "w") as f:
    json.dump(weights_to_jsonable(ref_llm_w), f, indent=2)

with open(os.path.join(OUTDIR, "marker_top.json"), "w") as f:
    json.dump(sets_to_jsonable(marker_top), f, indent=2)
with open(os.path.join(OUTDIR, "marker_w.json"), "w") as f:
    json.dump(weights_to_jsonable(marker_w), f, indent=2)

with open(os.path.join(OUTDIR, "split.json"), "w") as f:
    json.dump({"train_idx": list(map(int, train_idx)),
               "test_idx":  list(map(int, test_idx))}, f, indent=2)

print("Saved frozen priors to:", OUTDIR)
