In [61]:
import pandas as pd
import numpy as np
import scanpy as sc
import glob
from scipy.sparse import issparse
import os

### 1. Read in single-cell data and predictions from 5-fold

In [62]:
# Specify random seed and epochs ----------------------------------------
random_seed=12
num_epochs=100
filter_zeros = False

# Specify the model type and model name ----------------------------------------
model_type='morph' # Gears or morph or Control (for baseline), Truth
if model_type == 'morph':
    representation_type='DepMap_GeneEffect'
    model_name = 'best_model'
    recon_loss = 'mmd'
    null_label = 'zeros'
    mxAlpha = 2.0
    tolerance_epochs = 20
elif model_type == 'Gears':
    model_name='model.pt'
elif model_type == 'Control':
    model_name=None
elif model_type == 'Truth':
    model_name=None

# Specify the number of genes to use -------------------------------------------
num_gene = 2500 # 5044 or 1000

In [None]:
dataset_name = 'norman_k562_hvg'
dataset = dataset_name.replace('_hvg', '')
use_hvg = 'True' if 'hvg' in dataset_name else 'False'

scdata_file = pd.read_csv('/home/che/perturb-project/git/gene_ptb_prediction/scdata_file_path.csv')
adata_path = scdata_file[scdata_file['dataset'] == dataset][scdata_file['use_hvg'] == (use_hvg == 'True')]['file_path'].values[0]
adata = sc.read(adata_path)
print('Loaded adata_refer from ', adata_path)

In [None]:
adata_ctrl = adata[adata.obs['gene'] == 'non-targeting'].copy()
adata_ctrl

In [None]:
# calculate highly-variable genes in adata
sc.pp.highly_variable_genes(adata, n_top_genes=num_gene)
top_hvg = adata.var[adata.var['highly_variable'] == True].index
# get indexes of top hvg
top_hvg_idx = [adata.var.index.get_loc(x) for x in top_hvg]

### Step 1. Read in predictions for different models

In [66]:
if model_type == 'Gears':
    model_fold_path = f'/home/che/GEARS/{dataset_name}'
elif model_type == 'morph':
    model_fold_path = f'/home/che/perturb-project/predict_model/result/rna/{dataset_name}'
elif model_type == 'Control':
    model_fold_path = f'/home/che/perturb-project/git/gene_ptb_prediction/gene_interaction_prediction/baseline_model/predict_control/{dataset_name}'

In [None]:
if model_type == 'Gears' or model_type == 'morph':
    pred_whole = []
    # loop through paths to read in predictions
    for fold in range(5):
        fold_num = fold+1
        if model_type == 'Gears':
            model_path = f'{model_fold_path}/predict_gi_fold_{fold_num}/random_seed_{random_seed}'
            # read in pickle file of predictions
            pred_path = os.path.join(model_path, "gears_pred_dict.pkl")
            with open(pred_path, 'rb') as f:
                pred = pd.read_pickle(f)
            print(f'Loaded predictions from {pred_path}')
        elif model_type == 'morph':
            model_path = f'{model_fold_path}/predict_gi_fold_{fold_num}/recon_loss_{recon_loss}/null_label_{null_label}/epochs_{num_epochs}/tolerance_epochs_{tolerance_epochs}/mxAlpha_{mxAlpha}/random_seed_{random_seed}/'
            pattern = os.path.join(model_path, f'{representation_type}_{model_type}_run*')
            run_dirs = glob.glob(pattern)
            
            if len(run_dirs) == 0:
                print(f'No runs found for fold {fold_num}')
                continue
            elif len(run_dirs) > 1:
                print(f'Multiple runs found for fold {fold_num}')
                continue
            else:
                model_dir = run_dirs[0]
            # read in pickle file of predictions
            pred_path = f'{model_dir}/{model_name}_pred_test.pkl'
            with open(pred_path, 'rb') as f:
                pred = pd.read_pickle(f)
            print(f'Loaded predictions from {pred_path}')
        pred_whole.append(pred)

In [68]:
if model_type == 'Control':
    pred_whole = []
    # read in pickle file
    pred_path = f'{model_fold_path}/random_seed_{random_seed}/y_ctrl_pred.pkl'
    with open(pred_path, 'rb') as f:
        pred = pd.read_pickle(f)
    pred_whole.append(pred)

In [69]:
# Get mean of predictions
if model_type != 'Truth':
    mean_dict = {}
    for pred in pred_whole:
        for key, value in pred.items():
            if key not in mean_dict:
                mean_dict[key] = np.mean(value, axis=0)

In [None]:
# calculate delta score
X_dense = adata.X.toarray() if issparse(adata.X) else adata.X

# Create a DataFrame for the observations (genes)
obs_df = pd.DataFrame(adata.obs['gene'])

# Add the expression data to this DataFrame
X_df = pd.DataFrame(X_dense, index=adata.obs.index, columns=adata.var.index)

# Merge the expression data with the observation data
full_df = pd.concat([obs_df, X_df], axis=1)

# Group by the 'gene' column and calculate the mean for each group
average_intervention_effects = full_df.groupby('gene').mean()

assert(average_intervention_effects.shape[0] == len(adata.obs['gene'].unique()))

In [None]:
if model_type != 'Truth':
    # stack them into a dataframe, with keys being index
    pred_df = pd.DataFrame.from_dict(mean_dict, orient='index')
    pred_df.columns = adata.var.index
elif model_type == 'Truth':
    pred_df = average_intervention_effects
    # remove non-targeting row
    pred_df = pred_df.drop('non-targeting', axis=0)
pred_df

In [None]:
if filter_zeros:
    # set negative values to 0
    pred_df[pred_df < 0] = 0
pred_df.head(5)

In [None]:
pred_df_delta = pred_df - average_intervention_effects.loc['non-targeting']
pred_df_delta.head()

In [None]:
pred_df_delta[pred_df_delta.index == 'MAPK1+PRTG']

In [75]:
# \mu_p = \bar{X}_p - \bar{X}_{\text{non-targeting}}
average_intervention_effects = average_intervention_effects - average_intervention_effects.loc['non-targeting']
assert(average_intervention_effects.loc['non-targeting'].sum() == 0)

In [None]:
single_perturbations = average_intervention_effects.index[~average_intervention_effects.index.str.contains('\+')]
single_perturbations_df = average_intervention_effects.loc[single_perturbations]
single_perturbations_df.head()

In [None]:
# combine single with predicted combo
if model_type != 'Truth':
    pred_df_delta = pd.concat([pred_df_delta, single_perturbations_df])
pred_df_delta

In [None]:
pred_df_delta[pred_df_delta.index == 'MAPK1+PRTG']

In [None]:
# subset to top hvg
pred_df_delta = pred_df_delta.loc[:, top_hvg]
pred_df_delta

In [None]:
# save delta to the folder
# save the delta expression values into pickle file
import pickle
import os

output_path_base = f'/home/che/perturb-project/git/gene_ptb_prediction/gene_interaction_prediction/data/{dataset_name}/predict/use_gt_single/num_gene_'+str(num_gene)+'/'

if model_type == 'Gears':
    output_path = output_path_base + f'Gears/seed_{random_seed}'
elif model_type == 'morph':
    output_path = output_path_base + f'{representation_type}_{model_type}/recon_loss_{recon_loss}/null_label_{null_label}/epochs_{num_epochs}/tolerance_epochs_{tolerance_epochs}/mxAlpha_{mxAlpha}/seed_{random_seed}/{model_name}'
elif model_type == 'Control':
    output_path = output_path_base + 'Control/seed_'+str(random_seed)
elif model_type == 'Truth':
    output_path = output_path_base + 'Truth'

if not os.path.exists(output_path):
    os.makedirs(output_path)

if filter_zeros:
    with open(f'{output_path}/delta_expression_filtered.pkl', 'wb') as f:
        pickle.dump(pred_df_delta, f)
    print('Saved to ', f'{output_path}/delta_expression_filtered.pkl')
else:
    with open(f'{output_path}/delta_expression.pkl', 'wb') as f:
        pickle.dump(pred_df_delta, f)
    print('Saved to ', f'{output_path}/delta_expression.pkl')