In [6]:
import numpy as np
import scanpy as sc
from scripts.EGGFM.eggfm import run_eggfm_dimred
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score

In [7]:
def subset_anndata(ad: sc.AnnData, n_cells: int, random_state: int = 0) -> sc.AnnData:
    rng = np.random.default_rng(random_state)
    n = ad.n_obs
    n_subset = min(n_cells, n)
    idx = rng.choice(np.arange(n), size=n_subset, replace=False)
    return ad[idx].copy()

def compute_ari_fixed(X, labels, k, random_state: int = 0) -> float:
    Xk = X[:, :k]
    km = KMeans(
        n_clusters=len(np.unique(labels)),
        n_init=10,
        random_state=random_state,
    )
    km.fit(Xk)
    return adjusted_rand_score(labels, km.labels_)

In [8]:
params = {
    "seed": 7,
    "pca_n_top_genes": 2000,

    "spec": {
        "n_pcs": 20,
        "dcol_max_cells": 3000,
        "ari_label_key": "Cell type annotation",  # <-- this must match an obs column name
        # "ari_label_key": "paul15_clusters",         # <-- this must match an obs column name
        "ari_n_dims": 10,                           # how many dims to use for ARI per embedding
        # "ad_file": "data/paul15/paul15.h5ad",
        "ad_file": "data/prep/qc.h5ad",
    },

    "qc": {
        "min_cells": 500,
        "min_genes": 200,
        "max_pct_mt": 15,
    },

    "eggfm_model": {
        "hidden_dims": [512, 512, 512, 512],
        "latent_dim": 64,
    },

    "eggfm_train": {
        "batch_size": 2048,
        "num_epochs": 150,
        "lr": 4.0e-4,
        "sigma": 1.0,
        "device": "cuda",
        "latent_space": "hvg",
        "early_stop_patience": 30,
        "early_stop_min_delta": 0.0,
        "n_cells_sample":20000
    },

    "eggfm_diffmap": {
        "geometry_source": "pca",          # "pca" or "hvg"
        "energy_source": "hvg",            # where SCM/Hessian read energies
        "metric_mode": "euclidean",    # "euclidean", "scm", or "hessian_mixed"
        "n_neighbors": 30,
        "n_comps": 30,
        "device": "cuda",
        "hvp_batch_size": 1024,
        "eps_mode": "median",
        "eps_value": 1.0,
        "eps_trunc": "yes",
        "distance_power": 1.0,
        "t": 0.5,
        "norm_type": "l1",

        # SCM hypered
        "metric_gamma": 0.4,
        "metric_lambda": 4.0,
        "energy_clip_abs": 2.0,
        "energy_batch_size": 2048,

        # Hessian mixing hyperparams
        "hessian_mix_mode": "none",   # "additive" | "multiplicative" | "none"
        "hessian_mix_alpha": 0.3,
        "hessian_beta": 0.2,
        "hessian_clip_std": 2.0,
        "hessian_use_neg": True,
    },
}

spec = params["spec"]
k = spec.get("ari_n_dims", spec.get("n_pcs", 10))

base_ad = sc.read_h5ad(spec.get("ad_file"))


scores_eggfm = []
scores_eggfm_2 = []
scores_eggfm_3 = []
scores_eggfm_4 = []
scores_eggfm_5 = []
# scores_eggfm_6 = []
scores_pca = []
scores_pca_2 = []
total = 1
for run in range(total):
    run_seed = 0 + run
    ad_prep = subset_anndata(base_ad, params["eggfm_train"]["n_cells_sample"], run_seed)
    labels = ad_prep.obs[spec["ari_label_key"]].to_numpy()
    
    qc = ad_prep.copy()

    print(f"=== Run {run+1}/{total} ===")
    qc, _ = run_eggfm_dimred(qc, params)        

    # PCA → Diffmap
    sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_pca")
    sc.tl.diffmap(qc, n_comps=k)
    X_diff_pca = qc.obsm["X_diffmap"][:, :k]
    qc.obsm["X_diff_pca"] = X_diff_pca

    # # PCA → Diffmap → Diffmap
    # sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_diff_pca")
    # sc.tl.diffmap(qc, n_comps=k)
    # X_diff_pca_double = qc.obsm["X_diffmap"][:, :k]
    # qc.obsm["X_diff_pca_x2"] = X_diff_pca_double

    # EGGFM
    X_eggfm = qc.obsm["X_eggfm"][:, :k]

    # EGGFM DM
    sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_eggfm")
    sc.tl.diffmap(qc, n_comps=k)
    X_diff_eggfm = qc.obsm["X_diffmap"][:, :k]
    qc.obsm["X_diff_eggfm"] = X_diff_eggfm

    # One more
    
    # EGGFM DM DM
    sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_diff_eggfm")
    sc.tl.diffmap(qc, n_comps=k)
    X_diff_eggfm_x2 = qc.obsm["X_diffmap"][:, :k]
    qc.obsm["X_diff_eggfm_x2"] = X_diff_eggfm_x2

    # EGGFM DM DM DM
    # sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_diff_eggfm_x2")
    # sc.tl.diffmap(qc, n_comps=k)
    # X_diff_eggm_x3 = qc.obsm["X_diffmap"][:, :k]
    # qc.obsm["X_diff_eggm_x3"] = X_diff_eggm_x3

    # # EGGFM DM DM DM DM
    # sc.pp.neighbors(qc, n_neighbors=30, use_rep="X_diff_eggm_x3")
    # sc.tl.diffmap(qc, n_comps=k)
    # X_diff_eggm_x4 = qc.obsm["X_diffmap"][:, :k]
    # qc.obsm["X_diff_eggm_x4"] = X_diff_eggm_x4

    scores_pca.append(compute_ari_fixed(X_diff_pca, labels, k))
    # scores_pca_2.append(compute_ari(X_diff_pca_double, labels, k))
    scores_eggfm.append(compute_ari_fixed(X_eggfm, labels, k))
    scores_eggfm_2.append(compute_ari_fixed(X_diff_eggfm, labels, k))
    scores_eggfm_3.append(compute_ari_fixed(X_diff_eggfm_x2, labels, k))
    # scores_eggfm_4.append(compute_ari_fixed(X_diff_eggm_x3, labels, k))
    # scores_eggfm_5.append(compute_ari_fixed(X_diff_eggm_x4, labels, k))

print("\n=== Variance results ===")
print(f"PCA→DM:    mean={np.mean(scores_pca):.4f}, std={np.std(scores_pca):.4f}")
# print(
#     f"PCA→DM2:   mean={np.mean(scores_pca_2):.4f}, std={np.std(scores_pca_2):.4f}"
# )
print(
    f"EGGFM:     mean={np.mean(scores_eggfm):.4f}, std={np.std(scores_eggfm):.4f}"
)
print(
    f"EGGFM DM:  mean={np.mean(scores_eggfm_2):.4f}, std={np.std(scores_eggfm_2):.4f}"
)
print(
    f"EGGFM DM2: mean={np.mean(scores_eggfm_3):.4f}, std={np.std(scores_eggfm_3):.4f}"
)
# print(
#     f"EGGFM DM3: mean={np.mean(scores_eggfm_4):.4f}, std={np.std(scores_eggfm_4):.4f}"
# )
# print(
#     f"EGGFM DM4: mean={np.mean(scores_eggfm_5):.4f}, std={np.std(scores_eggfm_5):.4f}")

  utils.warn_names_duplicates("obs")


=== Run 1/1 ===
[Energy DSM] Epoch 1/150  loss=1.915038e+03
[Energy DSM] Epoch 2/150  loss=1.848374e+03
[Energy DSM] Epoch 3/150  loss=1.844114e+03
[Energy DSM] Epoch 4/150  loss=1.844109e+03
[Energy DSM] Epoch 5/150  loss=1.844538e+03
[Energy DSM] Epoch 6/150  loss=1.843519e+03
[Energy DSM] Epoch 7/150  loss=1.843196e+03
[Energy DSM] Epoch 8/150  loss=1.842886e+03
[Energy DSM] Epoch 9/150  loss=1.842082e+03
[Energy DSM] Epoch 10/150  loss=1.842751e+03
[Energy DSM] Epoch 11/150  loss=1.842343e+03
[Energy DSM] Epoch 12/150  loss=1.840866e+03
[Energy DSM] Epoch 13/150  loss=1.838538e+03
[Energy DSM] Epoch 14/150  loss=1.835809e+03
[Energy DSM] Epoch 15/150  loss=1.829278e+03
[Energy DSM] Epoch 16/150  loss=1.820722e+03
[Energy DSM] Epoch 17/150  loss=1.813543e+03
[Energy DSM] Epoch 18/150  loss=1.805895e+03
[Energy DSM] Epoch 19/150  loss=1.799928e+03
[Energy DSM] Epoch 20/150  loss=1.794583e+03
[Energy DSM] Epoch 21/150  loss=1.789633e+03
[Energy DSM] Epoch 22/150  loss=1.785761e+03
[En

In [9]:
from datetime import datetime
import os
import pandas as pd

# ---- build the result row (config + score summaries) ----

all_results = []
ed = params["eggfm_diffmap"]
ari_label_key = params["spec"]["ari_label_key"]

row = {
    # break out the ARI label explicitly
    "ari_label_key": ari_label_key,

    # all diffmap hyperparams become their own columns
    **ed,

    # EGGFM score summaries (rounded for nicer CSV)
    "ari_pca_mean": round(float(np.mean(scores_pca)), 4),
    "ari_eggfm_mean": round(float(np.mean(scores_eggfm)), 4),
    "ari_eggfm_std": round(float(np.std(scores_eggfm)), 4),
    "ari_eggfm_dm_mean": round(float(np.mean(scores_eggfm_2)), 4),
    "ari_eggfm_dm_std": round(float(np.std(scores_eggfm_2)), 4),
    "ari_eggfm_dm2_mean": round(float(np.mean(scores_eggfm_3)), 4),
    "ari_eggfm_dm2_std": round(float(np.std(scores_eggfm_3)), 4),
    "ari_eggfm_dm3_mean": round(float(np.mean(scores_eggfm_4)), 4),
    "ari_eggfm_dm3_std": round(float(np.std(scores_eggfm_4)), 4),
    "ari_eggfm_dm4_mean": round(float(np.mean(scores_eggfm_5)), 4),
    "ari_eggfm_dm4_std": round(float(np.std(scores_eggfm_5)), 4),
}

all_results.append(row)

results_df = pd.DataFrame(all_results)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_path = f"out/check_var/eggfm_admr_layered_ablation_subset_{timestamp}.csv"
results_df.to_csv(results_path, index=False)

gcs_path = f"gs://medit-uml-prod-uscentral1-8e7a/{results_path}"
os.system(f"gsutil cp {results_path} {gcs_path}")
print("Uploaded to:", gcs_path)

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)
Copying file://out/check_var/eggfm_admr_layered_ablation_subset_20251130_173400.csv [Content-Type=text/csv]...
/ [1 files][  647.0 B/  647.0 B]                                                
Operation completed over 1 objects/647.0 B.                                      


Uploaded to: gs://medit-uml-prod-uscentral1-8e7a/out/check_var/eggfm_admr_layered_ablation_subset_20251130_173400.csv


In [10]:
# Final focused hyperparam sweep

pattern_grid = {
    # Euclidean baseline: test both l0 and linf, since linf was a star on Paul15
    "eucl_only": {
        "n_layers": [1, 2, 3],
        "norm_types": ["l0", "linf"],
        "distance_powers": [0.0, 0.25, 0.5],
    },
    # SCM: winner on Weinreb at L3, p ~ 0.25, norm l0
    "scm_alt_euclid": {
        "n_layers": [1, 2, 3],
        "norm_types": ["l0"],
        "distance_powers": [0.0, 0.25, 0.5],
    },
    # Hessian-mixed: keep as a single EGGFM competitor
    "hessMult_alt_euclid": {
        "n_layers": [1, 2, 3],
        "norm_types": ["l0"],
        "distance_powers": [0.25, 0.5],
    },
}

t_euclid_values = [2.0]

config_list = []
for pattern_type, grid in pattern_grid.items():
    for n_layers in grid["n_layers"]:
        for norm in grid["norm_types"]:
            for p in grid["distance_powers"]:
                for t_eucl in t_euclid_values:
                    exp_name = (
                        f"{pattern_type}_L{n_layers}_norm{norm}_p{p}_teucl{t_eucl}"
                    )
                    config_list.append(
                        dict(
                            exp_name=exp_name,
                            pattern_type=pattern_type,
                            n_layers=n_layers,
                            t_euclid=t_eucl,
                            norm_type=norm,
                            distance_power=p,
                        )
                    )

len(config_list)


33