In [None]:
from efaar_benchmarking.data_loading import *
from efaar_benchmarking.efaar import *
from efaar_benchmarking.constants import *
from efaar_benchmarking.benchmarking import *
from efaar_benchmarking.plotting import *
import pickle

pc_count = 128 # embedding size
compute_univariate_metrics = False # note that if you change this to True, the run will take a couple hours to complete
save_results = True # if True, save computationally expensive results like the univariate metrics to disk not to recompute them next time
res_folder = 'data' # folder to save the proccessed data and results in
expression_data_folder = '../efaar_benchmarking/expression_data' # folder where the gene expression data is stored, this is needed to filter out unexpressed genes prior to saving the aggregated map for further evaluation

recall_threshold_pairs = []
start = 0.01
end = 0.99
step = 0.01
while start <= .105 and end >= .895:
    recall_threshold_pairs.append((round(start,2), round(end,2)))
    start += step
    end -= step
print(recall_threshold_pairs)

all_embeddings_pre_agg = {}

dataset = "JUMP" # replace accordingly to run the other datasets

if dataset == "GWPS":
    adata_norm = load_gwps("normalized")
    adata_raw = load_gwps("raw")
    metadata = adata_raw.obs # use adata_raw for both PCA and scVI maps since adata_raw has the same metadata as adata_norm

    pert_colname = GWPS_PERT_LABEL_COL
    ctrl_colname = GWPS_CONTROL_PERT_LABEL
    batch_colname = GWPS_BATCH_COL

    embeddings = embed_by_pca_anndata(adata_norm, pc_count)
    del adata_norm
    all_embeddings_pre_agg[f"PCA{pc_count}-CS"] = centerscale_on_controls(embeddings, metadata, pert_col=pert_colname, control_key=ctrl_colname)

    embeddings = embed_by_scvi_anndata(adata_raw.obs, n_latent=pc_count, n_hidden=pc_count*2)
    del adata_raw
    all_embeddings_pre_agg[f"scVI{pc_count}-CS"] = centerscale_on_controls(embeddings, metadata, pert_col=pert_colname, control_key=ctrl_colname)
    all_embeddings_pre_agg[f"scVI{pc_count}-TVN"] = tvn_on_controls(embeddings, metadata, pert_col=pert_colname, control_key=ctrl_colname, batch_col=batch_colname)

    unused_keys = []

elif dataset == "JUMP":
    features, metadata = load_cpg16_crispr()
    features, metadata = filter_cell_profiler_features(features, metadata)

    pert_colname = JUMP_PERT_LABEL_COL
    ctrl_colname = JUMP_CONTROL_PERT_LABEL
    batch_colname = JUMP_BATCH_COL
    batch_colname_2 = JUMP_BATCH_COL_2

    print("Computing PCA embedding for", pc_count, "dimensions...")
    embeddings = embed_by_pca(features.values, metadata, variance_or_ncomp=pc_count, batch_col=batch_colname)
    all_embeddings_pre_agg[f"CP-PCA{pc_count}-CS"] = centerscale_on_controls(embeddings, metadata, pert_col=pert_colname, control_key=ctrl_colname) ## CS alignment
    all_embeddings_pre_agg[f"CP-PCA{pc_count}-TVN"] = tvn_on_controls(embeddings, metadata, pert_col=pert_colname, control_key=ctrl_colname, batch_col_coral=batch_colname_2)  ## TVN alignment

    unused_keys = ['negCtrl', 'no-guide']

elif dataset == "PERISCOPE":
    features, metadata = load_periscope()

    pert_colname = PERISCOPE_PERT_LABEL_COL
    ctrl_colname = PERISCOPE_CONTROL_PERT_LABEL
    batch_colname = PERISCOPE_BATCH_COL

    print("Computing PCA embedding for", pc_count, "dimensions...")
    embeddings = embed_by_pca(features.values, metadata, variance_or_ncomp=pc_count, batch_col=batch_colname)
    all_embeddings_pre_agg[f"CP-PCA{pc_count}-CS"] = centerscale_on_controls(embeddings, metadata, pert_col=pert_colname, control_key=ctrl_colname) ## CS alignment
    all_embeddings_pre_agg[f"CP-PCA{pc_count}-TVN"] = tvn_on_controls(embeddings, metadata, pert_col=pert_colname, control_key=ctrl_colname, batch_col_coral=batch_colname)  ## TVN alignment

    unused_keys = ['negCtrl']

if save_results:
    metadata.to_pickle(f'{res_folder}/{dataset}_pre_agg_metadata.pkl') # save metadata to disk to get the statistics prior to aggregation, we will use it to check the total and expressed gene counts

### Compute benchmarks (perturbation signal benchmarks are computed pre-aggregation, while known relationship benchmarks are computed post-aggregation)
known_relationship_metrics = {}
for k, embeddings in all_embeddings_pre_agg.items():
    if compute_univariate_metrics:
        dist_res = pert_signal_distance_benchmark(embeddings, metadata, pert_col=pert_colname, batch_col=batch_colname, control_key=ctrl_colname, keys_to_drop=unused_keys, n_samples=1000)
        print(k, sum(dist_res.pval <= .01) / sum(~pd.isna(dist_res.pval)))
        
        cons_res = pert_signal_consistency_benchmark(embeddings, metadata, pert_col=pert_colname, batch_col=batch_colname, keys_to_drop=[ctrl_colname]+unused_keys, n_samples=1000)
        print(k, sum(cons_res.pval <= .01) / sum(~pd.isna(cons_res.pval)))

        if save_results:
            dist_res.to_csv(f'{res_folder}/{dataset}_distance_results_{k}.csv', index=False)
            cons_res.to_csv(f'{res_folder}/{dataset}_consistency_results_{k}.csv', index=False)

    map_data = aggregate(embeddings, metadata, pert_col=pert_colname, control_key=ctrl_colname)
    metrics = known_relationship_benchmark(map_data, recall_thr_pairs=recall_threshold_pairs, pert_col=pert_colname, n_null_samples = 10000, n_iterations = 1)
    print(k)
    print((metrics.groupby('source')['recall_0.05_0.95'].mean() * 100).round(1))
    known_relationship_metrics[f"{dataset} {k}"] = metrics

    if save_results:
        if dataset == "PERISCOPE":
            expr = pd.read_csv(f'{expression_data_folder}/HeLa_expression.csv') # note that we assume the HeLa expression data was used for PERISCOPE which is the default option in load_periscope()
            expr.columns = ['gene', 'tpm']
            expr.gene = expr.gene.apply(lambda x: x.split(' ')[0])
            expr_genes = list(expr.loc[expr.tpm > 0, 'gene'])
            ind = metadata[pert_colname].isin(expr_genes + [ctrl_colname])
        elif dataset == "JUMP":
            expr = pd.read_csv(f'{expression_data_folder}/U2OS_expression.csv', index_col=0)
            expr = expr.groupby('gene').zfpkm.agg('median').reset_index()
            expr_genes = list(expr.loc[expr.zfpkm >= -3, 'gene'])
            ind = metadata[pert_colname].isin(expr_genes + [ctrl_colname])
        else:
            ind = [True] * len(metadata)
        map_data = aggregate(embeddings[ind], metadata[ind], pert_col=pert_colname, control_key=ctrl_colname)
        with open(f'{res_folder}/{dataset}_aggr_{k}_map.pkl', 'wb') as outfile:
            pickle.dump(map_data, outfile)

plot_recall(known_relationship_metrics)
