In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scanpy as sc
import scipy.io as sio
import anndata as ad
import os as os
import sys as sys
import seaborn as sns
from load_10X_matrices import load_10X_matrices
from scipy import stats
import argparse


In [None]:
import rpy2
import rpy2.robjects as robjects
from rpy2.robjects.packages import importr
from rpy2.robjects import pandas2ri

In [None]:
import sys as sys
# sys.path.append('/home/xinghua/projects/PanCancer_scRNA_analysis/utils/')
# from scRNA_utils import *

In [None]:
def clustering_adata(adata):
    '''
    This function will cluster an AnnData object 

    Parameters:
        adata: AnnData object

    Returns:
        adata: AnnData object with a new column in adata.obs called 'leiden' that contains the cluster label for each cell
    '''

    # check if adata is AnnData object
    if not isinstance(adata, ad.AnnData):
        print ("Input adata is not an AnnData object")
        return None
    
    # check if adata has raw data
    # if adata_EGAS.raw is None:
    #     print ("Input adata does not have raw data")
    #     return None
    
    # check if adata has more than 1000 cells
    if adata.shape[0] < 1000:
        print ("Input adata has less than 1000 cells")
        return None
    
    # check if adata has more than 1000 genes
    if adata.shape[1] < 1000:
        print ("Input adata has less than 1000 genes")
        return None 
    
    # check if adata has more than 10000 genes
    # if adata.shape[1] > 10000:
    #     # select high veriable genes
    #     sc.pp.highly_variable_genes(adata, n_top_genes=2000)
    #     # filter adata
    #     adata = adata[:, adata.var['highly_variable']]

    #     # check if X is log transformed
    #     if not 'log1p' in adata.layers:
    #         sc.pp.normalize_total(adata, target_sum=1e4)
    #         sc.pp.log1p(adata, base = 2)

    # run PCA
    sc.tl.pca(adata, svd_solver='arpack', n_comps=50)   
    sc.pp.neighbors(adata, n_neighbors=50, n_pcs=50)
    sc.tl.leiden(adata, resolution=0.5)

    #plot UMAP
    sc.tl.umap(adata)
    sc.pl.umap(adata, color=['leiden'], legend_loc='on data', title='leiden')

    return adata

In [None]:
def labelClusterWithCellType(adata, cell_type_markers, cluster_column='leiden'):
    '''
    This function will label each cluster with the cell type that is most abundant in that cluster.

    Parameters:
        adata: AnnData object
        cell_type_markers: a dictionary where the key is the cell type and the value is a list of markers for that cell type
        cluster_column: the column in adata.obs that contains the cluster labels

    Returns:
        adata: AnnData object with a new column in adata.obs called 'cell_type' that contains the cell type label for each cell
    
    '''

    # find total number of clusters   
    cls_ids = adata.obs[cluster_column].unique()
    
    # iterate through all cluster
    for i in cls_ids:
        # find cells in cluster i        
        cell_in_cls_i = adata.obs[cluster_column] == i  
        # this will return a vector of True/False where True means the cell is in cluster i
        # print('processing cluster: ' + str(i) + ' with ' + str(sum(cell_in_cls_i)) + ' cells')

        # keep track of which cell type is most abundant in cluster i
        cell_type_cluster_overlapp_pct = dict()

        #iterate through key and value of cell_type_markers
        for cell_type, marker_genes in cell_type_markers.items():   
            # Extract the expression of all marker genes for cells in cluster i
            # this will return a sparse matrix of cells x markers
            cell_w_marker_genes = adata.raw.X[:, adata.raw.var_names.isin(marker_genes)] > 0  
            
            # change cell_in_cls_i to numpy array and repeat it to match the shape of cell_w_marker_genes
            cell_in_cls_i_m = np.tile(cell_in_cls_i.to_numpy(), (cell_w_marker_genes.shape[1], 1)).T

            # find cells in cluster i that express the marker
            # this create a matrix of cells x markers where True means the cell express the marker and in cluster i
            cell_w_marker_genes = cell_w_marker_genes.toarray() & cell_in_cls_i_m

            # caclualte average markers expressed in each cell in Marker_genes_i
            nmarker_per_cell = np.sum(cell_w_marker_genes, axis=0) / cell_w_marker_genes.shape[1]
            #print(nmarker_per_cell)

            # keep track of which cell type is most abundant in cluster i
            # assuming the cell type with the highest average marker present is the most abundant         
            cell_type_cluster_overlapp_pct[cell_type] = np.sum(nmarker_per_cell) / sum(cell_in_cls_i)

        # check with cell type is most abundant in cluster i
        max_type = max(cell_type_cluster_overlapp_pct, key=cell_type_cluster_overlapp_pct.get)
        print('Cluster ' + str(i) + ' is most likely ' + max_type + ' with ' + str(cell_type_cluster_overlapp_pct[max_type]) + ' overlap')
        adata.obs.loc[cell_in_cls_i, 'cell_type'] = max_type           
        


In [None]:
def scRNA2PseudoBulkAnnData(adata, sample_id_col = None): 
    '''        
        This function convert a scRNA AnnData oboject to an AnnData object,
           where gene expression from the same sample is merged and normalized as 
           transcript per million (TPM) format.  
         
        Parameters:
            adata: anndata object
            sample_id_col: the column in adata.obs that contains the sample id
        
        Returns:
            adata: AnnData object with adata.X in TPM format.  The annData object 
            is annoted with uns["pseudoBulk"] = "log_2_tpm"
        
    '''
    # check if input adata is AnnData object
    if not isinstance(adata, ad.AnnData):
        print ("Input adata is not an AnnData object")
        return None
    if not sample_id_col:
        print ("sample id column not provided")
        return None
    
    # check if adata have sample id col
    if sample_id_col not in adata.obs.columns:
        print ("sample id", sample_id_col, "column not available in adata.obs")
        return None
    
    # check if adata have raw data
    if not adata.raw:
        print ("adata.raw is not available")
        return None

    col_to_remove = ['ncount_rna', 'nfeature_rna', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'n_genes_by_counts', 'log1p_n_genes_by_counts']
    col_to_keep_in_obs = [x for x in adata.obs.columns.str.lower() if x not in col_to_remove]

    nSamples = len(adata.obs['sample_id'].unique()) 
    nGenes = len(adata.var_names)
    X = np.zeros((nSamples, nGenes), dtype=np.float32)
    df_tpm = pd.DataFrame(X, index=adata.obs['sample_id'].unique(), columns = adata.var_names)

    # remove obs columns that are added by sc.pp functions
    col_to_remove = ['ncount_rna', 'nfeature_rna', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'n_genes_by_counts', 'log1p_n_genes_by_counts']
    col_to_keep_in_obs = [x for x in adata.obs.columns.str.lower() if x not in col_to_remove]
    df_obs = pd.DataFrame(index=adata.obs['sample_id'].unique(), columns = col_to_keep_in_obs)

    for sample in adata.obs['sample_id'].unique():
        tpm = np.sum(adata.X[adata.obs['sample_id'] == sample, :], axis = 0)
        tpm = np.array(tpm / np.sum(tpm) * 1e6, dtype=np.float32) # normalize to TPM/per cell and force to float32
        df_tpm.loc[sample,:] = tpm

        # Populate df_obs
        for col in adata.obs.columns:
            df_obs.loc[sample, col] = adata.obs.loc[adata.obs[sample_id_col] == sample, col].unique()[0]
    
    for sample in adata.obs[sample_id_col].unique():
        # ...
        df_obs.loc[sample, col] = adata.obs.loc[adata.obs[sample_id_col] == sample, col].unique()[0]

    # Create an AnnData object for the pseudo-bulk RNA data
    adata_sample_tpm = ad.AnnData(df_tpm.values, obs=df_obs, var=adata.var)
    adata_sample_tpm.uns["pseudoBulk"] = "tpm"
    adata_sample_tpm.raw = adata_sample_tpm
    
    return adata_sample_tpm


In [None]:
def analyze_cell_type(adata, cell_type, markers, adata_name):
    '''
    this function should automatically extract a desired cell type for the user to save to a .h5ad file.  

    '''

    # set Scanpy plotting parameters
    sc.set_figure_params()

    # make a copy of adata
    adata = adata.copy()

    # extract cells and create a new AnnData object
    adata_type = adata[adata.obs['cell_type'] == cell_type].copy()

    # restore the X to the original raw.X for re-processing
    adata_type = ad.AnnData(X=adata_type.raw.X, obs=adata_type.obs, var=adata_type.raw.var, obsm=adata_type.obsm, uns=adata_type.uns)
    adata_type.raw = adata_type
    print(str(adata_type.shape))

    # re-calculate highly variable genes
    sc.pp.highly_variable_genes(adata_type, min_mean=0.0125, max_mean=3, min_disp=0.5)
    sc.pl.highly_variable_genes(adata_type)
    len(adata_type.var_names)

    # re-cluster the specified cell type
    sc.tl.pca(adata_type, svd_solver='arpack', n_comps=40)
    sc.pp.neighbors(adata_type, n_neighbors=80, n_pcs=40)
    sc.tl.leiden(adata_type, resolution=.5)
    
    sc.tl.umap(adata_type)
    sc.pl.umap(adata_type, color='leiden', legend_loc='on data')

    

    # apply cell type labels using the marker dictionary
    adata_type.obs.drop(columns="cell_type", inplace=True)
    labelClusterWithCellType(adata_type, markers, cluster_column='leiden')

    # UMAP
    sc.pl.umap(adata_type, color='cell_type')

    # more UMAPs
    sc.pl.umap(adata_type, color=['timepoint', 'cell_type'])

    # save a copy of adata_type under custom name
    globals()[adata_name] = adata_type


In [None]:
def pairwise_ttest(adata, condition_key = None, sample_id_col = None, patient_id_col = None, pval_cutoff = 0.05, log2fc_cutoff = 1):
    '''
    This function is to find the genes or gene modules that are differentially expression
    between two conditions collected from a same subject, e.g., tumor-vs-normal or before or after a 
    specific treatment. The function will perform pairwise t-test between two conditions for each gene.

    Steps in the process:
        1. Create pseudo-bulk RNA data for each sample 
        2. Identify cells from a sample that belong to a specific sample.
        3. Match samples from the same patient.
        4. Perform pairwise t-test between two conditions for each gene.


    Parameters:
        adata: AnnData object with adata.X in TPM format.  The annData object
            If annoted with uns["pseudoBulk"] = "log_2_tpm", the data is pseudo-bulk RNA in log2(TPM+1) format.
        Condition_key: the column in adata.obs that contains the condition information based on which pairwise t-test will be performed.
        sample_id_col: the column in adata.obs that contains the sample id
        patient_id_col: the column in adata.obs that contains the patient id
    
    return:
        A dataframe consisting of a list of genes and statistics of pair-wise t-test between two conditions.
    
    '''

    # check inputs
    if not isinstance(adata, ad.AnnData):
        print ("Input adata is not an AnnData object")
        return None 
    if not condition_key:
        print ("Condition key not provided")
        return None
    # check if condition to compare is binary
    if len(adata.obs[condition_key].unique()) != 2:
        print ("Condition to compare is not binary")
        return None
    if not sample_id_col:
        print ("sample id column not provided")
        return None
    if not patient_id_col:
        print ("patient id column not provided")
        return None
    # check if adata have raw data
    if not adata.raw:
        print ("adata.raw is not available")
        return None
    
    # assume data is already pseudo bulk, check
    if 'pseudoBulk' in adata.uns.keys():
        print ("Input adata is in pseudo-bulk RNA data. Convert to pseudo-bulk RNA data.")
        adata = scRNA2PseudoBulkAnnData(adata, sample_id_col=sample_id_col)
    
    # Create a 3-d matrix, one dimension is the patient, the other is the gene, the third is the condition
    nPatients = len(adata.obs[patient_id_col].unique())
    nGenes = len(adata.var_names)
    nConditions = len(adata.obs[condition_key].unique())
    X = np.zeros((nConditions, nPatients, nGenes), dtype=np.float32)

    res_df = pd.DataFrame(index=adata.var_names, columns = ['pval', 'log2fc', 'mean_condition1', 'mean_condition2'])
    patients = adata.obs[patient_id_col].unique()  # this is a numpy array
    for index, patient in np.ndenumerate(patients):
        indx_p = index[0]
        # print ("Processing patient %s" % patient)
        # check if the patient has two conditions
        if len(adata.obs[condition_key][adata.obs[patient_id_col] == patient].unique()) < 2:
            # print ("Patient %s does not have two conditions" % patient)
            continue
        # extract data from the patient under condition 1 and condition 2
        condition1 = adata.obs[condition_key].unique()[0]
        condition2 = adata.obs[condition_key].unique()[1]
        # print ("Extract data from patient %s under condition %s & %s" % (patient, condition1, condition2))
        X[0, indx_p, :] = adata.raw.X[(adata.obs[patient_id_col] == patient) & (adata.obs[condition_key] == condition1), :]
        X[1, indx_p, :] = adata.raw.X[(adata.obs[patient_id_col] == patient) & (adata.obs[condition_key] == condition2), :]
        
    # perform paired t-test 
    # for each gene, perform t-test between two conditions of the same patient
    for i in range(nGenes):  # need check how to parallelize this loop, maybe use cupy
        x_1 = X[0, :, i]
        x_2 = X[1, :, i]
        pval = stats.ttest_rel(x_1, x_2)[1]
        gene_name = adata.var_names[i]        
        mean_condition1 = np.mean(x_1)
        mean_condition2 = np.mean(x_2)
        if mean_condition1 == 0 or mean_condition2 == 0:
            log2fc = np.nan
        else:
            log2fc = np.log2(np.mean(x_1) / np.mean(x_2))
        res_df.loc[gene_name, 'pval'] = pval
        res_df.loc[gene_name, 'log2fc'] = log2fc
        res_df.loc[gene_name, 'mean_condition1'] = mean_condition1
        res_df.loc[gene_name, 'mean_condition2'] = mean_condition2

    return res_df

In [None]:
def find_cluster_DEGs_pairwise(adata, cluster_label, condition_key):
    '''
    This function will find differentially expressed genes between two conditions for a given cluster.
    Steps in the process:
        1. Identify cells from a sample that belong to a specific cluster.
        2. Create pseudo-bulk RNA data for each sample.
        3. Match samples from the same patient.
        4. Perform pairwise t-test between two conditions for each gene.
    '''
    # assume data is already pseudo bulk, check
    # 
    
    # Filter cells based on the cluster
    cluster_mask = adata.obs['leiden'] == cluster_label
    adata_cluster = adata[cluster_mask].copy()
    # Create pseudo-bulk RNA data for each sample
    bulk_data = {}
    for sample in adata.obs['sample_id'].unique():
        # Find cells that belong to the specific cluster in this sample
        # Produce pseudo-bulk RNA data
        sample_mask = adata_cluster.obs['sample_id'] == sample
        bulk_data[sample] = np.array(adata_cluster.X[sample_mask].sum(axis=0)).flatten()

    # A dictionary to match samples from the same patient under two conditions.
    # Produce a matrix with the following axes: pre/on, N-patients, N-Genes.

    # create list for storing data
    DEGs = []

    # looping through var names
    for gene in adata_cluster.var_names:
        gene_data = adata_cluster[:, gene]

        # split data into two conditions
        pre_data = gene_data[gene_data.obs[condition_key] == 'Pre']
        on_data = gene_data[gene_data.obs[condition_key] == 'On']

        # perform t-test using scipy
        t_stat, p_value = stats.ttest_ind(pre_data.X, on_data.X)

        # store statistics in a dict
        gene_stats[gene] = {'t_stat': t_stat, 'p_value': p_value}

        # check if differentially expressed
        if np.abs(t_stat) > 0:
            DEGs.append(gene)

    return DEGs

In [None]:
def paird_ttest(adata, condition_key = None, sample_id_col = None, patient_id_col = None, pval_cutoff = 0.05, log2fc_cutoff = 1):
    '''
    This function is to find the genes or gene modules that are differentially expression
    between two conditions collected from a same subject, e.g., tumor-vs-normal or before or after a 
    specific treatment. The function will perform pairwise t-test between two conditions for each gene.

    Steps in the process:
        1. Create pseudo-bulk RNA data for each sample 
        2. Identify cells from a sample that belong to a specific sample.
        3. Match samples from the same patient.
        4. Perform pairwise t-test between two conditions for each gene.


    Parameters:
        adata: AnnData object with adata.X in TPM format.  The annData object
            If annoted with uns["pseudoBulk"] = "log_2_tpm", the data is pseudo-bulk RNA in log2(TPM+1) format.
        Condition_key: the column in adata.obs that contains the condition information based on which pairwise t-test will be performed.
        sample_id_col: the column in adata.obs that contains the sample id
        patient_id_col: the column in adata.obs that contains the patient id
    
    return:
        A dataframe consisting of a list of genes and statistics of pair-wise t-test between two conditions.
    
    '''

    # check inputs
    if not isinstance(adata, ad.AnnData):
        print ("Input adata is not an AnnData object")
        return None 
    if not condition_key:
        print ("Condition key not provided")
        return None
    # check if condition to compare is binary
    if len(adata.obs[condition_key].unique()) != 2:
        print ("Condition to compare is not binary")
        return None
    if not sample_id_col:
        print ("sample id column not provided")
        return None
    if not patient_id_col:
        print ("patient id column not provided")
        return None
    # check if adata have raw data
    if not adata.raw:
        print ("adata.raw is not available")
        return None
    
    # assume data is already pseudo bulk, check
    if 'pseudoBulk' not in adata.uns.keys():
        print ("Input adata is not pseudo-bulk RNA data. Converting to pseudo-bulk RNA data.")
        adata = scRNA2PseudoBulkAnnData(adata, sample_id_col=sample_id_col)
    
    # Create a 3-d matrix, one dimension is the patient, the other is the gene, the third is the condition
    nPatients = len(adata.obs[patient_id_col].unique())
    nGenes = len(adata.var_names)
    nConditions = len(adata.obs[condition_key].unique())
    X = np.zeros((nConditions, nPatients, nGenes), dtype=np.float32)

    res_df = pd.DataFrame(index=adata.var_names, columns = ['pval', 'log2fc', 'mean_condition1', 'mean_condition2'])
    patients = adata.obs[patient_id_col].unique()  # this is a numpy array
    for index, patient in np.ndenumerate(patients):
        indx_p = index[0]
        # print ("Processing patient %s" % patient)
        # check if the patient has two conditions
        if len(adata.obs[condition_key][adata.obs[patient_id_col] == patient].unique()) < 2:
            # print ("Patient %s does not have two conditions" % patient)
            continue
        # extract data from the patient under condition 1 and condition 2
        condition1 = adata.obs[condition_key].unique()[0]
        condition2 = adata.obs[condition_key].unique()[1]
        # print ("Extract data from patient %s under condition %s & %s" % (patient, condition1, condition2))
        X[0, indx_p, :] = adata.raw.X[(adata.obs[patient_id_col] == patient) & (adata.obs[condition_key] == condition1), :]
        X[1, indx_p, :] = adata.raw.X[(adata.obs[patient_id_col] == patient) & (adata.obs[condition_key] == condition2), :]
        
    # perform paired t-test 
    # for each gene, perform t-test between two conditions of the same patient
    for i in range(nGenes):  # need check how to parallelize this loop, maybe use cupy
        x_1 = X[0, :, i]
        x_2 = X[1, :, i]
        pval = stats.ttest_rel(x_1, x_2)[1]
        gene_name = adata.var_names[i]        
        mean_condition1 = np.mean(x_1)
        mean_condition2 = np.mean(x_2)
        if mean_condition1 == 0 or mean_condition2 == 0:
            log2fc = np.nan
        else:
            log2fc = np.log2(np.mean(x_1) / np.mean(x_2))
        res_df.loc[gene_name, 'pval'] = pval
        res_df.loc[gene_name, 'log2fc'] = log2fc
        res_df.loc[gene_name, 'mean_condition1'] = mean_condition1
        res_df.loc[gene_name, 'mean_condition2'] = mean_condition2

    # estimate q-value based on p-value        
    qvalue = importr('qvalue')
    r_p_values = robjects.FloatVector(res_df['pval'])
    r_q_values = qvalue.qvalue(r_p_values)
    res_df['qval'] = np.array(r_q_values.rx2('qvalues'))

    return res_df

In [None]:
def plotGEM(adata, cluster_id_col, cluster_id_to_plot, ncols=4, min_cells_expr=5):
    '''
    Plot the expression of genes in a cluster
    Parameters:
        adata: AnnData object
        cluster_id_col: column name in adata.obs that contains the cluster id
        cluster_id_to_plot: cluster id to plot
        ncols: number of columns in the plot
        min_cells_expr: Minimum number of cells expressing the gene within the cluster
    '''

    # check input values
    if cluster_id_col not in adata.obs.columns:
        print("Error: cluster_id_col not found in adata.obs.columns")
        return
    elif not cluster_id_col:
        print("Error: cluster_id_col is empty")
        return

    # check if cluster_id_to_plot is in adata.obs[cluster_id_col]
    if cluster_id_to_plot not in adata.obs[cluster_id_col].unique():
        print("Error: cluster_id_to_plot not found in adata.obs[cluster_id_col]")
        return

    # Get the indices of genes that have expression above the threshold in more than min_cells_expr cells
    gene_expr_counts = np.sum(adata.X[adata.obs[cluster_id_col] == cluster_id_to_plot] > 25, axis=0)
    gene_indices = np.where(gene_expr_counts >= min_cells_expr)[0]

    # Debug prints to check array sizes
    print("Gene filter array:", gene_expr_counts >= min_cells_expr)
    print("Length of gene_filter:", len(gene_indices))
    print("Length of adata.var_names:", len(adata.var_names))
    print("Number of selected genes:", len(gene_indices))

    # Extract the gene names that satisfy the filter from the original adata object
    GEMs_exprs_in_cls = adata.var_names[gene_indices].tolist()



    # Plot the UMAP with the selected genes
    adata_tmp = adata[adata.obs[cluster_id_col] == cluster_id_to_plot, :].copy()  
    nCells, nGenes = adata_tmp.shape
    sc.pl.umap(adata_tmp, color=GEMs_exprs_in_cls, ncols=ncols)
    
    # Add 'timepoint' to the list of genes to plot
    GEMs_exprs_in_cls = ['timepoint'] + GEMs_exprs_in_cls


In [None]:
def plotGEMs(adata, cluster_id_col, cluster_id_to_plot, ncols=4):
    '''
    Plot the expression of genes in a cluster
    Parameters:
        adata: AnnData object
        cluster_id_col: column name in adata.obs that contains the cluster id
        cluster_id_to_plot: cluster id to plot
        ncols: number of columns in the plot
    '''

    # check input values
    if cluster_id_col not in adata.obs.columns:
        print("Error: cluster_id_col not found in adata.obs.columns")
        return
    elif not cluster_id_col:
        print("Error: cluster_id_col is empty")
        return
    
    # check if cluster_id_to_plot is in adata.obs[cluster_id_col]
    if cluster_id_to_plot not in adata.obs[cluster_id_col].unique():
        print("Error: cluster_id_to_plot not found in adata.obs[cluster_id_col]")
        return
    
    # identify the cells assigned in the cluster_id_to_plot
    adata_tmp = adata[adata.obs[cluster_id_col] == cluster_id_to_plot, :].copy()  
    nCells, nGenes = adata_tmp.shape
    # search for GEMs expressed in the cells of this cluster
    GEMs_exprs_in_cls = adata_tmp.var_names[(np.sum(adata_tmp.X > 25, axis= 0) / nCells > .05)].tolist()
    GEMs_exprs_in_cls = ['timepoint'] + GEMs_exprs_in_cls
    sc.pl.umap(adata_tmp, color = GEMs_exprs_in_cls )

In [None]:
def findDEGsFromClusters(adata, condition_col = None, condition_1 = None, condition_2 = None, top_n_degs=20):
    '''
    This function search for clusters and then find DEGs with each clusters conditioning on specifid conditons.

    Parameters
    --------
    adata: AnnData object
        Annotated data matrix with rows for cells and columns for genes.
    condition_col: the column name of the condition in the adata.obs
    condition_1: the condition_1    
    condition_2: the condition_2

    Returns:
    --------
    DEGs: A dataframe with DEGs and their logFC, pval, pval_adj, etc.


    pseudocode:
    1. find clusters by call leiden or louvian by clustering_adata function
    2. loop through each cluster:
        2.1. extract cells belonging to the cluster (adata.copy())
        2.2. Call paird_ttest funciton using the adata_cluster find DEGs conditioning on the condition_1 and condition_2
        2.3. return the dataframe of DEGs
    '''

    # 1: find clusters using leiden or louvain by clustering_adata function
    if condition_col is None or condition_1 is None or condition_2 is None:
        print("Error: Missing condition information.")
        return None

    adata_clusters = clustering_adata(adata)  # Use the provided clustering_adata function

    # 2: loop through each cluster, extract cells belonging to the cluster, and find DEGs
    clusters = adata_clusters.obs['leiden'].unique()
    result_dfs = []  # store DEG dataframes for each cluster

    for cluster in clusters:
        print(f"Finding DEGs for cluster {cluster}")

        # 2.1. extrac cells belonging to the cluster (adata.copy())
        adata_cluster = adata_clusters[adata_clusters.obs['leiden'] == cluster].copy()

        # 2.2. call paired_ttest function using the adata_cluster to find DEGs conditioning on condition_1 and condition_2
        DEGs_cluster = paird_ttest(adata_cluster, condition_key=condition_col, sample_id_col='sample_id', patient_id_col='patient_id', pval_cutoff=0.05, log2fc_cutoff=1)

        # 2.3. return the dataframe of DEGs
        if DEGs_cluster is not None:
            result_dfs.append(DEGs_cluster)

        # just for fun, some UMAPs
        sc.pp.neighbors(adata_cluster, n_neighbors=30, n_pcs=50)
        sc.tl.umap(adata_cluster)
        sc.pl.umap(adata_cluster, color=['cell_type', 'timepoint'], legend_loc='on data', title=f'Cluster {cluster}')
        
        # UMAP for DEGs
        if not DEGs_cluster.empty:
            # Convert 'pval' column to numeric type
            DEGs_cluster['pval'] = pd.to_numeric(DEGs_cluster['pval'])
            
            top_n_degs_cluster = DEGs_cluster.nsmallest(top_n_degs, 'pval')
            sc.pl.umap(adata_cluster, color=top_n_degs_cluster.index.tolist(), use_raw=False, cmap='viridis', legend_loc='on data')

    # Combine all the DEG dataframes into a single DataFrame
    DEGs = pd.concat(result_dfs)

    return DEGs

## T cells processing

In [None]:
adata = sc.read('/home/data/ICI_exprs/ICI_T_cell_collection/1863-counts_cells_cohort1_T_cells.h5ad')

In [None]:
adata.obs['sample_id'] = ['_'.join(x.split('_')[:3]) for x in adata.obs.index]


In [None]:
adata.obs

In [None]:
adata.var

In [None]:
# T cell 

cell_type = 'T cells'

markers = {
    'CD4'	: ['CD4', 'IL7R'],
    'CD8'	: [ 'CD8A', 'CD8B'],
    'Naïve'	: ['TCF7', 'SELL', 'LEF1', 'CCR7'],
    'Exhausted' : ['LAG3', 'TIGIT', 'PDCD1', 'HAVCR2', 'CTLA4'],
    'Cytotoxic' : ['IL2', 'GZMA', 'GNLY', 'PRF1', 'GZMB', 'GZMK', 'IFNG', 'NKG7'],
    'Treg' : ['IL2RA', 'FOXP3', 'IKZF2', 'IKZF4',  'TNFRSF18'],
    'Gamma-delta' : ['TRGC1', 'TRGC2', 'TRDC'],
    'Th17' : ['IL17A',  'CCR6', 'KLRB1'],  #'IL22',
    'MAIT' : ['SLC4A10', 'KLRB1', 'IL7R', 'DPP4'],  
    'ILC' :	['KIT', 'IL1R1'],
    'Th1' :	['STAT4', 'IL12RB2', 'IFNG'],
    'Th2' :	['GATA3', 'STAT6', 'IL4'],
    'Tfh'	: ['MAF', 'CXCL13', 'CXCR5', 'PDCD1'],
    'NK' :  ['XCL1', 'FCGR3A', 'KLRD1', 'KLRF1', 'NCAM1'],
    'Proliferation' : ['MKI67', 'PCNA', 'STMN1']
}


adata_name = 'adata_Tcell'

analyze_cell_type(adata, cell_type, markers, adata_name)

In [None]:
adata_Tcell.obs

In [None]:
sc.pp.log1p(adata_Tcell)

In [None]:
adata_Tcell.write('/home/data/ICI_exprs/ICI_T_cell_collection/1863-counts_cells_cohort1_T_cells_updated.h5ad')

## real adata_T

In [None]:
adata_T = sc.read('/home/data/ICI_exprs/ICI_T_cell_collection/1863-counts_cells_cohort1_T_cells_updated.h5ad')

In [None]:
adata_T.obs

In [None]:
findDEGsFromClusters(adata_T, condition_col = 'timepoint', condition_1 = 'Pre', condition_2 = 'On', top_n_degs = 20)            

In [None]:
findDEGsFromClusters(adata_T, condition_col = 'timepoint', condition_1 = 'Pre', condition_2 = 'On', top_n_degs = 8)            

In [None]:
for cluster in adata_T.obs['leiden'].unique():
    cluster_cells = adata_T.obs['leiden'] == cluster
    cluster_data = adata_T[cluster_cells, :]

    # Perform UMAP for the cluster
    sc.tl.umap(cluster_data)

    # plot UMAP + label w cluster name
    fig, ax = plt.subplots()
    sc.pl.umap(cluster_data, color='timepoint', title=f'Cluster {cluster}', ax=ax)
    plt.show()

In [None]:
# separate pre and on cells

pre_cells = cluster_data.obs['timepoint'] == 'Pre'
on_cells = cluster_data.obs['timepoint'] == 'On'

pre_data = cluster_data[pre_cells, :]
on_data = cluster_data[on_cells, :]

In [None]:
adata_T_bulk = scRNA2PseudoBulkAnnData(adata_T, sample_id_col = 'sample_id')

In [None]:
adata_T_bulk.obs

In [None]:
adata_T_ttest = paird_ttest(adata_T, condition_key = 'timepoint', sample_id_col = 'sample_id', patient_id_col = 'patient_id', pval_cutoff = 0.05, log2fc_cutoff = 1)

In [None]:
adata_T_ttest

In [None]:
adata_T_ttest[(adata_T_ttest['pval'] < 0.05)]    

In [None]:
for c in adata_T.obs['leiden'].unique():
    print("Plotting cluster: ", c)
    plotGEM(adata_T, 'leiden', c, ncols=4)


In [None]:
for c in adata_T.obs['cell_type'].unique():
    print ("Plotting cell type: ", c)
    plotGEMs(adata_T, 'cell_type', c, ncols=4)

In [None]:
for c in adata_T.obs['cell_type'].unique():
    print ("Plotting cell type: ", c)
    plotGEM(adata_T, 'cell_type', c, ncols=4)

In [None]:
# scatterplot

# filter only significant results-- pval < 0.05 and |log2fc| > 1
significant_results = adata_T_ttest[(adata_T_ttest['pval'] < 0.05) & (np.abs(adata_T_ttest['log2fc']) > 1)]

# plot
log2fc_values = significant_results['log2fc'].values.astype(np.float64)
pval_values = significant_results['pval'].values.astype(np.float64)

plt.scatter(log2fc_values, -np.log10(pval_values), alpha=0.5)

plt.xlabel('Log2 Fold Change')
plt.ylabel('-log10(p-value)')
plt.show()


In [None]:
print(significant_results[:10])

In [None]:
# violin plot

# Convert 'mean_condition1' and 'mean_condition2' columns to numeric arrays
adata_T_ttest['mean_condition1'] = pd.to_numeric(adata_T_ttest['mean_condition1'], errors='coerce')
adata_T_ttest['mean_condition2'] = pd.to_numeric(adata_T_ttest['mean_condition2'], errors='coerce')

# Filter out rows with missing values in mean_condition1 and mean_condition2
valid_rows = ~np.isnan(adata_T_ttest['mean_condition1']) & ~np.isnan(adata_T_ttest['mean_condition2'])
filtered_data = adata_T_ttest[valid_rows]

# Violin plot
plt.figure(figsize=(8, 6))
plt.violinplot([filtered_data['mean_condition1'], filtered_data['mean_condition2']], showmeans=True)
plt.xticks([1, 2], ['Condition 1', 'Condition 2'])
plt.ylabel('Mean Expression')
plt.title('Violin Plot')
plt.show()


In [None]:
# make mono-color colormap
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

def create_custom_blue_colormap():
    # Define the colors and positions for the custom colormap
    colors = [(0.0, 'white'), (0.5, 'lightblue'), (1.0, 'blue')]  # Blue shades from white to blue
    cmap_name = 'custom_blue_colormap'
    
    # Create the colormap
    cmap = mcolors.LinearSegmentedColormap.from_list(cmap_name, colors)

    return cmap

In [None]:
find_cluster_DEGs_pairwise(adata_T, 'leiden', 'timepoint')

In [None]:
# make mono-color colormap
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

def create_custom_blue_colormap():
    # Define the colors and positions for the custom colormap
    colors = [(0.0, 'white'), (0.5, 'lightblue'), (1.0, 'blue')]  # Blue shades from white to blue
    cmap_name = 'custom_blue_colormap'
    
    # Create the colormap
    cmap = mcolors.LinearSegmentedColormap.from_list(cmap_name, colors)

    return cmap

In [None]:
### filter GEMs with columns < 100
nonZeroGEMs = adata_T.var_names[(np.sum(adata_T.X, axis=0) > 1000) & (np.sum(adata_T.X > 2, axis=0) > 500)].values
print(nonZeroGEMs)


In [None]:
blue_cmap = create_custom_blue_colormap()
sc.pl.umap(adata_T, color = nonZeroGEMs, use_raw = False, cmap=blue_cmap, ncols = 4, vmax=25)

# B cells

In [None]:
adata_Bcell = sc.read('/home/data/ICI_exprs/ICI_B_cell_collection/1863-counts_cells_cohort1_B_cells.h5ad')

In [None]:
adata_Bcell.obs

In [None]:
# B cells 

cell_type = 'B cells'

markers = {
    'B_cell' : ['CD19', 'CD20', 'CD79A', 'CD79B', 'MS4A1', 'IGHM', 'IGLC2', 'IGLC3', 'IGHG1'],
    'Plasma_cell' : ['CD38', 'CD138', 'XBP1', 'PRDM1', 'IRF4', 'MUM1'],
    'Memory_B_cell' : ['CD27', 'CD21', 'CD23', 'CD24', 'CD5'],
    'Naive_B_cell' : ['CD27', 'CD21', 'CD23', 'CD24', 'CD5', 'CD38'],
    'Germinal_center_B_cell' : ['BCL6', 'PAX5', 'CD10', 'CD38'],
    'Follicular_B_cell' : ['CD21', 'CD35', 'CXCR4', 'CD23'],
    'Marginal_zone_B_cell' : ['CD27', 'CD21', 'CD35', 'IgM', 'IgD'],
    'B1_cell' : ['CD20', 'CD43', 'CD5', 'IgM', 'IgD'],
    'B_regulatory_cell' : ['CD19', 'CD20', 'CD24', 'CD38', 'CD5', 'CD27', 'CD1d', 'CD21'],
    'Plasmablast' : ['CD38', 'CD138', 'IRF4', 'XBP1', 'PRDM1', 'MUM1'],
    'Transitional_B_cell' : ['CD10', 'CD24', 'CD38', 'CD21', 'CD23'],
    'IgM_B_cell' : ['IgM'],
    'IgD_B_cell' : ['IgD']
}



adata_name = 'adata_B'

analyze_cell_type(adata_B, cell_type, markers, adata_name)

In [None]:
adata_B_ttest = paird_ttest(adata_B, condition_key = 'timepoint', sample_id_col = 'sample_id', patient_id_col = 'patient_id', pval_cutoff = 0.05, log2fc_cutoff = 1)

In [None]:
for c in adata_B.obs['leiden'].unique():
    print("Plotting cluster: ", c)
    plotGEMs(adata_T, 'leiden', c, ncols=4)

In [None]:
for c in adata_B.obs['cell_type'].unique():
    print ("Plotting cell type: ", c)
    plotGEMs(adata_T, 'cell_type', c, ncols=4)