Copyright 2023 Recursion

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [10]:
from efaar_benchmarking.data_loading import load_cpg16_crispr
from efaar_benchmarking.efaar import *
from efaar_benchmarking.constants import *
from efaar_benchmarking.benchmarking import *
from efaar_benchmarking.plotting import plot_recall

recall_threshold_pairs = []
start = 0.01
end = 0.99
step = 0.01

while start <= .105 and end >= .895:
    recall_threshold_pairs.append((round(start,2), round(end,2)))
    start += step
    end -= step

print(recall_threshold_pairs)

[(0.01, 0.99), (0.02, 0.98), (0.03, 0.97), (0.04, 0.96), (0.05, 0.95), (0.06, 0.94), (0.07, 0.93), (0.08, 0.92), (0.09, 0.91), (0.1, 0.9)]


In [11]:
pc_count = 128
compute_univariate_metrics = False # note that if you change this to True, the run will take a couple hours to complete

features, metadata = load_cpg16_crispr()
features, metadata = filter_cell_profiler_features(features, metadata)

all_embeddings_pre_agg = {}

### Generate embeddings
print("PCA embedding for", pc_count, "dimensions...")
embeddings = embed_by_pca(features.values, metadata, variance_or_ncomp=pc_count, batch_col=JUMP_BATCH_COL)
all_embeddings_pre_agg[f"CP-PCA{pc_count}-CS"] = centerscale_on_controls(embeddings, metadata, pert_col=JUMP_PERT_LABEL_COL, control_key=JUMP_CONTROL_PERT_LABEL) ## CS alignment
all_embeddings_pre_agg[f"CP-PCA{pc_count}-TVN"] = tvn_on_controls(embeddings, metadata, pert_col=JUMP_PERT_LABEL_COL, control_key=JUMP_CONTROL_PERT_LABEL, batch_col_coral=JUMP_BATCH_COL_2)  ## TVN alignment

### Aggregate and compute benchmarks -- consider saving computationally expensive results like all_embeddings_pre_agg, dist_res, cons_res
all_metrics = {}
for k, embeddings in all_embeddings_pre_agg.items():
    if compute_univariate_metrics:
        dist_res = pert_signal_distance_benchmark(embeddings, metadata, pert_col=JUMP_PERT_LABEL_COL, batch_col=JUMP_BATCH_COL, control_key=JUMP_CONTROL_PERT_LABEL, keys_to_drop=['negCtrl', 'no-guide'], n_samples=1000)
        print(k, sum(dist_res.pval <= .01) / sum(~pd.isna(dist_res.pval)))
        cons_res = pert_signal_consistency_benchmark(embeddings, metadata, pert_col=JUMP_PERT_LABEL_COL, batch_col=JUMP_BATCH_COL, keys_to_drop=[JUMP_CONTROL_PERT_LABEL, 'negCtrl', 'no-guide'], n_samples=1000)
        print(k, sum(cons_res.pval <= .01) / sum(~pd.isna(cons_res.pval)))
    map_data = aggregate(embeddings, metadata, pert_col=JUMP_PERT_LABEL_COL, control_key=JUMP_CONTROL_PERT_LABEL)
    metrics = known_relationship_benchmark(map_data, recall_thr_pairs=recall_threshold_pairs, pert_col=JUMP_PERT_LABEL_COL, n_null_samples = 10000, n_iterations = 1)
    print(k)
    print(metrics.groupby('source')['recall_0.05_0.95'].mean())
    all_metrics[f"JUMP {k}"] = metrics
plot_recall(all_metrics, right_sided=right_sided)