<a href="https://colab.research.google.com/github/pachterlab/CGP_2024_2/blob/main/combinatorial_prediction_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**In this notebook we will test different predictive models of combinatorial perturbation behaviors using dual gRNA/CRISPRa conditions. This will demonstrate how to identify potential models of combinatorial regulation at the level of kinetic parameters.**

## **Read in data and metadata**

In [None]:
#Install
!pip install --quiet monod
!pip install -U --quiet loompy


In [None]:

# /home/tchari/metadata/norman_GSE133344_filtered_cell_identities.csv

#gg_200525_genome_polyA_cum_3

#/home/tchari/counts/norman_crispr/loom/allcrispr.loom

#https://data.caltech.edu/records/ahjyk-gsj16/files/allcrispr_looms.tar.gz?download=1

In [None]:
import monod
from monod import preprocess, extract_data, cme_toolbox, inference, analysis

In [None]:
import pandas as pd
import numpy as np
import loompy as lp
import matplotlib.pyplot as plt
import scipy
import seaborn as sns
import scipy.stats

import random
import glob
import os

## **Run *Monod* inference on loom files**

Set up files for Monod run and set cell barcode filters for desired conditions

In [None]:
meta = pd.read_csv('./norman_GSE133344_filtered_cell_identities.csv') #Metadata for cells from study


In [None]:

#Select individual control conditions, CBL/CNN1 conditions, and all controls combined
subcluster_names = [['NegCtrl10_NegCtrl0__NegCtrl10_NegCtrl0'],['NegCtrl11_NegCtrl0__NegCtrl11_NegCtrl0'],
                   ['NegCtrl1_NegCtrl0__NegCtrl1_NegCtrl0'],['NegCtrl0_NegCtrl0__NegCtrl0_NegCtrl0'],
                   ['CBL_NegCtrl0__CBL_NegCtrl0'],['CBL_CNN1__CBL_CNN1'],
                   ['NegCtrl0_CNN1__NegCtrl0_CNN1','CNN1_NegCtrl0__CNN1_NegCtrl0'],
                   ['NegCtrl10_NegCtrl0__NegCtrl10_NegCtrl0','NegCtrl11_NegCtrl0__NegCtrl11_NegCtrl0',
                    'NegCtrl0_NegCtrl0__NegCtrl0_NegCtrl0']]

sub_names_only = ['_'.join(n) for n in subcluster_names]

cluster_names = []
dataset_names = sub_names_only   #To save
print('dataset_names: ', dataset_names)
print('len(dataset_names): ',len(dataset_names))
print()


transcriptome_filepath = './gg_200525_genome_polyA_cum_3'

spliced_layer = 'spliced'
unspliced_layer = 'unspliced'
gene_attr = 'gene_name'
cell_attr = 'barcode'

attribute_names=[(unspliced_layer,spliced_layer),gene_attr,cell_attr]

loom_filepaths = ['./allcrispr.loom']*len(dataset_names)
print('loom_filepaths: ',loom_filepaths)

n_datasets = len(loom_filepaths)

In [None]:
cf = [] #Get cell barcodes for each condition


for k in range(len(dataset_names)):
    filename = loom_filepaths[k]

    with lp.connect(filename,mode='r') as ds:
        bcs = ds.ca[cell_attr]

    annot_bcs = meta[(meta['guide_identity'].isin(subcluster_names[k]))]['cell_barcode']
    cf.append(np.isin(bcs,annot_bcs))
    print(f'\t{len(annot_bcs)} cells in annotations. {np.isin(bcs,annot_bcs).sum()} selected.')



**Set up *Monod* and select 500 genes to fit across all conditions**

In [None]:
dir_string,dataset_strings = monod.preprocess.construct_batch(loom_filepaths, \
                                             transcriptome_filepath, \
                                             dataset_names, \
                                             attribute_names=attribute_names,\
                                             batch_location='./',meta='crispra_combo',batch_id=1,\
                                             n_genes=500,exp_filter_threshold=None,cf=cf)

**Run *Monod* inference**

In [None]:
#Define bounds for parameters
phys_lb = [-2.0, -1.8, -1.8 ] #lower bounds for b,beta,gamma (log10)
phys_ub = [4.2, 2.5, 2.5] #upper bounds for b,beta,gamma
samp_lb = [-7.157894736842105, -1.25] #lower bounds for sampling parameters (c_u,lambda_s)
samp_ub = [-7.157894736842105, -1.25] #upper bounds for sampling parameters (c_u,lambda_s)
gridsize = [1,1] #We are running the technical parameters used in the study (no grid search)

In [None]:
result_strings = []
for i in range(n_datasets):
    #Set burst model with Poisson technical sampling
    fitmodel = monod.cme_toolbox.CMEModel('Bursty','Poisson')

    #Set up inference run
    inference_parameters = monod.inference.InferenceParameters(phys_lb,phys_ub,samp_lb,samp_ub,gridsize,\
                dataset_strings[i],fitmodel,use_lengths = True,
                gradient_params = {'max_iterations':5,'init_pattern':'moments','num_restarts':1})

    search_data = monod.extract_data.extract_data(loom_filepaths[i], transcriptome_filepath, dataset_names[i],
                dataset_strings[i], dir_string, dataset_attr_names=attribute_names,cf=cf[i])
    #Run inference
    full_result_string = inference_parameters.fit_all_grid_points(1,search_data)

    result_strings.append(full_result_string)

In [None]:
#Get rejection statistics
sr_arr = []
sd_arr = []
for i in range(n_datasets):
    sr = monod.analysis.load_search_results(result_strings[i])
    sd = monod.analysis.load_search_data(dir_string+'/'+dataset_names[i]+'/raw.sd')

    sr.find_sampling_optimum()


    _=sr.chisquare_testing(sd,threshold=1e-3)
    sr.chisq_best_param_correction(sd,Ntries=4,viz=False,threshold=1e-3)

    sr.compute_sigma(sd,num_cores=1)

    monod.analysis.make_batch_analysis_dir([sr],dir_string)
    sr.update_on_disk()

    sr_arr += [sr]
    sd_arr += [sd]

## **Multiplicative and Additive prediction models**

In [None]:
def get_multAdd_annot(meanAdd,meanMult,true,trueErr,ci):
    '''
    Assign if param changes are mult/add/unknown.
    Compare predicted value from mult/add models to CI around observed/true value
    '''

    lb = true - ci*trueErr
    ub = true + ci*trueErr

    if (lb <= meanAdd <= ub) & (lb <= meanMult <= ub):
        return 'ambig'
    elif (lb <= meanAdd <= ub):
        return 'add'
    elif (lb <= meanMult <= ub):
        return 'mult'
    elif (meanAdd < lb) & (meanMult < lb):
        return 'supermult'
    elif (meanAdd > ub) & (meanMult > ub):
        return 'subadd'
    else:
        return 'ambig'


def get_multAdd_preds(par_vals,inds,ctrlInds,sd_arr,filt,mod=0,ci=1.96):
    '''
    Get expr FC (for mod=0 for U or 1 for S), and param FCs, under mult/add models
    '''
    ci = ci #confidence interval
    res = pd.DataFrame()

    #Mean spliced/unspliced preds
    meanU_add = np.log2(np.mean(sd_arr[inds[0]].layers[mod][filt,:],axis=1)/np.mean(sd_arr[inds[-1]].layers[mod][filt,:],axis=1) + np.mean(sd_arr[inds[1]].layers[mod][filt,:],axis=1)/np.mean(sd_arr[inds[-1]].layers[mod][filt,:],axis=1) -1)
    meanU_mult = np.log2(np.mean(sd_arr[inds[0]].layers[mod][filt,:],axis=1)/np.mean(sd_arr[inds[-1]].layers[mod][filt,:],axis=1))+np.log2(np.mean(sd_arr[inds[1]].layers[mod][filt,:],axis=1)/np.mean(sd_arr[inds[-1]].layers[mod][filt,:],axis=1))

    true_meanU = np.log2(np.mean(sd_arr[inds[2]].layers[mod][filt,:],axis=1)/np.mean(sd_arr[inds[-1]].layers[mod][filt,:],axis=1))
    true_meanU_err = np.std(np.array([np.log2(np.mean(sd_arr[inds[2]].layers[mod][filt,:],axis=1)/np.mean(sd_arr[ctrlInds[i]].layers[mod][filt,:],axis=1)) for i in range(len(ctrlInds))]),axis=0)/np.sqrt(len(ctrlInds))

    #Get mult/add assignment for means
    means_multAdd = [get_multAdd_annot(meanU_add[i],meanU_mult[i],true_meanU[i],true_meanU_err[i],ci) for i in range(meanU_add.shape[0])]


    z_mult,true  = pred_params(par_vals,inds,model='mult')

    std_z_mult, std_true, allPreds_mult, allTrues = pred_stds(par_vals,inds,ctrlInds,model='mult')
    std_true_err = std_true/np.sqrt(len(ctrlInds))

    z_add,true_add  = pred_params(par_vals,inds,model='add') #np.log2(10**Z)

    z_mult = z_mult[filt,:]
    true = true[filt,:]
    z_add = z_add[filt,:]
    std_true_err = std_true_err[filt,:]

    b_multAdd = [get_multAdd_annot(z_add[i,0],z_mult[i,0],true[i,0],std_true_err[i,0],ci) for i in range(meanU_add.shape[0])]
    beta_multAdd = [get_multAdd_annot(z_add[i,1],z_mult[i,1],true[i,1],std_true_err[i,1],ci) for i in range(meanU_add.shape[0])]

    res['gene'] = sd_arr[inds[0]].gene_names[filt]
    res['exprAnnot'] = means_multAdd
    res['exprFC'] = true_meanU
    res['bAnnot'] = b_multAdd
    res['betaAnnot'] = beta_multAdd
    res['bStd'] = std_true[filt,0]
    res['betaStd'] = std_true[filt,1]
    res['gammaStd'] = std_true[filt,2]

    res['bFC'] = np.log2(10**true[:,0])
    res['betaFC'] = np.log2(10**true[:,1])
    res['gammaFC'] = np.log2(10**true[:,2])

    return res



In [None]:
#Save all biohpysical parameters, and the control to compare all against
control = 7

n_genes = sr_arr[0].n_genes
n_phys_pars = sr_arr[0].sp.n_phys_pars
par_vals_uncorrected = np.zeros((2,n_datasets,n_genes,n_phys_pars))


for j in range(n_datasets):
    par_vals_uncorrected[0,j,:,:] = sr_arr[control].phys_optimum #control params
    par_vals_uncorrected[1,j,:,:] = sr_arr[j].phys_optimum


In [None]:
ctrlInds = [6,4,3] #individual control conditions
control=7

#Set which results are the individual and combined perturbation condition
inds = [10,12,11,control] #......

In [None]:
#Make filter for
cond1_res = sr_arr[inds[0]]
cond2_res = sr_arr[inds[1]]
both_res = sr_arr[inds[2]]
ctrl_res = sr_arr[inds[3]]
forFilt = ~cond1_res.rejected_genes & ~cond2_res.rejected_genes & ~both_res.rejected_genes & ~ctrl_res.rejected_genes

In [None]:
res = get_multAdd_preds(par_vals_uncorrected,inds,ctrlInds,sd_arr,forFilt,mod=0,ci=1.96) #mod=0 unspliced, mod=1 spliced counts
res.head()

In [None]:
#Save as csv
res.to_csv('multAdd_preds.csv',index=None)

In [None]:
#Distribution of strategies for beta fold changes for repressed genes (lowered expression)
sub = res[(res['exprFC'] <-0.5)]
sns.histplot(sub,x='betaAnnot')