In [None]:
import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats
from tqdm import tqdm

import anndata
import scanpy as sc

from scmg.preprocessing.data_standardization import GeneNameMapper

gene_name_mapper = GeneNameMapper()

In [None]:
adata = sc.read_h5ad('/GPUData_xingjie/SCMG/hESC_perturb_seq/adata_single_gene_pert.h5ad')
adata.obs_names_make_unique()
adata

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)
adata

In [None]:
feature_calls = np.unique(adata.obs['feature_call'])

In [None]:
kd_dict = {
    'feature_call' : [],
    'perturbed_gene' : [],
    'n_cells' : [],
    'mean_exp_ctl' : [],
    'mean_exp_perturbed' : [],
    'fc' : [],
    'pval' : []
}

adata_ctl = adata[adata.obs['perturbed_gene'] == 'non-targeting']

for feature in feature_calls:
    adata_feature = adata[adata.obs['feature_call'] == feature]
    perturbed_gene = adata_feature.obs['perturbed_gene'].values[0]

    kd_dict['feature_call'].append(feature)
    kd_dict['perturbed_gene'].append(perturbed_gene)
    kd_dict['n_cells'].append(adata_feature.n_obs)

    if perturbed_gene not in adata.var_names:
        kd_dict['mean_exp_ctl'].append(np.nan)
        kd_dict['mean_exp_perturbed'].append(np.nan)
        kd_dict['fc'].append(np.nan)
        kd_dict['pval'].append(np.nan)

    else:
        control_exps = adata_ctl[:, perturbed_gene
                             ].X.toarray().reshape(-1)
        perturbed_exps = adata_feature[:, perturbed_gene
                               ].X.toarray().reshape(-1)
        
        # Calculate the fold-change and p-val for the perturbed gene
        mean_exp_ctl = np.mean(control_exps)
        mean_exp_perturbed = np.mean(perturbed_exps)
        pval = scipy.stats.mannwhitneyu(control_exps, perturbed_exps, alternative='greater')[1]

        kd_dict['mean_exp_ctl'].append(mean_exp_ctl)
        kd_dict['mean_exp_perturbed'].append(mean_exp_perturbed)
        kd_dict['fc'].append(mean_exp_perturbed / mean_exp_ctl)
        kd_dict['pval'].append(pval)

    print(kd_dict['feature_call'][-1], kd_dict['perturbed_gene'][-1], kd_dict['n_cells'][-1], 
          kd_dict['mean_exp_ctl'][-1], kd_dict['mean_exp_perturbed'][-1], kd_dict['fc'][-1], kd_dict['pval'][-1])
    

kd_df = pd.DataFrame(kd_dict).set_index('feature_call')
kd_df.to_csv('KD_efficiency.csv')

In [None]:
kd_df = pd.read_csv('KD_efficiency.csv', index_col=0)
kd_df

In [None]:
kd_df['fc'].hist(bins=50, range=(0, 2))
plt.xlabel('Target gene expression fold-change')
plt.ylabel('Number of guide RNAs')

In [None]:
plt.scatter(np.log2(kd_df['fc']), -np.log10(kd_df['pval']), s=1)

In [None]:
significant_kd_df = kd_df[(kd_df['pval'] < 0.05) & (kd_df['fc'] < 0.9)]
significant_kd_df

In [None]:
gene_significant_guide_count = significant_kd_df['perturbed_gene'].value_counts()
gene_significant_guide_count

In [None]:
count_of_counts = gene_significant_guide_count.value_counts()

plt.bar(count_of_counts.index, count_of_counts.values)
plt.xlabel('Number of functional guides')
plt.ylabel('Number of genes')