In [1]:
import pandas as pd
import pickle as pkl
import numpy as np

import matplotlib.pyplot as plt

import scanpy as sc
%matplotlib inline
import matplotlib

import scanpy as sc
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import glob

### Random baseline

In [2]:
def generate_random_rounds(all_genes, round0=None, rounds = 16, sample_size=64):
    past_set =list(all_genes)
    random_rounds = {}
    if round0 is not None:
        random_rounds[0] = round0
    else:
        random_rounds[0] = np.random.choice(all_genes, sample_size)
    
    for i in range(1,rounds):
        selection_set = set(past_set).difference(set(random_rounds[i-1]))
        random_rounds[i] = np.concatenate([random_rounds[i-1], np.random.choice(list(selection_set), sample_size)])
        past_set = list(selection_set)
        
    return random_rounds

def generate_random_rounds_custom(all_genes, rounds = 16, sample_sizes=None):
    past_set =list(all_genes)
    random_rounds = {}
    random_rounds[0] = []
    
    for i in range(1,rounds):
        selection_set = set(past_set).difference(set(random_rounds[i-1]))
        random_rounds[i] = np.concatenate([random_rounds[i-1], np.random.choice(list(selection_set), sample_sizes[i])])
        past_set = list(selection_set)
        
    return random_rounds

def return_hits(list_):
    return list(set(list_).intersection(set(topmovers)))

def get_all_sampled_genes(exp_path):
    all_sampled_genes = []
    num_rounds = len(glob.glob(exp_path + '*'))
    for i in range(num_rounds):
        exp_path_ = exp_path + str(i) + '.npy'
        all_sampled_genes.append(np.load(exp_path_))
    return all_sampled_genes

def get_successful_sample_sizes(sampled_genes):
    sizes = []
    prev_size = 0
    for sample in sampled_genes:
        sizes.append(len(sample)-prev_size)
        prev_size = prev_size + sizes[-1]
    return sizes

essential_genes = '/dfs/user/yhr/bagel/CEGv2.txt'
essential_genes = pd.read_csv(essential_genes, delimiter='\t')['GENE'].tolist()

def return_hits(arr, remove_essential=True):
    res = list(set(arr).intersection(set(topmovers)))
    if remove_essential == True:
        res = list(set(res).difference(set(essential_genes)))
    return res

In [3]:
REPS = np.arange(1, 10).astype('str')
ACQUISITION_FUNCTIONS = ["random", "softuncertain", "topuncertain",
                         "marginsample", "coreset", "badge",
                         "kmeans_embedding", "kmeans_data"]

dataset = 'carnevale_adenosine'
sample_size= 32
num_steps = 30

data_name_map = {'scharenberg': 'Scharenberg22',
                 'ifng': 'IFNG',
                 'il2': 'IL2',
                 'steinhart': 'Steinhart_crispra_GD2_D22',
                 'scharenberg': 'Scharenberg22',
                 'sanchez': 'Sanchez21_down',
                 'carnevale_adenosine': 'Carnevale22_Adenosine'}

result_name_map = {'scharenberg': 'scharenberg_2022',
                 'ifng': 'Schmidt_2021_ifng',
                 'il2': 'Schmidt_2021_il2',
                 'steinhart': 'steinhart_2024_crispra_GD2_D22',
                 'scharenberg': 'scharenberg_2022',
                 'sanchez': 'sanchez_2021_down',
                  'carnevale_adenosine': 'Carnevale_2022_Adenosine'}

    
num_reps = len(REPS)
batch_size=sample_size

data_df = pd.read_csv('/dfs/user/yhr/AI_RA/research_assistant/datasets/ground_truth_'+data_name_map[dataset]+'.csv')
topmovers = np.load('/dfs/user/yhr/AI_RA/research_assistant/datasets/topmovers_'+data_name_map[dataset]+'.npy')
try:
    data_df = data_df.rename(columns={'0':'Gene', '1':'Score'})
except:
    pass

all_genes = data_df['Gene'].values
data_df = data_df.set_index('Gene')

### ML model results

In [4]:
## For steinhart

#data_name = 'schmidt_2021_ifng_noise'
#data_name = 'belk_2022'
data_name = result_name_map[dataset]

def read_ml_rounds(data_name, num_steps, batch_size):
    #path = f'/dfs/user/yhr/genedisco/genedisco/results/Schmidt_2021_'+data_name+'/'
    
    ## special path for steinhart
    if sample_size == 128:
        path = f'/dfs/user/yhr/genedisco/genedisco/results/longruns/'+data_name+'/'
    elif sample_size == 32:
        path = f'/dfs/user/yhr/genedisco/genedisco/results/longruns_32/'+data_name+'/'
    print(path)
    all_pred_genes = []
    for i in range(num_steps):
        try:
            all_pred_genes.append(pd.read_pickle(path+'cycle_{}/selected_indices.pickle'.format(i)))
        except:
            print('No sampling round {}'.format(i))
    return all_pred_genes

def get_ml_hits(data_name, ML_model_name, num_steps = num_steps, 
                batch_size=batch_size,
                num_reps=2):
    
    ml_rounds = {}
    ml_hits = {}
    ml_hit_rates = {}
    
    for rep in range(1, num_reps+1):
        rep=str(rep)
        ml_rounds[rep] = read_ml_rounds(data_name + '_' + ML_model_name + '_{}'.format(rep), 
                                        num_steps, batch_size=batch_size)

        ml_hits[rep] = []
        ml_hit_rates[rep] = []
        for step in range(num_steps):
            ml_hits[rep].append(return_hits(ml_rounds[rep][step]))
            ml_hit_rates[rep].append(len(ml_hits[rep][-1])/len(topmovers))
        
    return ml_hit_rates, ml_hits


def get_random_hits(sample_size, num_steps = num_steps, num_reps=3):
    random_rounds = {}
    random_hits = {}
    random_hit_rates = {}

    for rep in range(1, num_reps+1):
        rep=str(rep)
        random_rounds[rep] = generate_random_rounds(all_genes, rounds = num_steps, sample_size=sample_size)
        #random_rounds[rep] = generate_random_rounds_custom(all_genes, rounds = num_steps, 
        #                                                   sample_sizes=sample_sizes[rep])

        random_hits[rep] = []
        random_hit_rates[rep] = []
        for step in range(num_steps):
            random_hits[rep].append(return_hits(random_rounds[rep][step]))
            random_hit_rates[rep].append(len(random_hits[rep][-1])/len(topmovers))
            
    return random_hit_rates, random_hits

def get_avg_hit_score(hits):
    return data_df.loc[hits].abs().mean().values[0]

def get_avg_hit_scores(hits, num_reps=3):
    hit_scores = []
    
    for rep in range(1, num_reps+1):
        hit_score_rep = []
        for hit_list in hits[str(rep)]:
            hit_score_rep.append(get_avg_hit_score(hit_list))
        hit_scores.append(hit_score_rep)

    return np.round(np.nanmean(np.array(hit_scores), 0),3)

In [9]:
# calculate the mean and standard deviation of the random data
random_hit_rates, random_hits = get_random_hits(sample_size, num_steps = num_steps, num_reps=num_reps)
random_mean = np.mean(np.array([x for x in random_hit_rates.values()]),0)
random_std = np.std(np.array([x for x in random_hit_rates.values()]),0)

# calculate the mean and standard deviation for the model predictions
ml_mean = {}
ml_std = {}
ml_scores = {}

for ML_model_name in ACQUISITION_FUNCTIONS:
    
    print(ML_model_name)
    
    try:
        ml_hit_rates, ml_hits = get_ml_hits(data_name, ML_model_name, num_steps = num_steps, 
                                            num_reps=num_reps,
                                            batch_size=batch_size)
        ml_scores[ML_model_name] = get_avg_hit_scores(ml_hits, num_reps=num_reps)
        
        ml_mean[ML_model_name] = np.round(np.mean(np.array([x for x in ml_hit_rates.values()]),0),3)
        ml_std[ML_model_name] = np.round(np.std(np.array([x for x in ml_hit_rates.values()]),0),3)
    except:
        print('Failed')

random
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/Carnevale_2022_Adenosine_random_1/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/Carnevale_2022_Adenosine_random_2/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/Carnevale_2022_Adenosine_random_3/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/Carnevale_2022_Adenosine_random_4/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/Carnevale_2022_Adenosine_random_5/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/Carnevale_2022_Adenosine_random_6/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/Carnevale_2022_Adenosine_random_7/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/Carnevale_2022_Adenosine_random_8/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/Carnevale_2022_Adenosine_random_9/
softuncertain
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/Carnevale_2022_Adenosine_softuncertain_1/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/Carnevale_2022_Adenosine_sof

In [11]:
for means, stds in zip(ml_mean.items(), ml_std.items()):
    assert means[0] == stds[0]
    name = means[0]
    row = means[0]
    means= means[1]
    stds= stds[1]
    print(name, '&', '{} & {} & {}$\pm${}'.format(means[9],  
                                                     means[19],
                                                     means[29], stds[29]))

random & 0.016 & 0.033 & 0.049$\pm$0.006
softuncertain & 0.018 & 0.037 & 0.052$\pm$0.01
topuncertain & 0.014 & 0.033 & 0.048$\pm$0.004
marginsample & 0.012 & 0.027 & 0.044$\pm$0.005
badge & 0.013 & 0.03 & 0.05$\pm$0.005
kmeans_embedding & 0.013 & 0.023 & 0.036$\pm$0.005


In [10]:
## For Scharenberg only 20 rounds work

for means, stds in zip(ml_mean.items(), ml_std.items()):
    assert means[0] == stds[0]
    name = means[0]
    row = means[0]
    means= means[1]
    stds= stds[1]
    print(name, '&', '{} & {} $\pm${}'.format(means[9],  
                                                     means[19],stds[19]))

random & 0.243 & 0.518 $\pm$0.045
softuncertain & 0.281 & 0.57 $\pm$0.029
topuncertain & 0.354 & 0.604 $\pm$0.0
marginsample & 0.319 & 0.604 $\pm$0.0
coreset & 0.358 & 0.587 $\pm$0.014
badge & 0.363 & 0.587 $\pm$0.024
kmeans_embedding & 0.237 & 0.543 $\pm$0.021
kmeans_data & 0.388 & 0.585 $\pm$0.018


In [11]:
get_ml_hits(data_name, ML_model_name, num_steps = num_steps, 
                                            num_reps=num_reps,
                                            batch_size=batch_size)

/dfs/user/yhr/genedisco/genedisco/results/longruns_32/scharenberg_2022_kmeans_data_1/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/scharenberg_2022_kmeans_data_2/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/scharenberg_2022_kmeans_data_3/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/scharenberg_2022_kmeans_data_4/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/scharenberg_2022_kmeans_data_5/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/scharenberg_2022_kmeans_data_6/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/scharenberg_2022_kmeans_data_7/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/scharenberg_2022_kmeans_data_8/
/dfs/user/yhr/genedisco/genedisco/results/longruns_32/scharenberg_2022_kmeans_data_9/


({'1': [0.07547169811320754,
   0.09433962264150944,
   0.09433962264150944,
   0.1320754716981132,
   0.18867924528301888,
   0.2830188679245283,
   0.3584905660377358,
   0.3584905660377358,
   0.37735849056603776,
   0.41509433962264153,
   0.41509433962264153,
   0.4528301886792453,
   0.4528301886792453,
   0.4716981132075472,
   0.5094339622641509,
   0.5094339622641509,
   0.5283018867924528,
   0.5471698113207547,
   0.5660377358490566,
   0.5849056603773585],
  '2': [0.018867924528301886,
   0.03773584905660377,
   0.09433962264150944,
   0.1320754716981132,
   0.18867924528301888,
   0.22641509433962265,
   0.24528301886792453,
   0.3018867924528302,
   0.33962264150943394,
   0.39622641509433965,
   0.41509433962264153,
   0.4339622641509434,
   0.4716981132075472,
   0.5094339622641509,
   0.5849056603773585,
   0.6037735849056604,
   0.6226415094339622,
   0.6226415094339622,
   0.6226415094339622,
   0.6226415094339622],
  '3': [0.0,
   0.03773584905660377,
   0.113207547