## Imports and variable definitions

In [None]:
from efaar_benchmarking.data_loading import *
from efaar_benchmarking.efaar import *
from efaar_benchmarking.constants import *
from efaar_benchmarking.benchmarking import *
from efaar_benchmarking.plotting import *
import pickle

pc_count = 128

## GWPS run

In [None]:
dataset = "GWPS"
pert_colname = "gene"
gem_group_colname = "gem_group"
control_key = "non-targeting"
all_controls = ["non-targeting"]

adata_raw = load_gwps("raw")
print("Perturbation dataset loaded")
metadata = adata_raw.obs

all_embeddings_pre_agg = {}
print("Running for embedding size", pc_count)
all_embeddings_pre_agg[f"scVI{pc_count}"] = embed_by_scvi_anndata(adata_raw, batch_col=gem_group_colname, n_latent=pc_count, n_hidden=pc_count*2)
print("embed_by_scvi_anndata completed")
all_embeddings_pre_agg[f"scVI{pc_count}-CS"] = centerscale_on_controls(all_embeddings_pre_agg[f"scVI{pc_count}"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=gem_group_colname)
print("centerscale completed")
all_embeddings_pre_agg[f"scVI{pc_count}-TVN"] = tvn_on_controls(all_embeddings_pre_agg[f"scVI{pc_count}"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=gem_group_colname)
print("tvn completed")
all_embeddings_pre_agg[f"PCA{pc_count}"] = embed_by_pca_anndata(adata_raw, gem_group_colname, pc_count)
print("embed_by_pca_anndata completed")
all_embeddings_pre_agg[f"PCA{pc_count}-CS"] = centerscale_on_controls(all_embeddings_pre_agg[f"PCA{pc_count}"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=gem_group_colname)
print("centerscale completed")
all_embeddings_pre_agg[f"PCA{pc_count}-TVN"] = tvn_on_controls(all_embeddings_pre_agg[f"PCA{pc_count}"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=gem_group_colname)
print("tvn completed")

array_list = []
for k, emb in all_embeddings_pre_agg.items():
    print("Aggregating...")
    map_data = aggregate(emb, metadata, pert_col=pert_colname, keys_to_remove=all_controls)
    print("Computing recall...")
    metrics = known_relationship_benchmark(map_data, recall_thr_pairs=[(.05, .95)], pert_col=pert_colname)
    array_list.append((metrics.groupby("source")["recall_0.05_0.95"].mean() * 100).round(1).values)
res = np.vstack(array_list)
formatted_result = np.array2string(res, separator=", ")
print(formatted_result)

with open(f'data/{dataset}_map_cache.pkl', 'wb') as f:
    pickle.dump(map_data, f)  # storing the PCA-TVN map data for downstream analysis

with open(f'data/{dataset}_metadata.pkl', 'wb') as f:
    pickle.dump(metadata, f)  # storing the metadata for downstream analysis

## cpg0016 run

In [None]:
dataset = "cpg0016"
pert_colname = "Metadata_Symbol"
plate_colname = "Metadata_Plate"
run_colname = "Metadata_Batch"
control_key = "non-targeting"
all_controls = ["non-targeting", "no-guide"]

features, metadata = load_cpg16_crispr()
print("Perturbation dataset loaded")
features, metadata = filter_cell_profiler_features(features, metadata)

all_embeddings_pre_agg = {}
print("Computing PCA embedding for", pc_count, "dimensions...")
all_embeddings_pre_agg[f"PCA{pc_count}"] = embed_by_pca(features.values, metadata, variance_or_ncomp=pc_count, batch_col=plate_colname)
print("Computing centerscale...")
all_embeddings_pre_agg[f"PCA{pc_count}-CS"] = centerscale_on_controls(all_embeddings_pre_agg[f"PCA{pc_count}"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=run_colname)
print("Computing TVN...")
all_embeddings_pre_agg[f"PCA{pc_count}-TVN"] = tvn_on_controls(all_embeddings_pre_agg[f"PCA{pc_count}"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=run_colname)

expression_data_folder = "../efaar_benchmarking/expression_data"
expr = pd.read_csv(f"{expression_data_folder}/U2OS_expression.csv", index_col=0).groupby("gene").zfpkm.agg("median").reset_index()
unexpr_genes = list(expr.loc[expr.zfpkm < -3, "gene"])
expr_genes = list(expr.loc[expr.zfpkm >= -3, "gene"])
ind = metadata[pert_colname].isin(expr_genes + [control_key])

array_list = []
for k, emb in all_embeddings_pre_agg.items():
    print("Aggregating...")
    map_data = aggregate(emb[ind], metadata[ind], pert_col=pert_colname, keys_to_remove=all_controls)
    print("Computing recall...")
    metrics = known_relationship_benchmark(map_data, recall_thr_pairs=[(.05, .95)], pert_col=pert_colname)
    array_list.append((metrics.groupby("source")["recall_0.05_0.95"].mean() * 100).round(1).values)
res = np.vstack(array_list)
formatted_result = np.array2string(res, separator=", ")
print(formatted_result)

with open(f'data/{dataset}_map_cache.pkl', 'wb') as f:
    pickle.dump(map_data, f)  # storing the PCA-TVN map data for downstream analysis

with open(f'data/{dataset}_metadata.pkl', 'wb') as f:
    pickle.dump(metadata, f)  # storing the metadata for downstream analysis

## cpg0021 run

In [None]:
dataset = "cpg0021"
pert_colname = "Metadata_Foci_Barcode_MatchedTo_GeneCode"
plate_colname = "Metadata_Plate"
control_key = "nontargeting"
all_controls = ["nontargeting", "negCtrl"]

features, metadata = load_periscope()
print("Perturbation dataset loaded")

all_embeddings_pre_agg = {}
print("Computing PCA embedding for", pc_count, "dimensions...")
all_embeddings_pre_agg[f"PCA{pc_count}"] = embed_by_pca(features.values, metadata, variance_or_ncomp=pc_count, batch_col=plate_colname)
print("Computing centerscale...")
all_embeddings_pre_agg[f"PCA{pc_count}-CS"] = centerscale_on_controls(all_embeddings_pre_agg[f"PCA{pc_count}"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=plate_colname)
print("Computing TVN...")
all_embeddings_pre_agg[f"PCA{pc_count}-TVN"] = tvn_on_controls(all_embeddings_pre_agg[f"PCA{pc_count}"], metadata, pert_col=pert_colname, control_key=control_key, batch_col=plate_colname)

expression_data_folder = "../efaar_benchmarking/expression_data"
expr = pd.read_csv(f"{expression_data_folder}/HeLa_expression.csv") # note that we assume the HeLa expression data was used for PERISCOPE which is the default option in load_periscope()
expr.columns = ["gene", "tpm"]
expr.gene = expr.gene.apply(lambda x: x.split(" ")[0])
unexpr_genes = list(expr.loc[expr.tpm == 0, "gene"])
expr_genes = list(expr.loc[expr.tpm > 0, "gene"])
ind = metadata[pert_colname].isin(expr_genes + [control_key])

array_list = []
for k, emb in all_embeddings_pre_agg.items():
    print("Aggregating...")
    map_data = aggregate(emb[ind], metadata[ind], pert_col=pert_colname, keys_to_remove=all_controls)
    print("Computing recall...")
    metrics = known_relationship_benchmark(map_data, recall_thr_pairs=[(.05, .95)], pert_col=pert_colname)
    array_list.append((metrics.groupby("source")["recall_0.05_0.95"].mean() * 100).round(1).values)
res = np.vstack(array_list)
formatted_result = np.array2string(res, separator=", ")
print(formatted_result)

with open(f'data/{dataset}_map_cache.pkl', 'wb') as f:
    pickle.dump(map_data, f)  # storing the PCA-TVN map data for downstream analysis

with open(f'data/{dataset}_metadata.pkl', 'wb') as f:
    pickle.dump(metadata, f)  # storing the metadata for downstream analysis