Copyright 2023 Recursion

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
from efaar_benchmarking.data_loading import load_cpg16_crispr
from efaar_benchmarking.efaar import *
from efaar_benchmarking.constants import *
from efaar_benchmarking.benchmarking import univariate_consistency_benchmark, multivariate_benchmark
from efaar_benchmarking.plotting import plot_recall

recall_threshold_pairs = []
start = 0.01
end = 0.99
step = 0.01

while start <= .105 and end >= .895:
    recall_threshold_pairs.append((round(start,2), round(end,2)))
    start += step
    end -= step

print(recall_threshold_pairs)

In [None]:
pc_counts = [128, 256, 512, 1024, 2048]
all_embeddings_pre_agg = {}
features, metadata = load_cpg16_crispr() # loading may take some time if the files are not cached yet, depending on the speed of your internet connection
features, metadata = filter_cell_profiler_features(features, metadata)

# Raw CP features
all_embeddings_pre_agg["CP"] = features.values
all_embeddings_pre_agg["CP-CS"] = centerscale_on_controls(features.values, metadata, pert_col=JUMP_PERT_LABEL_COL, control_key=JUMP_CONTROL_PERT_LABEL)

### PCA embeddings with different PC counts and alignment
for pcc in pc_counts:
    print(pcc)
    embeddings = embed_by_pca(features.values, metadata, variance_or_ncomp=pcc, plate_col=JUMP_PLATE_COL)
    for k, fn in {f"CP-PCA{pcc}-CS": centerscale_on_controls, f"CP-PCA{pcc}-TVN": tvn_on_controls}.items():
        all_embeddings_pre_agg[k] = fn(embeddings, metadata, pert_col=JUMP_PERT_LABEL_COL, control_key=JUMP_CONTROL_PERT_LABEL)

### Aggregate and compute metrics
for right_sided in [False]:
    all_metrics = {}
    for k, embeddings in all_embeddings_pre_agg.items():
        # consistency_pvals = univariate_consistency_benchmark(embeddings, metadata, pert_col=JUMP_PERT_LABEL_COL, keys_to_drop=[JUMP_CONTROL_PERT_LABEL, 'no-guide'])
        map_data = aggregate(embeddings, metadata, pert_col=JUMP_PERT_LABEL_COL, control_key=JUMP_CONTROL_PERT_LABEL)
        metrics = multivariate_benchmark(map_data, recall_thr_pairs=recall_threshold_pairs, pert_col=JUMP_PERT_LABEL_COL, n_null_samples = 10000, n_iterations = 1, right_sided=right_sided)
        print(k)
        print(metrics.groupby('source')['recall_0.05_0.95'].mean())
        all_metrics[f"JUMP {k}"] = metrics
    plot_recall(all_metrics, right_sided=right_sided, title="Right tail only" if right_sided else "Both tails")