# Inference validation with simulation


In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc
import scipy as sp
import itertools
import numpy as np
import scipy.stats as stats
from scipy.integrate import dblquad
import seaborn as sns
from statsmodels.stats.multitest import fdrcorrection
import imp
pd.options.display.max_rows = 999
pd.set_option('display.max_colwidth', -1)
import pickle as pkl
import time

  pd.set_option('display.max_colwidth', -1)


In [2]:
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

import matplotlib.pylab as pylab
params = {'legend.fontsize': 'xx-small',
         'axes.labelsize': 'medium',
         'axes.titlesize':'medium',
         'figure.titlesize':'medium',
      
         'xtick.labelsize':'xx-small',
         'ytick.labelsize':'xx-small'}
pylab.rcParams.update(params)


In [3]:
import sys
sys.path.append('/home/ssm-user/Github/scrna-parameter-estimation/dist/memento-0.0.9-py3.8.egg')
import memento
import memento.simulate as simulate

In [4]:
# import sys
# sys.path.append('/home/ssm-user/Github/single_cell_eb/')
# sys.path.append('/home/ssm-user/Github/single_cell_eb/sceb/')
# import scdd

In [5]:
data_path = '/data_volume/memento/demux/'
fig_path = '/data/home/Github/scrna-parameter-estimation/figures/fig3/'

### Extract parameters from interferon dataset

In [6]:
adata = sc.read(data_path + 'interferon_filtered.h5ad')
adata = adata[adata.obs.cell_type == 'CD4 T cells - ctrl']
data = adata.X.copy()
relative_data = data.toarray()/data.sum(axis=1)

  df_sub[k].cat.remove_unused_categories(inplace=True)


In [7]:
q = 0.07
x_param, z_param, Nc, good_idx = memento.simulate.extract_parameters(adata.X, q=q, min_mean=q)

### Functions for simulating DE, DV, and DC

In [None]:
def simulate_two_datasets(x_param, Nc, n_cells, q, diff='mean'):
    
    log_means_1, log_variances_1 = np.log(x_param[0]), np.log(x_param[1])
    log_means_2, log_variances_2 = log_means_1.copy(), log_variances_1.copy()
    
    if diff == 'null':
        norm_cov_1, norm_cov_2 = 'indep', 'indep'
    if diff == 'mean':
        log_means_2[:500] += 0.5
        norm_cov_1, norm_cov_2 = 'indep', 'indep'
    if diff == 'variability':
        log_variances_2[:500] += 0.5
        norm_cov_1, norm_cov_2 = 'indep', 'indep'
    if diff == 'correlation':
        norm_cov_1 = make_spd_matrix(log_means_1.shape[0])
        norm_corr_1 = norm_cov_1/np.outer(np.sqrt(np.diag(norm_cov_1)), np.sqrt(np.diag(norm_cov_1)))
        norm_corr_subset = norm_corr_1[:100, :100].copy()
        
        change_indices = np.where(norm_corr_subset < 0.5)
        change_indices = (change_indices[0][:150], change_indices[1][:150])
        norm_corr_subset[change_indices] += 0.5
        
        norm_corr_2 = norm_corr_1.copy()
        norm_corr_2[:100, :100] = norm_corr_subset
        norm_cov_2 = norm_corr_2 * np.outer(np.sqrt(np.diag(norm_cov_1)), np.sqrt(np.diag(norm_cov_1)))
    
    data_1 = simulate.simulate_transcriptomes(
        n_cells=n_cells, 
        means=np.exp(log_means_1)*Nc.mean(),
        variances=(np.exp(log_variances_1) + np.exp(log_means_1)**2)*(Nc**2).mean() - np.exp(log_means_1)**2*Nc.mean()**2,
        Nc=Nc,
        norm_cov=norm_cov_1)
    
    data_2 = simulate.simulate_transcriptomes(
        n_cells=n_cells, 
        means=np.exp(log_means_2)*Nc.mean(),
        variances=(np.exp(log_variances_2) + np.exp(log_means_2)**2)*(Nc**2).mean() - np.exp(log_means_2)**2*Nc.mean()**2,
        Nc=Nc,
        norm_cov=norm_cov_2)
    
    true_data = np.vstack([data_1, data_2])
    _, hyper_captured = simulate.capture_sampling(true_data, q=q, process='hyper')
    
    anndata = sc.AnnData(sp.sparse.csr_matrix(hyper_captured))
    anndata.obs['ct_real'] = ['A' for i in range(n_cells)] + ['B' for i in range(n_cells)]
    anndata.obs['ct_shuffled'] = np.random.choice(['A', 'B'], anndata.shape[0])
    
    if diff == 'correlation':
        return anndata, change_indices
    else:
        return anndata, None

    
def calculate_power(n_cells, test='mean', test_null=False):
    
    sim_adata, change_indices = simulate_two_datasets(x_param, Nc, n_cells, q=q, diff=test if not test_null else 'null')
        
    schypo.create_groups(
        sim_adata,
        q=q,
        label_columns=['ct_real'], 
        inplace=True)
    
    schypo.compute_1d_moments(
        sim_adata, 
        inplace=True, 
        filter_genes=False, 
        residual_var=True,
        filter_mean_thresh=0.0, 
        min_perc_group=.9)
        
    if test == 'mean':

        schypo.ht_1d_moments(
            sim_adata, 
            formula_like='1 + ct_real',
            cov_column='ct_real',
            num_boot=5000,
            num_cpus=6,
            verbose=3)

        power = (sim_adata.uns['schypo']['1d_ht']['mean_asl'][:500] < 0.05).mean()
        pvals = sim_adata.uns['schypo']['1d_ht']['mean_asl']
        
    if test == 'variability':

        schypo.ht_1d_moments(
            sim_adata, 
            formula_like='1 + ct_real',
            cov_column='ct_real',
            num_boot=5000,
            num_cpus=6,
            verbose=3)
        
        power = (sim_adata.uns['schypo']['1d_ht']['var_asl'][:500] < 0.05).mean()
        pvals = sim_adata.uns['schypo']['1d_ht']['var_asl']
        
    if test == 'correlation':
        
        schypo.compute_2d_moments(
            sim_adata, 
            inplace=True, 
            gene_1=np.arange(50).astype(str).tolist(),
            gene_2=np.arange(50).astype(str).tolist())
        
        schypo.ht_2d_moments(
            sim_adata, 
            formula_like='1 + ct_real',
            cov_column='ct_real',
            num_boot=5000,
            num_cpus=6,
            verbose=3)
        
        c_pv = sim_adata.uns['schypo']['2d_ht']['corr_asl'][:50, :50]
        power = (c_pv[change_indices] < 0.05).mean()
        pvals = sim_adata.uns['schypo']['2d_ht']['corr_asl']
                
    return power, pvals, sim_adata