In [None]:
import sys
sys.path.append('../src/')  # Replace with your actual path
from train import str2bool, set_seed, parse_config, get_predictions, get_attention, get_embedding

import json

import datetime
import os
import pickle as pkl
import random

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from presage_datamodule import ReploglePRESAGEDataModule
from evaluator import Evaluator

In [None]:
dataset="replogle_k562_essential_unfiltered"
seed="seed_0"

default_config_file = "../configs/defaults_config.json"
singles_config_file = "../configs/singles_config.json"
ds_config_file = f"../configs/{dataset}_config.json"

# Load the default config
with open(default_config_file, "r") as f:
    config = json.load(f)
with open(singles_config_file,"r") as f:
    singles_config = json.load(f)
with open(ds_config_file,"r") as f:
    ds_config = json.load(f)

singles_config.update(singles_config)
singles_config.update(ds_config)

new_config = {}
for key, value in singles_config.items():
    if value is not None and key not in {"config", "data_config"}:
        new_config[key.replace("_", ".", 1)] = value
singles_config = new_config
config.update(singles_config)

In [None]:
# set training.eval_test to True to run all eval
modify_config = {"training.eval_test":True,
"model.pathway_files": "../sample_files/prior_files/sample.knowledge_experimental.txt",
"data.data_dir":"../data/",}

config.update(modify_config)

In [None]:
config = parse_config(config)

set_seed(config["training"].pop("seed", None))

offline = config["training"].pop("offline", False)
do_test_eval = config["training"].pop("eval_test", True)

In [None]:
config['data']['dataset'] = dataset

config['data']['seed'] = f"../splits/{dataset}_random_splits/{seed}.json"
seed = config["data"].pop("seed")
datamodule = ReploglePRESAGEDataModule.from_config(config["data"])
datamodule.do_test_eval = do_test_eval

if hasattr(datamodule, "set_seed"):
    datamodule.set_seed(seed)
config["data"]["seed"] = seed

In [None]:
datamodule.prepare_data()
datamodule.setup("fit")
datamodule.setup("test")
datamodule._data_setup = False

In [None]:
predictions_file = config["training"].pop("predictions_file", None)
mean_preds = pd.read_csv(predictions_file,index_col=0)

In [None]:
mean_preds

In [None]:
train_adata = datamodule.train_dataset.adata

ctrl_cells = train_adata[
    train_adata.obs.loc[:, datamodule.perturb_field] == datamodule.control_key
]
train_keys = datamodule.splits["train"]

adata = datamodule.load_preprocessed()
adata.X = adata.X - np.mean(ctrl_cells.X, axis=0)



evaluator = Evaluator(
    datamodule.var_names,
    datamodule.degs,
    ctrl_cells,
    train_keys,
    adata,
    geneset_file=datamodule.gs_file,
    perturbation_cluster_file=datamodule.pclust_file,
    ncells_per_perturbation_file=datamodule.ncells_per_perturbation_file,
    dataset=datamodule.dataset,
    seed=datamodule.seed,
)

In [None]:
grouped = [tup for tup in adata.obs.groupby(datamodule.perturb_field)]
tgt_inds = [tup[0] for tup in grouped]
mean_tgts = pd.DataFrame([adata[tup[1].index].X.mean(axis=0) for tup in grouped], index=tgt_inds, columns=adata.var_names)

In [None]:
mean_tgts = mean_tgts.loc[mean_preds.index,:]

In [None]:
# Find missing genes by comparing datamodule.var_names to existing columns in mean_preds 
missing_genes = datamodule.var_names[np.isin(datamodule.var_names, mean_preds.columns, invert=True)]
mean_preds = pd.concat([mean_preds, pd.DataFrame(0, index=mean_preds.index, columns=missing_genes)], axis=1)

In [None]:
mean_preds = mean_preds.loc[:,datamodule.var_names]
mean_tgts = mean_tgts.loc[:,datamodule.var_names]

In [None]:
temp = evaluator(mean_tgts.index.values.ravel(), mean_tgts.values, mean_preds.values, "True")

In [None]:
# dataframe with eval metrics
eval_df = evaluator.eval_dfs
eval_df['split'] = seed.split("/")[-1].split(".json")[0]

In [None]:
# perturbations with statistical effect from the mean
perturbations_with_effect = evaluator.perturbations_with_effect

# evaluation metrics for individual perturbations
single_perturbation_predictions = evaluator.all_single_evals

# perturbations with significant effect on gene sets at different MAD levels
virtual_screen_enriched_perts = evaluator.ground_truth_virtual_screen_perts

In [None]:
eval_df