In [1]:
from collections import defaultdict
import numpy as np
import os
import sys
import scanpy as sc
import pandas as pd
from collections import defaultdict
import numpy as np

In [2]:
def get_mean_control(adata):

    mean_ctrl_exp = adata[adata.obs['condition'] == 'ctrl'].to_df().mean()
    return mean_ctrl_exp

def subtract_ctrl_mean(preds_avg, adata):
 
    mean_ctrl_exp = get_mean_control(adata)
 
    adjusted_preds_avg = {}
    
    for key in preds_avg.keys():
        expr_values = preds_avg[key]        
        adjusted_expr_values = [expr - mean_ctrl_exp[gene] for expr, gene in zip(expr_values, adata.var_names)]
        
        adjusted_preds_avg[key] = adjusted_expr_values
    
    return adjusted_preds_avg
    
def calculate_avg_expression_truth(res):
    sums_counts = defaultdict(lambda: [0, 0])
    
    for i, cell in enumerate(res['pert_cat']):
        if sums_counts[cell][1] == 0:
            sums_counts[cell][0] = np.array(res['truth'][i])
        else:
            sums_counts[cell][0] += np.array(res['truth'][i])
        sums_counts[cell][1] += 1
    
    preds_avg = {}
    
    for cell, (total, count) in sums_counts.items():
        preds_avg[cell] = (total / count).tolist()
    
    return preds_avg
    
def calculate_avg_expression_pred(res):
    sums_counts = defaultdict(lambda: [0, 0])
    
    for i, cell in enumerate(res['pert_cat']):
        if sums_counts[cell][1] == 0:
            sums_counts[cell][0] = np.array(res['pred'][i])
        else:
            sums_counts[cell][0] += np.array(res['pred'][i])
        sums_counts[cell][1] += 1
    
    preds_avg = {}
    
    for cell, (total, count) in sums_counts.items():
        preds_avg[cell] = (total / count).tolist()
    
    return preds_avg



In [3]:
def subset_preds_pp(subset_preds, combo):
    ptn12_ctrl = subset_preds[f'{combo[0]}+ctrl']
    cbl_ctrl = subset_preds[f'{combo[1]}+ctrl']
    naive_values = [x + y for x, y in zip(ptn12_ctrl, cbl_ctrl)]
    subset_preds['Naive'] = naive_values
    cbl_ptpn12 = subset_preds[f'{combo[0]}+{combo[1]}']
    naive = subset_preds['Naive']

    abs_diff = [np.abs(x - y) for x, y in zip(cbl_ptpn12, naive)]
    
    subset_preds['abs_diff'] = abs_diff
    
    abs_diff = subset_preds['abs_diff']
    
    sorted_indices = np.argsort(abs_diff)
    
    descending_indices = sorted_indices[::-1]
    
    diff_rank = np.zeros_like(descending_indices) + len(descending_indices)
    diff_rank[descending_indices] = np.arange(1, len(descending_indices) + 1)
    
    subset_preds['diff_rank'] = diff_rank.tolist()
    
    subset_preds['index'] = list(range(0, len(diff_rank)))
    
    return subset_preds

def subset_preds(subset_preds):
    diff_rank = subset_preds['diff_rank']
    top_20_indices = np.argsort(diff_rank)[:20]
    subsetted_preds = {}

    for key in subset_preds.keys():
        subsetted_preds[key] = [subset_preds[key][index] for index in top_20_indices]

    return subsetted_preds

## data

In [4]:
import pickle
with open('/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/normanweissman2019/splits/normanweissman2019_simulation_1_0.75_subgroup.pkl', 'rb') as f:
     split1_subgroup = pickle.load(f)

In [5]:
with open('/home/share/huadjyin/home/zhoumin3/zhoumin/model_benchmark/01_A_results/NormanWeissman2019/gears/split1/NormanWeissman2019_split1_test_res.pkl', 'rb') as f:
     res = pickle.load(f)

In [6]:
import scanpy as sc
adata = sc.read_h5ad('/home/share/huadjyin/home/zhoumin3/zhoumin/benchmark_data/01A_total_re/03final/normanweissman2019/perturb_processed.h5ad')

### truth

In [45]:
# delta
truth_avg = calculate_avg_expression_truth(res)
truth_avg = subtract_ctrl_mean(truth_avg, adata)

In [46]:
combo = ['CBL', 'PTPN12']
keys_to_keep = {f'{combo[0]}+ctrl', f'{combo[1]}+ctrl', f'{combo[0]}+{combo[1]}'}
subset_truth = {key: truth_avg[key] for key in keys_to_keep if key in truth_avg}

In [47]:
subset_truth_processed = subset_preds_pp(subset_truth, combo)

In [48]:
subset_truth_processed.keys()

dict_keys(['CBL+PTPN12', 'CBL+ctrl', 'PTPN12+ctrl', 'Naive', 'abs_diff', 'diff_rank', 'index'])

In [49]:
#add gene_name
subset_truth_processed['gene_name'] = adata.var['gene_name'].tolist()

In [50]:
subset_truth_processed = subset_preds(subset_truth_processed)

### pred

In [51]:
truth_pred = subset_truth_processed

In [52]:
pred_avg = calculate_avg_expression_pred(res)
pred_avg = subtract_ctrl_mean(pred_avg, adata)

In [53]:
truth_pred['CBL+PTPN12_p'] = pred_avg['CBL+PTPN12']

In [54]:
truth_pred.keys()

dict_keys(['CBL+PTPN12', 'CBL+ctrl', 'PTPN12+ctrl', 'Naive', 'abs_diff', 'diff_rank', 'index', 'gene_name', 'CBL+PTPN12_p'])

In [55]:
truth_pred_20 =  subset_preds(truth_pred)

In [56]:
truth_pred_20.keys()

dict_keys(['CBL+PTPN12', 'CBL+ctrl', 'PTPN12+ctrl', 'Naive', 'abs_diff', 'diff_rank', 'index', 'gene_name', 'CBL+PTPN12_p'])

In [95]:
truth_pred_20

{'CBL+PTPN12': [0.5592501536011696,
  0.416536808013916,
  0.9357820153236389,
  0.917165607213974,
  0.7560104131698608,
  -0.2701316177845001,
  1.2825572490692139,
  0.8081920146942139,
  -0.4057002663612366,
  -0.5600814968347549,
  -0.012369513511657715,
  0.4044220447540283,
  0.35859426856040955,
  0.018280386924743652,
  0.8034462332725525,
  0.3428364396095276,
  0.8359328508377075,
  -0.06716972589492798,
  -0.206298828125,
  -0.22855722904205322],
 'CBL+ctrl': [0.07270268350839615,
  0.9072510004043579,
  0.2619195580482483,
  0.34826937317848206,
  0.21902143955230713,
  0.12879401445388794,
  0.5642321109771729,
  0.3291623592376709,
  -0.07589304447174072,
  -0.39467325806617737,
  0.18758392333984375,
  0.12069070339202881,
  0.08970889449119568,
  0.02955937385559082,
  0.2680312991142273,
  0.06865668296813965,
  0.25530409812927246,
  0.07212424278259277,
  -0.18888092041015625,
  -0.18153870105743408],
 'PTPN12+ctrl': [0.05861476808786392,
  -0.07010841369628906,
  0

## GIs_Error_bar

In [67]:
def calculate_delta_expression(adata, gene_indices, combo):
    # ctrl
    control_vals = adata[adata.obs['condition'] == 'ctrl'][:, gene_indices].X.toarray()
    control_vec = control_vals.mean(axis=0)

    differential_expression = {}

    # condition
    conditions = {
        combo[0]: f"{combo[0]}+ctrl",
        combo[1]: f"{combo[1]}+ctrl",
        '+'.join(combo): f"{combo[0]}+{combo[1]}"
    }

    # DEGs
    for condition, label in conditions.items():
        vals = adata[adata.obs['condition'] == label][:, gene_indices].X.toarray()
        diff = vals - control_vec
        differential_expression[condition] = diff

    return differential_expression


In [80]:
import scipy
import statsmodels.api as sm
import numpy as np
from scipy import stats
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.utils import resample

def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), stats.sem(a)
    h = se * stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h
def calculate_p_ctrl_vec(adata, num_samples=300, num_reps=100):

    np.random.seed(42)
    ctrl_adata = adata[adata.obs['condition'] == 'ctrl']
    if isinstance(ctrl_adata.X, csr_matrix):

        ctrl_data_dense = ctrl_adata.X.A
    else:
        ctrl_data_dense = ctrl_adata.X

    p_ctrl_list = []

    for _ in range(num_reps):
 
        sample = resample(ctrl_data_dense, n_samples=num_samples, random_state=np.random.randint(0, 1000))
        mean_sample = np.mean(sample, axis=0)
        p_ctrl_list.append(mean_sample)

    p_ctrl = np.array(p_ctrl_list)
    p_ctrl_vec = p_ctrl.mean(0)
    return p_ctrl_vec


In [74]:
def calculate_confidence_intervals(delta_exp):

    confidence_intervals = {}
    for condition, data in delta_exp.items():
        ci_list = []
        for col in range(data.shape[1]):
            column_data = data[:, col]
            mean, lower_bound, upper_bound = mean_confidence_interval(column_data)
            ci_list.append((mean, lower_bound, upper_bound))
        confidence_intervals[condition] = ci_list
    
    return confidence_intervals

In [66]:
combo = ['CBL', 'PTPN12']
gene_indices = truth_pred_20['index']

In [68]:
delta_exp = calculate_delta_expression(adata, gene_indices, combo)

In [73]:
delta_exp

{'CBL': array([[-0.0703242 ,  0.07580388,  0.5402542 , ...,  0.4888112 ,
          0.20078707,  0.09040797],
        [ 0.495473  ,  1.9054011 , -0.33267742, ..., -0.07384676,
         -0.17737079, -0.3661145 ],
        [-0.0703242 ,  1.183451  , -0.33267742, ...,  0.14600354,
         -0.2960949 ,  0.4689349 ],
        ...,
        [-0.0703242 ,  0.4343233 , -0.33267742, ..., -0.02126181,
          0.4499054 ,  0.10911715],
        [-0.0703242 ,  0.87593734, -0.33267742, ..., -0.63964397,
         -0.3089974 ,  0.22018719],
        [-0.0703242 ,  0.846462  , -0.33267742, ...,  0.42276162,
          1.1750953 ,  0.88004124]], dtype=float32),
 'PTPN12': array([[-0.0703242 ,  0.05867279, -0.33267742, ...,  0.4716801 ,
         -0.83700216, -1.554736  ],
        [ 0.43300918, -0.5493179 , -0.33267742, ...,  0.19693345,
         -0.78570724, -1.0514026 ],
        [-0.0703242 , -1.0526513 , -0.33267742, ...,  0.13930917,
          0.10743761,  0.67035544],
        ...,
        [-0.0703242 , 

In [92]:
error_bar = calculate_confidence_intervals(delta_exp)

In [93]:
error_bar

{'CBL': [(0.0727027, 0.04777931004818419, 0.09762608677093049),
  (0.9072521, 0.8572623104764034, 0.9572418360995243),
  (0.2619194, 0.21600289071950235, 0.30783592735377036),
  (0.34826937, 0.29806413804966986, 0.39847460830729425),
  (0.21902102, 0.1682932890467431, 0.26974875559284434),
  (0.12879378, 0.08119038210195213, 0.17639716996866556),
  (0.5642312, 0.4951438540620688, 0.6333185797529337),
  (0.3291617, 0.26581593067701337, 0.39250747649614337),
  (-0.07589319, -0.12038950956916594, -0.03139687739753938),
  (-0.39467323, -0.4326822005380657, -0.35666425598964424),
  (0.18758395, 0.1375447904218736, 0.23762311586245868),
  (0.12069116, 0.07953708431858275, 0.1618452314363077),
  (0.08970877, 0.04857770650851131, 0.13083982915413975),
  (0.029559432, -0.00312758372683275, 0.06224644692201364),
  (0.26803127, 0.22025754778996248, 0.31580499083384733),
  (0.06865675, 0.02830305777703792, 0.10901044226969211),
  (0.25530317, 0.20496981144316903, 0.3056365370713879),
  (0.07212423

#### p_ctrl_data

In [87]:
p_ctrl_vec = calculate_p_ctrl_vec(adata)
p_ctrl_vec = p_ctrl_vec[:, gene_indices]

In [88]:
indices = np.where(res['pert_cat'] == 'CBL+PTPN12')[0]
pred_CBL_PTPN12 = res['pred'][indices, :]

In [89]:
p_pred_CBL_PTPN12 = pred_CBL_PTPN12[:, gene_indices]
p_CBL_PTPN12 = p_pred_CBL_PTPN12 - p_ctrl_vec

In [90]:
CBL_PTPN12_p = []
for col in range(p_CBL_PTPN12.shape[1]):
    column_data = p_CBL_PTPN12[:, col]
    mean, lower_bound, upper_bound = mean_confidence_interval(column_data)
    CBL_PTPN12_p.append((mean, lower_bound, upper_bound))

In [106]:
CBL_PTPN12_p

[(0.2984241, 0.2737979453214698, 0.32305024450931025),
 (-0.119589694, -0.17786802945509483, -0.061311358233541946),
 (0.2479058, 0.20032195805267278, 0.2954896533612829),
 (0.46178576, 0.41378262720096937, 0.5097888998032726),
 (0.38551825, 0.3128754970759141, 0.4581610086232436),
 (-0.051144995, -0.11591844265889538, 0.013628453068250543),
 (0.3473718, 0.24779004821354783, 0.44695352545207107),
 (0.25798362, 0.1785153055940385, 0.3374519442762618),
 (-0.20398858, -0.26455208133441344, -0.14342508245724306),
 (-0.19246545, -0.2530344138693852, -0.131896494810577),
 (0.0736198, 0.012231890856496522, 0.13500770479513008),
 (0.17038612, 0.11128237155511725, 0.2294898697988714),
 (0.14534332, 0.09308535852720248, 0.19760127863595975),
 (-0.04346157, -0.08720629077815142, 0.00028315347098437776),
 (0.11858929, 0.05677770594271413, 0.18040087303010233),
 (0.08594741, 0.028670454317599715, 0.1432243642267881),
 (0.09841772, 0.035665043509803374, 0.16117039986769144),
 (-0.06472984, -0.120867

In [94]:
error_bar['CBL_PTPN12_p'] = CBL_PTPN12_p

In [122]:
error_bar.keys()

dict_keys(['CBL+PTPN12', 'CBL_PTPN12_p', 'CBL+ctrl', 'PTPN12+ctrl'])

In [121]:
error_bar['CBL+ctrl'] = error_bar.pop('CBL')
error_bar['PTPN12+ctrl'] = error_bar.pop('PTPN12')

# Plot

In [99]:
truth_pred_20.keys()

dict_keys(['CBL+PTPN12', 'CBL+ctrl', 'PTPN12+ctrl', 'Naive', 'abs_diff', 'diff_rank', 'index', 'gene_name', 'CBL+PTPN12_p'])

In [108]:
row_names = ['CBL+ctrl', 'PTPN12+ctrl', 'CBL+PTPN12', 'Naive', 'CBL+PTPN12_p']
column_names = truth_pred_20['gene_name'][:10]
plot_df = pd.DataFrame(index=row_names, columns=column_names)

plot_df.loc['CBL+ctrl'] = truth_pred_20['CBL+ctrl'][:10]
plot_df.loc['PTPN12+ctrl'] = truth_pred_20['PTPN12+ctrl'][:10]
plot_df.loc['CBL+PTPN12'] = truth_pred_20['CBL+PTPN12'][:10]
plot_df.loc['Naive'] = truth_pred_20['Naive'][:10]
plot_df.loc['CBL+PTPN12_p'] = truth_pred_20['CBL+PTPN12_p'][:10]

In [118]:
plot_df

Unnamed: 0,HBA2,CD99,HBA1,ALAS2,GYPB,CTSL,HBG2,HBG1,KCNH2,PRSS57
CBL+ctrl,0.072703,0.907251,0.26192,0.348269,0.219021,0.128794,0.564232,0.329162,-0.075893,-0.394673
PTPN12+ctrl,0.058615,-0.070108,0.291212,0.210517,0.184954,-0.104332,0.436146,0.208772,-0.070537,-0.403382
CBL+PTPN12,0.55925,0.416537,0.935782,0.917166,0.75601,-0.270132,1.282557,0.808192,-0.4057,-0.560081
Naive,0.131317,0.837143,0.553132,0.558786,0.403975,0.024462,1.000379,0.537935,-0.14643,-0.798055
CBL+PTPN12_p,0.298424,-0.11959,0.247906,0.461786,0.385518,-0.051145,0.347372,0.257984,-0.203989,-0.192465


In [116]:
first_values = [t[0] for t in CBL_PTPN12_p[:10]]
plot_df.loc['CBL+PTPN12_p'] = first_values
plot_df.loc['Naive'] = plot_df.loc['CBL+ctrl'] + plot_df.loc['PTPN12+ctrl']

In [125]:
subset_error_bar = {key: value[:10] for key, value in error_bar.items()}

In [130]:
subset_error_bar['CBL+PTPN12_p'] = subset_error_bar.pop('CBL_PTPN12_p')