In [None]:
import benchutils as bu

from tqdm import tqdm 
from benchdb import *
from itertools import product 
from benchplots import *
from benchutils import *
from benchmodels import RefCM
from benchdb import load_benchdb, save_benchdb

import sys
sys.path.append("../src/")

from refcm.embeddings import HVGEmbedder, PCAEmbedder, NMFEmbedder


%load_ext autoreload
%autoreload 2

In [None]:
dss = [
    "pancreas_celseq",
    "pancreas_celseq2",
    "pancreas_fluidigmc1",
    "pancreas_inDrop1",
    "pancreas_inDrop2",
    "pancreas_inDrop3",
    "pancreas_inDrop4",
    "pancreas_smarter",
    "pancreas_smartseq2",
]
dss = {s: load_adata(f"../data/{s}.h5ad") for s in tqdm(dss)}
combs = [(q, r) for q, r in product(dss, dss) if q != r]
key = 'celltype'

In [None]:
# brain
dss = ["ALM", "MTG", "VISp"]

dss = {s: load_adata(f"../data/{s}.h5ad") for s in tqdm(dss)}
combs = [(q, r) for q, r in product(dss, dss) if q != r]
key = 'labels34'


# embeddings

In [None]:
fpath = 'ablation.json'
db = load_benchdb(fpath)

m = RefCM()

kwargs = {"RefCM": {"discovery_threshold": 0.0, "verbose": False}}
embedders = {
    "HVG": HVGEmbedder,
    "PCA": PCAEmbedder,
    "NMF": NMFEmbedder
}
ks = {
    "HVG": [10, 20, 50, 100, 200, 500, 1000, 2000, 3000, 5000],
    "PCA": [10, 20, 50, 100, 200, 500],
    "NMF": [10, 20, 50, 100]
}

for e, embedder in embedders.items():
    for k in ks[e]:
        
        for q, r in combs:
            
            kwargs['RefCM']['embedder'] = embedder(k)
            
            
            m.setref(dss[r], key, **kwargs)
            a = m.annotate(dss[q], key, **kwargs)
            a.eval_(dss[q], key)
                        
            print(f"{e:<3} | {k:<5} : {q:>20} | {r:<20} : {a.cacc:>5.3f} | {m.time}")
            
            add_bench(db, f"{e}-{k}", q, r, a)
        
        save_benchdb(db, fpath)
        

# ground-cost

In [None]:
fpath = 'ablation.json'
db = load_benchdb(fpath)

m = RefCM()

kwargs = {"RefCM": {"discovery_threshold": 0.0, "verbose": False}}
pdists = ["euclidean", "manhattan", "cosine", "inner"]

for pdist in pdists:
    
    for q, r in combs:
        
        kwargs['RefCM']['pdist'] = pdist
        
        m.setref(dss[r], key, **kwargs)
        a = m.annotate(dss[q], key, **kwargs)
        a.eval_(dss[q], key)
                    
        print(f"{pdist:<15} : {q:>10} | {r:<10} : {a.acc:>5.3f} | {m.time}")
        
        add_bench(db, pdist, q, r, a)
    
    save_benchdb(db, fpath)

## OT solver

In [None]:
fpath = 'ablation.json'
db = load_benchdb(fpath)

m = RefCM()

kwargs = {"RefCM": {"discovery_threshold": 0.0, "verbose": False}}
reg_values = [0, 0]
reg_values += [0.01, 0.05, 0.1, 0.5, 1.0]
ot_solvers = ['emd', 'gw'] 
ot_solvers += ['sink'] * 5


for solver, reg in zip(ot_solvers, reg_values):
    
    for q, r in combs:
        
        kwargs['RefCM']['ot_solver'] = solver
        kwargs['RefCM']['reg'] = reg
        
        m.setref(dss[r], key, **kwargs)
        a = m.annotate(dss[q], key, **kwargs)
        a.eval_(dss[q], key)
        
        print(f"{solver}-{reg} : {q:>10} | {r:<10} : {a.acc:>5.3f} | {m.time}")
        
        if solver in ['emd', 'gw']:
            add_bench(db, solver, q, r, a)
        else:
            add_bench(db, f"{solver}-{reg}", q, r, a)
    
    save_benchdb(db, fpath)

In [None]:
import numpy as np
from sklearn.neighbors import NearestNeighbors

def uot_uniform(x_q, x_r):
    return np.ones(len(x_q)) / len(x_q), np.ones(len(x_r)) / len(x_r)

def uot_size_proportional(x_q, x_r):
    n_q, n_r = len(x_q), len(x_r)
    geom_mean = np.sqrt(n_q * n_r)
    a = np.ones(n_q) * (n_q / geom_mean) / n_q
    b = np.ones(n_r) * (n_r / geom_mean) / n_r
    return a, b

def uot_centroid_weighted(x_q, x_r, temp=1.0):
    def weights(X):
        if len(X) == 1:
            return np.array([1.0])
        dists = np.linalg.norm(X - X.mean(axis=0), axis=1)
        std = dists.std()
        if std < 1e-10:
            return np.ones(len(X)) / len(X)  # fallback to uniform
        scaled = -dists / (temp * std)
        scaled -= scaled.max()  # numerical stability for exp
        w = np.exp(scaled)
        return w / w.sum()
    return weights(x_q), weights(x_r)

def uot_variance_weighted(x_q, x_r):
    def weights(X):
        w = 1 / (X.var(axis=1) + 1e-8)
        return w / w.sum()
    return weights(x_q), weights(x_r)

def uot_density_weighted(x_q, x_r, k=10):
    def weights(X):
        k_use = min(k, len(X) - 1)
        if k_use < 1:
            return np.ones(len(X)) / len(X)
        nn = NearestNeighbors(n_neighbors=k_use + 1).fit(X)
        dists, _ = nn.kneighbors(X)
        density = 1 / (dists[:, 1:].mean(axis=1) + 1e-8)
        return density / density.sum()
    return weights(x_q), weights(x_r)



uot_configs = [
    ("uot-uniform", uot_uniform),
    ("uot-size_prop", uot_size_proportional),
    ("uot-centroid", uot_centroid_weighted), 
    ("uot-variance", uot_variance_weighted),
    ("uot-density", uot_density_weighted),
]


fpath = 'ablation.json'
db = load_benchdb(fpath)
m = RefCM()
kwargs = {"RefCM": {"discovery_threshold": 0.0, "verbose": False}}

reg = 0.1 

for name, w_fn in uot_configs:
    
    kwargs['RefCM']['ot_solver'] = 'uot'
    kwargs['RefCM']['reg'] = reg
    kwargs['RefCM']['uot_w_fn'] = w_fn
    
    for q, r in combs:
        
        m.setref(dss[r], key, **kwargs)
        a = m.annotate(dss[q], key, **kwargs)
        a.eval_(dss[q], key)
        
        print(f"{name} : {q:>10} | {r:<10} : {a.acc:>5.3f} | {m.time}")
        
        add_bench(db, name, q, r, a)
    
    save_benchdb(db, fpath)

## plot

In [None]:
cell_height = 20  # fixed pixels per cell
margin_top = 60   # for title/annotations
margin_bottom = 150  # for x-axis labels


def pad_labels_to_width(labels: list[str]) -> list[str]:
    """Pad labels to the same width using non-breaking spaces."""
    nbsp = "\u00A0"
    max_len = max(len(label) for label in labels)
    return [label + nbsp * (max_len - len(label)) for label in labels]

def make_ablation_heatmap(df, datasets, methods, **kwargs):
    n_methods = len(methods)
    height = n_methods * cell_height + margin_top + margin_bottom
    
    fig = plot_heatmap_panel(
        df,
        datasets=datasets,
        methods=methods,
        colorscale="plasma",
        reorder_by_mean=False,
        show_text=False,
        height=height,
        width=1200,  # fixed width
        **kwargs
    )
    
    fig.update_yaxes(dtick=1)
    fig.update_xaxes(
        tickangle=90,
    )
    
    # Pad methods to same width
    padded_methods = pad_labels_to_width(methods)

    # Update the y-axis labels on each trace
    for trace in fig.data:
        if trace.y is not None:
            # Create mapping from original to padded
            label_map = dict(zip(methods, padded_methods))
            print(label_map)
            trace.y = [label_map.get(y, y) for y in trace.y]
            
        
    for ann in fig.layout.annotations:
        ann.yshift = 4
        
    return fig

In [None]:
df = bench_to_df(db)
df.loc[df['method'] == 'uot-size_prop', 'method'] = 'uot-size-prop' # renaming

In [None]:
fig = plot_heatmap_panel(
    df,
    datasets=["Allen Brain", "scIB pancreas"],
    methods=df['method'].unique().tolist(),
    colorscale="Plasma",
    reorder_by_mean=False,
    show_text=False,
    height=800,
    width=1400
)
fig.update_yaxes(dtick=1)
fig.update_xaxes(tickangle=90)#, fontsize=12)

for ann in fig.layout.annotations:
    ann.yshift = 4

fig.show()
fig.write_image('SUP/ablation.png', scale=3)