In [None]:
import scanpy as sc 
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import colors, cm
import numpy as np
import scipy
import os

from utils import plot_histogram
from var import *

In [None]:
#set seed to 0 for reproducibility of results
np.random.seed(0)

In [None]:
#Create figures folder
figures_folder = os.path.join(results_folder, 'figures/cell_state/')
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)

1. Assess target gene expression with respect to guide UMI counts

In [None]:
#load log normlised adata object with filter for cells with more than 5000 UMI counts expressed
adata_preprocessed = sc.read_h5ad(f'/lustre/scratch123/hgi/teams/parts/kl11/cell2state_tf_activation/results/20230116_adata_pre_processed.h5ad')

In [None]:
#load adata obs object from crispra csv file
adata_obs_df=pd.read_csv('/lustre/scratch123/hgi/teams/parts/kl11/cell2state_tf_activation/data/crispra_data_exploration/20230202_perturbation_state.csv', index_col=0)
#order index of adata_obs_df to match adata_preprocessed
adata_obs_df = adata_obs_df.reindex(adata_preprocessed.obs.index)
#check that the index of adata_obs_df matches adata_preprocessed
assert adata_obs_df.index.equals(adata_preprocessed.obs.index)
#overwrite adata_preprocessed.obs with adata_obs_df
adata_preprocessed.obs = adata_obs_df

#filter out cells that are not in the perturbation state
adata_perturbed = adata_preprocessed[adata_preprocessed.obs['perturbation_state'] != 'not_perturbed']

In [None]:
guides = adata_perturbed.var_names.str.rstrip('_1|_2').unique()

In [None]:
#select all guides by selecting the names which do not start with ENSG in var_names
targeted_tf_guides = list(adata_preprocessed.var_names[~adata_preprocessed.var_names.str.startswith('ENSG')])
#use list comprehension to split targeted_tf_guides by '_' and select the first element of the list
targeted_tf = [x.split('_')[0] for x in targeted_tf_guides]

In [None]:

# Create a list of expected guides including the control guides 
expected_guides = [
    'AIRE',
    'ASCL1',
    'DLX1',
    'IRF3',
    'LHX6',
    'MAFB',
    'MYOD1',
    'NEUROG2',
    'OLIG2',
    'PROX1',
    'RORA',
    'RORB',
    'SATB2',
    'sgRNA1_SCP',
    'sgRNA6_SCP',
    'ONE_INTERGENIC_SITE_1194'
]

[i for i in expected_guides if i not in targeted_tf]

In [None]:
#compute np log 1 p on adata_perturbed counts layers
adata_perturbed.layers['log1p'] = np.log1p(adata_perturbed.layers['counts'])

In [None]:
#plot scatter plot x axis UMI counts for perturbed cells and y axis target gene expression


guides = adata_perturbed.var_names.str.rstrip('_1|_2').unique()
for short_name in expected_guides:
    print(short_name)
    tmp_guides = guides[guides.str.contains(short_name)]



    
    if len(tmp_guides) == 2:
        #create empty list to store fraction of each guide pair
        
        fig,axs = plt.subplots()
            
        tmp_adata= adata_perturbed[:,adata_perturbed.var_names.str.contains(short_name)]
        #calculate fraction of each guide pair
        tmp_sum_perturbation = tmp_adata.layers['log1p'].sum(axis=1)
        #convert to matrix to array and reshape
        tmp_sum_perturbation = np.array(tmp_sum_perturbation).reshape(-1)



        print(tmp_sum_perturbation)


        #filter adata_perturbedrbed for counts of target gene based on short name
        tmp_target = adata_perturbed[:,adata_perturbed.var['SYMBOL'] == short_name].layers['log1p']
        #convert to array and reshape
        tmp_target = tmp_target.toarray().reshape(-1)
        print(tmp_target.shape)


          


        
        
        
        #plot scatter plot of guide pair fraction
        
        plt.hist2d(tmp_sum_perturbation,tmp_target,bins=100,density=True, 
                        norm = colors.LogNorm())

        #add x and y label
        plt.xlabel(f'{short_name} guide pair count (log1p)')
        plt.ylabel(f'{short_name} target count (log1)')



In [None]:
#subset adata for each perturbation state by looping
guides = adata_perturbed.var_names.str.rstrip('_1|_2').unique()
for short_name in expected_guides:
    #map short_name to var_names using .var
    tmp_ENSG= adata_perturbed[:,adata_perturbed.var['SYMBOL']==short_name].var_names
        
    if len(tmp_ENSG) == 1:

        #create subplots 
        fig, axs = plt.subplots(1,2, figsize=(10,5))

        #subset adata_perturbed for each perturbation state and control
        tmp_adata = adata_perturbed[adata_perturbed.obs['perturbation_state'].isin([short_name,'non_activating'])]
        

        print(tmp_ENSG)

        #plot UMAP for target gene expression and perturbation state
        sc.pl.umap(tmp_adata, color=tmp_ENSG, show=False, ax=axs[0],title=short_name)
        sc.pl.umap(tmp_adata, color='perturbation_state', show=False, ax=axs[1])
            
