In [None]:
from efaar_benchmarking.data_loading import load_replogle
from efaar_benchmarking.efaar import *
from efaar_benchmarking.constants import *
from efaar_benchmarking.benchmarking import univariate_consistency_benchmark, multivariate_benchmark
from efaar_benchmarking.plotting import plot_recall

recall_threshold_pairs = []
start = 0.01
end = 0.99
step = 0.01

while start <= .105 and end >= .895:
    recall_threshold_pairs.append((round(start,2), round(end,2)))
    start += step
    end -= step

print(recall_threshold_pairs)

In [None]:
pc_counts = [128, 256, 512, 1024]
all_embeddings_pre_agg = {}

adata_norm = load_replogle("genome_wide", "normalized", "../../project")
metadata = adata_norm.obs
### PCA embeddings with different PC counts and alignment
for pcc in pc_counts:
    print(pcc)
    embeddings = embed_by_pca_anndata(adata_norm, pcc)
    for k, fn in {f"PCA{pcc}-CS": centerscale_on_controls, f"PCA{pcc}-TVN": tvn_on_controls}.items():
        all_embeddings_pre_agg[k] = fn(embeddings, metadata, pert_col=REPLOGLE_PERT_LABEL_COL, control_key=REPLOGLE_CONTROL_PERT_LABEL)
del adata_norm

adata_raw = load_replogle("genome_wide", "raw", "../../project")
metadata = adata_raw.obs
### scVI embeddings with different latent and hidden node counts and alignment
for pcc in pc_counts:
    print(pcc)
    embeddings = embed_by_scvi_anndata(adata_raw, n_latent=pcc, n_hidden=pcc*2)
    for k, fn in {f"scVI{pcc}-CS": centerscale_on_controls, f"scVI{pcc}-TVN": tvn_on_controls}.items():
        all_embeddings_pre_agg[k] = fn(embeddings, metadata, pert_col=REPLOGLE_PERT_LABEL_COL, control_key=REPLOGLE_CONTROL_PERT_LABEL)
del adata_raw

### Aggregate and compute metrics
for right_sided in [False]:
    all_metrics = {}
    for k, embeddings in all_embeddings_pre_agg.items():
        # consistency_pvals = univariate_consistency_benchmark(embeddings, metadata, pert_col=REPLOGLE_PERT_LABEL_COL, keys_to_drop=[REPLOGLE_CONTROL_PERT_LABEL])
        map_data = aggregate(embeddings, metadata, pert_col=REPLOGLE_PERT_LABEL_COL, control_key=REPLOGLE_CONTROL_PERT_LABEL)
        metrics = multivariate_benchmark(map_data, recall_thr_pairs=recall_threshold_pairs, pert_col=REPLOGLE_PERT_LABEL_COL, n_null_samples = 10000, n_iterations = 1, right_sided=right_sided)
        print(k)
        print(metrics.groupby('source')['recall_0.05_0.95'].mean())
        all_metrics[f"GWPS {k}"] = metrics
    plot_recall(all_metrics, right_sided=right_sided, title="Right tail only" if right_sided else "Both tails")