In [None]:
import scanpy as sc 
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import colors, cm
import numpy as np
import scipy
import os
import re

from utils import plot_histogram
from var import *
from guide_assignment import *


In [None]:
#set seed to 0 for reproducibility of results
np.random.seed(0)

In [None]:
#define which dataset to use
dataset = '49_04_l4_3_s_gex'

In [None]:
#Create figures folder
figures_folder = os.path.join(results_folder, f'figures/{dataset}/guide_assignment/')
if not os.path.exists(figures_folder):
    os.makedirs(figures_folder)

## 1. Assessing extent of guide overassignment issue

In [None]:
#load raw anndata object
adata = sc.read_h5ad('/lustre/scratch123/hgi/teams/parts/kl11/cell2state_tf_activation/results/20230301_adata_49_04_l4_3_s_gex_pre_processed_filtered_above_5000.h5ad')

### 1.1 Load barcode data and experimental design for automation

In [None]:
# load barcode count data 
barcode_count = pd.read_csv('/lustre/scratch123/hgi/teams/parts/kl11/cell2state_tf_activation/data/barcode_output/Barcode-3S_barcode_counts.csv', index_col=0)

In [None]:
# function to add guide barcodes to adata obsm
def add_guide_barcodes(adata, barcode_count, keep_cells_w_no_guide_barcodes=True):
    """
    Add guide barcodes to adata obsm

    Parameters
    ----------
    adata : AnnData
        AnnData object containing the raw data
    barcode_count : pd.DataFrame
        DataFrame containing the barcode counts for each guide barcode
    keep_cells_w_no_guide_barcodes : bool, optional
        If True, keep cells with no guide barcodes assigned, by default True
    Returns
    -------
    adata : AnnData
        AnnData object containing the raw data with guide barcodes added to adata.obsm
        
    """

    if keep_cells_w_no_guide_barcodes:
        # reindex barcode count index based ada obs column 'barcodes'
        adata.obs.index = adata.obs['barcodes']
        # add non intersecting guides to barcode couns
        barcode_count = barcode_count.reindex(adata.obs.index.unique(), fill_value=0)
        # add barcode count to adata obsm
        adata.obsm['barcode_count'] = barcode_count
        
    else:
        # adata obs column 'barcodes' use as .obs index
        adata.obs.index = adata.obs['barcodes']
        # filter adata obs based on barcode count data -> this drops cells without guide barcodes being assigned
        adata = adata[adata.obs.index.isin(barcode_count.index)]
        # filter barcode count data based on adata obs -> this drops barcodes without cells being assigned
        barcode_count = barcode_count.loc[adata.obs.index.unique()]
        # assign barcode count to adata obsm
        adata.obsm['barcode_count'] = barcode_count
    
    #replace all nan values with 0
    adata.obsm['barcode_count'] = adata.obsm['barcode_count'].fillna(0)

    return adata

In [None]:
# assign guide barcodes to adata obsm
adata = add_guide_barcodes(adata, barcode_count, keep_cells_w_no_guide_barcodes=True)

In [None]:
#assign guide barcode availability to adata obs
adata.obs['guide_barcodes'] = adata.obsm['barcode_count'].sum(axis=1) > 0
#change true/false to available/not available
adata.obs['guide_barcodes'] = adata.obs['guide_barcodes'].replace({True:'available', False:'not available'})

In [None]:
#plot umap of guide barcode availability
sc.pl.umap(adata, color='guide_barcodes')

In [None]:
# load experimental design data 
experimental_design = pd.read_csv('/lustre/scratch123/hgi/teams/parts/kl11/cell2state_tf_activation/data/49_04_exp_design_20221209 - experimental_design.csv')

In [None]:
#function pivot experimental design data from wide to long format
def pivot_experimental_design(df):
    '''
    Pivot experimental design data from wide to long format with data wrangling to match barcode count data

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame containing the experimental design data
    Returns
    -------
    df_long : pd.DataFrame
        DataFrame containing the experimental design data in long format
        
    '''
    # create index by merging all columns containing pertubation in column name
    df['perturbation'] = df.filter(regex='perturbation').apply(lambda x: '_'.join(x.dropna().astype(str)), axis=1)
    # drop double control condition by dropping duplicates
    df = df.drop_duplicates(subset='perturbation', keep='first')
    #convert all columns header to string
    df.columns = df.columns.astype(str)

    # pivot experimental design data from wide to long format drop variable column
    guide_col=df.columns[df.columns.str.startswith('guide')]
    df_long = pd.melt(df, id_vars=['perturbation','number_of_pert'], value_vars=guide_col, var_name='guide_no', value_name='guide')
    df_long = df_long.drop(columns=['guide_no'])
    # sort rows by perturbation
    df_long = df_long.sort_values(by=['perturbation'])
    # drop all nan columns in guide column
    df_long = df_long.dropna(subset=['guide'])
    #in grna column replace regex r'+\d+' with ''
    df_long['guide'] = df_long['guide'].str.replace(r'\+.+', '', regex=True)
    #create a columns with target gene name by splitting guide column on '_'
    df_long['guide_target_gene'] = df_long['guide'].str.split('_gRNA').str[0]
    #for CONTROL condition set target gene to 'CONTROL'
    df_long.loc[df_long['perturbation'] == 'CONTROL', 'guide_target_gene'] = 'CONTROL'

    return df_long

In [None]:
experimental_design_long = pivot_experimental_design(experimental_design)
#filter experimental design long for guides which are in adata obsm barcode count
experimental_design_long = experimental_design_long[experimental_design_long['guide'].isin(adata.obsm['barcode_count'].columns)]

### 1.2 guide UMI count distribution

In [None]:
#assign raw counts to adata
adata.X = adata.layers['counts']

In [None]:
dense_count_2d_arr = sparse_to_2d_arr(adata)
guide_frequency = count_guide_frequency(dense_count_2d_arr)

### 1.2.1 guide UMI count histograms

In [None]:
#plot total guide UMI counts from adata.obsm['barcode_count']
# order columns alphabetically
adata.obsm['barcode_count'] = adata.obsm['barcode_count'].reindex(sorted(adata.obsm['barcode_count'].columns), axis=1)
adata.obsm['barcode_count'].sum(axis=0).plot(kind='bar', figsize=(20,10))
# add x and y labels
plt.xlabel('Guide')
plt.ylabel('total UMI counts')
# save figure
plt.savefig(os.path.join(figures_folder, f'{today}_guide_UMI_counts.pdf'), dpi=300, bbox_inches='tight')
print(adata.obsm['barcode_count'].sum(axis=0))

In [None]:
#plot average guide UMI counts/cell from adata.obsm['barcode_count']
# order columns alphabetically
adata.obsm['barcode_count'] = adata.obsm['barcode_count'].reindex(sorted(adata.obsm['barcode_count'].columns), axis=1)
adata.obsm['barcode_count'].mean(axis=0).plot(kind='bar', figsize=(20,10))

In [None]:
#plot the amount of cells with a guide barcode for each guide
# order columns alphabetically
#count non zero values per column
adata.obsm['barcode_count'] = adata.obsm['barcode_count'].reindex(sorted(adata.obsm['barcode_count'].columns), axis=1)
adata.obsm['barcode_count'].astype(bool).sum(axis=0).plot(kind='bar', figsize=(20,10))
# add x and y labels
plt.xlabel('Guide')
plt.ylabel('Number of cells with guide barcode counts')

In [None]:
#count number of unique guide per cell by counting number of all non-zero values in each row
guide_per_cell = np.count_nonzero(adata.obsm['barcode_count'], axis=1)
#plot histogram of guide per cell
plt.hist(guide_per_cell, bins=100)
# add x and y labels
plt.xlabel('Number of guides per cell')
plt.ylabel('Number of cells')
# save figure
plt.savefig(os.path.join(figures_folder, f'{today}_unique_guides_per_cell.pdf'), dpi=300, bbox_inches='tight')

### 1.3 Guide assignment

#### 1.3.1 Guide fraction for guide pool

In [None]:
def compute_guide_fraction(array):
    """
    Parameters
    ----------
    array : 2d array
    Returns
    -------
    guide_frac_per_cell : 2d array
        DESCRIPTION.
    Compute fraction per column to determine guide fraction of each cell
    """
    #compute fraction per column
    guide_frac_per_cell=array.T/array.sum(axis=1)
    guide_frac_per_cell=guide_frac_per_cell.T
    #sort fraction in ascending order for each row
    guide_frac_per_cell.sort(axis=0)
    #remove all zeros in a column 
    guide_frac_per_cell=guide_frac_per_cell[~np.all(guide_frac_per_cell == 0, axis = 1),:]
    return guide_frac_per_cell

In [None]:
guide_frac_per_cell=compute_guide_fraction(adata.obsm['barcode_count'].values)

In [None]:
barcode_count=adata.obsm['barcode_count'].values
guide_frac_per_cell=barcode_count.T/barcode_count.sum(axis=1)
guide_frac_per_cell=guide_frac_per_cell.T

In [None]:
#write a function to extract all unique targets appearing in multiple perturbations
def extract_multiple_pert_targets(experimental_design_long):
    """
    This function extracts all unique targets appearing in multiple perturbations from the experimental design long format dataframe.
    Parameters
    ----------
    experimental_design_long : pd.DataFrame
        Long format dataframe with experimental design.

    Returns
    -------
    multiple_pert_targets : list
        List of all unique targets appearing in multiple perturbations.

    """
    # subset experimental design long for multiple perturbations
    multiple_pert = experimental_design_long[experimental_design_long['number_of_pert'] > 1]
    #extract all unique targets appearing in multiple perturbations
    multiple_pert_targets = multiple_pert['perturbation'].unique()
    multiple_pert_targets = [string.split('_') for string in  multiple_pert_targets]
    #flatten list
    multiple_pert_targets = [item for sublist in multiple_pert_targets for item in sublist]
    #unique elements
    multiple_pert_targets = list(set(multiple_pert_targets))
    
    return multiple_pert_targets


In [None]:
multiple_pert_targets = extract_multiple_pert_targets(experimental_design_long)

In [38]:
#write function for 1d hist plot
def plot_1d_hist(filtered_pert_frac, filtered_pert_count, pert_string, path_to_save=None):
        """
        This function plots a 1d histogram of guide fraction per cell and guide counts per cell for a given perturbation.
        Parameters
        ----------
        filtered_pert_frac : pd.DataFrame
                Dataframe with guide fraction per cell for each perturbation.
        filtered_pert_count : pd.DataFrame
                Dataframe with guide counts per cell for each perturbation.
        pert_string : list
                List of strings containing the perturbation name.
        path_to_save : str, optional
                Path to save figure. The default is None.
        Returns
        -------
        None.
        """
        #assert that pert_string has only one element
        assert len(pert_string) == 1, 'pert_string should only contain one element'
        #initialise figure and axes with two subplots
        fig, ax = plt.subplots(1, 2, figsize=(20,10))
        # set title of whole figure
        fig.suptitle(f'{pert}')

        # plot histogram of guide fraction per cell using pert_frac
        ax[0].hist(filtered_pert_frac[pert_string[0]], bins=20)
        # add x and y labels
        ax[0].set_xlabel(f'{pert_string[0]} guide fraction per cell')
        ax[0].set_ylabel(f'Number of cells')

        # plot histogram of guide counts per cell using pert_count
        # add x and y labels
        ax[1].hist(filtered_pert_count[pert_string[0]], bins=20)
        ax[1].set_xlabel(f'{pert_string[0]} guide counts per cell')
        ax[1].set_ylabel('Number of cells')

        if path_to_save:
                # save figure
                # plt.savefig(os.path.join(figures_folder, f'{today}_{pert}_guide_fraction.pdf'), dpi=300, bbox_inches='tight')
                plt.savefig(path_to_save, dpi=300, bbox_inches='tight')


In [47]:
#write function for 2d hist plot
def plot_2d_hist(filtered_pert_frac, filtered_pert_count, pert_string, path_to_save=None):
        """
        This function plots a 2d histogram of guide fraction per cell and guide counts per cell for a given perturbation.
        Parameters
        ----------
        filtered_pert_frac : pd.DataFrame
                Dataframe with guide fraction per cell for each perturbation.
        filtered_pert_count : pd.DataFrame
                Dataframe with guide counts per cell for each perturbation.
        pert_string : list
                List of strings containing the perturbation name.
        path_to_save : str, optional
                Path to save figure. The default is None.
        Returns
        -------
        None.
        """
        #assert that pert_string has only one element
        assert len(pert_string) == 2, 'pert_string should contain two elements'
        #initialise figure and axes with two subplots
        fig, ax = plt.subplots(1, 2, figsize=(20,10))
        # set title of whole figure
        fig.suptitle(f'{pert}')

        # plot histogram of guide fraction per cell using pert_frac
        ax[0].hist2d(filtered_pert_frac[pert_string[0]], filtered_pert_frac[pert_string[1]], bins=50, norm=colors.LogNorm(), cmap='viridis')
        # add x and y labels
        ax[0].set_xlabel(f'{pert_string[0]} guide fraction per cell')
        ax[0].set_ylabel(f'{pert_string[1]} guide counts per cell')

        # plot histogram of guide counts per cell using pert_count
        # add x and y labels
        ax[1].hist2d(filtered_pert_count[pert_string[0]], filtered_pert_count[pert_string[1]], bins=50, cmap='viridis', norm=colors.LogNorm())
        ax[1].set_xlabel(f'{pert_string[0]} guide counts per cell')
        ax[1].set_ylabel(f'{pert_string[1]} guide fraction per cell')

        if path_to_save:
                # save figure
                # plt.savefig(os.path.join(figures_folder, f'{today}_{pert}_guide_fraction.pdf'), dpi=300, bbox_inches='tight')
                plt.savefig(path_to_save, dpi=300, bbox_inches='tight')

In [50]:
#plot 3d scatter plot
def plot_3d_scatter(filtered_pert_frac, filtered_pert_count, pert_string, path_to_save=None):
        """
        This function plots a 3d scatter plot of guide fraction per cell and guide counts per cell for a given perturbation.
        Parameters
        ----------
        filtered_pert_frac : pd.DataFrame
                Dataframe with guide fraction per cell for each perturbation.
        filtered_pert_count : pd.DataFrame
                Dataframe with guide counts per cell for each perturbation.
        pert_string : list
                List of strings containing the perturbation name.
        path_to_save : str, optional
                Path to save figure. The default is None.
        Returns
        -------
        None.
        """
        #assert that pert_string has only one element
        assert len(pert_string) == 3, 'pert_string should contain three elements'

        #initialise figure and axes with two subplots
        fig, ax = plt.subplots(1, 2, figsize=(20,10))
        # turn off axis
        ax[0].axis('off')
        ax[1].axis('off')
        # set title of whole figure
        fig.suptitle(f'{pert}')
        # set 3d figure
        ax[0] = fig.add_subplot(1, 2, 1, projection='3d')
        ax[1] = fig.add_subplot(1, 2, 2, projection='3d')
        # tight layout
        fig.tight_layout()
        # plot 3d scatter plot
        ax[0].scatter(filtered_pert_frac[pert_string[0]], filtered_pert_frac[pert_string[1]], filtered_pert_frac[pert_string[2]], marker='o', alpha=0.5)
        # add x, y and z labels
        ax[0].set_xlabel(f'{pert_string[0]} guide fraction per cell')
        ax[0].set_ylabel(f'{pert_string[1]} guide fraction per cell')
        ax[0].set_zlabel(f'{pert_string[2]} guide fraction per cell')

        ax[1].scatter(filtered_pert_count[pert_string[0]], filtered_pert_count[pert_string[1]], filtered_pert_count[pert_string[2]], marker='o', alpha=0.5)
        # add x, y and z labels
        ax[1].set_xlabel(f'{pert_string[0]} guide count per cell')
        ax[1].set_ylabel(f'{pert_string[1]} guide count per cell')
        ax[1].set_zlabel(f'{pert_string[2]} guide count per cell')


        if path_to_save:
                # save figure
                # plt.savefig(os.path.join(figures_folder, f'{today}_{pert}_guide_fraction.pdf'), dpi=300, bbox_inches='tight')
                plt.savefig(path_to_save, dpi=300, bbox_inches='tight')

In [62]:
#assign new columnt to assign perturbation state
adata.obs['pert_state'] = 'not_perturbed'
for pert in experimental_design_long['perturbation'].unique():
    #create figures folder if it does not exist
    tmp_figures_folder = os.path.join(figures_folder, 'perturbation_frac_counts')
    if not os.path.exists(tmp_figures_folder):
        os.makedirs(tmp_figures_folder)

    #filter for guides of perturbation
    guides = experimental_design_long[experimental_design_long['perturbation'] == pert]['guide'].unique()   

    #split perturbation string
    pert_string = pert.split('_')
    #sum for each perturbation the number of counts per cell with dict comprehension by summing row-wise adatas obsm['barcode_count'] 
    pert_count = {pert: adata.obsm['barcode_count'][experimental_design_long[experimental_design_long['guide_target_gene']==pert]['guide'].unique()].sum(axis=1) for pert in pert_string}

    #calculate fraction of perturbation counts per cell
    pert_frac = {k: v / adata.obsm['barcode_count'].sum(axis=1) for k, v in pert_count.items()}
    #fill all nan values with 0 in pert_frac
    pert_frac = {k: v.fillna(0) for k, v in pert_frac.items()}
    #create path to save figure
    path_to_save = os.path.join(tmp_figures_folder, f'{today}_{pert}.pdf')
    
    # check if string contains any of the multiple perturbations
    if any(x in pert for x in multiple_pert_targets):

        if len(pert_string)==1:
            print(pert)
            # extract all conditions where pert is one of the multiple perturbations
            conditions = experimental_design_long[experimental_design_long['perturbation'].str.contains(pert)]['perturbation'].unique()
            # split conditions by '_'
            conditions = [string.split('_') for string in conditions]
            # flatten list
            conditions = [item for sublist in conditions for item in sublist]
            # remove pert from conditions
            conditions_to_exclude = [x for x in conditions if x != pert]
            # extract list of guides for conditions to exclude based on guide_target_gene column 
            guides_to_exclude = experimental_design_long[experimental_design_long['guide_target_gene'].isin(conditions_to_exclude)]['guide'].unique()
            
            #create boolean conditions and apply to pert_count and pert_frac
            #boolean mask to filter for cells with >0 guide count per cell
            tmp_count_mask = adata.obsm['barcode_count'][guides].sum(axis=1) > 0
            #boolean mask for guides to exclude all counts need to be 0 
            tmp_count_mask = tmp_count_mask & (adata.obsm['barcode_count'][guides_to_exclude].sum(axis=1) == 0)
            #apply mask to pert_count and pert_frac
            filtered_pert_count = {k: v[tmp_count_mask] for k, v in pert_count.items()}
            filtered_pert_frac = {k: v[tmp_count_mask] for k, v in pert_frac.items()}

            # #plot 1d hist for one condition
            # plot_1d_hist(filtered_pert_frac, filtered_pert_count, pert_string, path_to_save)

        #plot hist2d for two conditions
        elif len(pert_string)==2:
            # boolean mask to filter for cells with each pert count per cell >0  in pert_count
            tmp_count_mask = (pert_count[pert_string[0]] > 0) & (pert_count[pert_string[1]] > 0)
            # apply mask to pert_count and pert_frac
            filtered_pert_count = {k: v[tmp_count_mask] for k, v in pert_count.items()}
            filtered_pert_frac = {k: v[tmp_count_mask] for k, v in pert_frac.items()}

            # #plot 2d hist for two conditions
            # plot_2d_hist(filtered_pert_frac, filtered_pert_count, pert_string, path_to_save)

        elif len(pert_string)==3:
            #plot 3d scatter plot
            # boolean mask to filter for cells with each pert count per cell >0  in pert_count
            tmp_count_mask = (pert_count[pert_string[0]] > 0) & (pert_count[pert_string[1]] > 0) & (pert_count[pert_string[2]] > 0)
            # apply mask to pert_count and pert_frac
            filtered_pert_count = {k: v[tmp_count_mask] for k, v in pert_count.items()}
            filtered_pert_frac = {k: v[tmp_count_mask] for k, v in pert_frac.items()}

            # #plot 3d scatter plot
            # plot_3d_scatter(filtered_pert_frac, filtered_pert_count, pert_string, path_to_save)

        
    else:
        # create boolean mask to filter for cells with >1 guide count per cell
        tmp_count_mask = adata.obsm['barcode_count'][guides].sum(axis=1) > 1
        # apply mask to pert_count and pert_frac
        filtered_pert_count = {k: v[tmp_count_mask] for k, v in pert_count.items()}
        filtered_pert_frac = {k: v[tmp_count_mask] for k, v in pert_frac.items()}

        # #plot hist1d
        # plot_1d_hist(filtered_pert_count, filtered_pert_frac, pert_string, path_to_save)

    #assign perturbation state based on summed frequency of perturbation guides per cell
    #extract all list from dict
    tmp_list = list(filtered_pert_frac.values())
    #sum individual lists
    tmp_sum = sum(tmp_list)
    #filter for cells with sum > 0.5
    tmp_sum = tmp_sum[tmp_sum > 0.5]

    #assign perturbation state to cells with sum > 0.5
    adata.obs.loc[tmp_sum.index, 'pert_state'] = pert


AIRE
ASCL1
CONTROL
DBX2
DLX1
DLX5
FOXN1
GATA2
IRF3
LHX6
MAFB
MYOD1
MYOD1
MYOD1_NEUROD2
NEUROD2
NEUROD2
NEUROG2
NEUROG2
NEUROG2_NEUROD2
NEUROG2_NEUROD2_RORB
NEUROG2_RORB
OLIG1
OLIG2
PROX1
RORA
RORA
RORA_NEUROD2
RORB
RORB
RORB_RORA
RORB_RORA_NEUROD2
SATB2


In [None]:
#save adata
adata.write(os.path.join(results_folder, f'{today}_adata_49_04_l4_3_s_gex_guide_assigned.h5ad'))