# Compare DEGs from EGAS000010040809 and GSE169246 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scanpy as sc
import scipy.io as sio
import anndata as ad
import os as os
import sys as sys
import seaborn as sns
from load_10X_matrices import load_10X_matrices
from scipy import stats
import argparse

In [None]:
import rpy2
import rpy2.robjects as robjects
from rpy2.robjects.packages import importr
from rpy2.robjects import pandas2ri

In [None]:
# does not work

import sys as sys
sys.path.append('/home/xinghua/projects/PanCancer_scRNA_analysis/utils/scRNA_utils')
from scRNA_utils import *

1) Merge two datasets 
2) Perform cluster using merged data 
3) For T cell Isolate cells express PDCD1 (PD-1) from data.  Exclude this requirement for other types of cells, because they do not express PDCD1 
4) Produce pseduobulk (scRNA2psuebulk) --> sample-by-gene matrix 

    a) Loop through each gene, extract list of 'pre' and 'on' sample 

    b) Do t-test for each cluster 
    
    c)Save results as df 

### 1. Merge two datasets

In [None]:
# datasets
adata_egas = sc.read('/home/data/ICI_exprs/ICI_T_cell_collection/1863-counts_cells_cohort1_T_cells.h5ad', cache = True)
adata_gse = sc.read_h5ad('/data/ICI_exprs/GSE169246/GSE169246_TNBC_RNA.h5ad')

In [None]:
# egas
adata_egas.obs['sample_id'] = ['_'.join(x.split('_')[:3]) for x in adata_egas.obs.index]

In [None]:
# gse metadata
adata_gse.var_names = pd.read_csv('/data/ICI_exprs/GSE169246/GSE169246_TNBC_RNA_features.tsv.gz', header=None, sep='\t')[0]

In [None]:
adata.var_names_make_unique()

In [None]:
# replace 'on' and 'prog' with 'On', and turn 'pre' into 'Pre'
adata_gse.obs['timepoint'] = adata_gse.obs['timepoint'].replace({'pre': 'Pre', 'prog': 'On', 'on': 'On'})


In [None]:
adata_egas.obs

In [None]:
adata_gse.obs

In [None]:
# concatenate
adata_merged = sc.concat([adata_egas, adata_gse], join="outer")

In [None]:
adata_merged.obs

## preprocessing

In [None]:
# removing genes expressing in <10 cells
sc.pp.filter_genes(adata_merged, min_cells = 50)

# removing cells with fewer than 400 genes or more than 8000 genes
sc.pp.filter_cells(adata_merged, min_genes=400)
sc.pp.filter_cells(adata_merged, max_genes=8000)

In [None]:
# removing cells containing <600 || >120000 UMIs
sc.pp.filter_cells(adata_merged, min_counts = 600)
sc.pp.filter_cells(adata_merged, max_counts = 120000)

In [None]:
# label genes as mt
adata_merged.var['mt'] = adata_merged.var_names.str.startswith('MT-')  

# annotate cells with the percent of genes assigned as mt
sc.pp.calculate_qc_metrics(adata_merged, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

# Here we keep cells with < 20% mito ratio
adata_merged = adata_merged[adata_merged.obs['pct_counts_mt'] < 10, :]
adata_merged.shape

In [None]:
# Log normalization scaled up to 10000
sc.pp.normalize_total(adata_merged, target_sum=1e4)

In [None]:
# Logarithmize adata
sc.pp.log1p(adata_merged, base=2)

### 2) Perform cluster using merged data 

In [None]:
sc.pp.neighbors(adata_merged, n_neighbors=10, n_pcs=30)

In [None]:
# use the Leiden algorithm to find clusters
sc.tl.leiden(adata_merged, resolution=0.5)

adata_merged.raw does not exist-- create copy called adata that has adata.raw
- use adata from now

In [None]:
# adata_merged is without raw data
# create raw

adata = adata_merged.copy()  # create copy 

# create raw layer from existing .X attribute
adata.raw = adata  # This assigns the current .X as the raw layer

# adata.raw.var = adata.var # issues with adata.var
adata.raw.obs = adata.obs


In [None]:
def findDEGsFromClusters(adata, condition_col=None, condition_1=None, condition_2=None, top_n_degs=100):
    '''
    This function searches for clusters and then finds DEGs with each cluster conditioning on specified conditions.

    Parameters
    --------
    adata: AnnData object
        Annotated data matrix with rows for cells and columns for genes.
    condition_col: the column name of the condition in the adata.obs
    condition_1: the condition_1    
    condition_2: the condition_2
    top_n_degs: Number of top DEGs to consider for plotting

    Returns:
    --------
    DEGs: A dataframe with DEGs and their logFC, pval, pval_adj, etc.
    significant_genes_dict: A dictionary containing significant genes for each cluster.

    pseudocode:
    1. find clusters by call leiden or louvian by clustering_adata function
    2. loop through each cluster:
        2.1. extract cells belonging to the cluster (adata.copy())
        2.2. Call paird_ttest funciton using the adata_cluster find DEGs conditioning on the condition_1 and condition_2
        2.3. return the dataframe of DEGs
        2.4. find significant genes using sc.tl.rank_genes_groups
        2.5. plot UMAP for significant genes
        2.6. save a dataframe of significant DEGs
        2.7. plot volcano plot of DEGs

    '''

    # 1: find clusters using leiden or louvain by calling clustering_adata function
    if condition_col is None or condition_1 is None or condition_2 is None:
        print("Error: Missing condition information.")
        return None

    adata_clusters = clustering_adata(adata) # call clustering_adata function

    # 2: loop through each cluster, extract cells belonging to the cluster, and find DEGs
    clusters = adata_clusters.obs['leiden'].unique()
    result_dfs = []  # store DEG dataframes for each cluster
    significant_genes_df = {} # store dataframe for significant DEGs

    for cluster in clusters:
        print(f"Finding DEGs for cluster {cluster}")

        # 2.1. extrac cells belonging to the cluster (adata.copy())
        adata_cluster = adata_clusters[adata_clusters.obs['leiden'] == cluster].copy()

        # 2.2. Call paired_ttest function using the adata_cluster to find DEGs conditioning on condition_1 and condition_2
        DEGs_cluster = paird_ttest(adata_cluster, condition_key=condition_col, sample_id_col='sample_id', patient_id_col='patient_id', pval_cutoff=0.05, log2fc_cutoff=1)

        # 2.3. return the dataframe of DEGs
        if DEGs_cluster is not None:
            result_dfs.append(DEGs_cluster)
        
        # Create a copy of the original adata and apply log1p transformation to the adata_cluster
        adata_copy = adata.copy()
        sc.pp.log1p(adata_copy[adata_copy.obs.index.isin(adata_cluster.obs.index)])
                
        # 2.4. find significant genes using sc.tl.rank_genes_groups
        print(f"plotting significant genes for cluster {cluster}")

        # Set the 'base' value in adata.uns['log1p']
        adata_cluster.uns['log1p'] = {'base': 2} # check if already logged
        
        sc.tl.rank_genes_groups(adata_cluster, groupby=condition_col, method='wilcoxon')
        sc.pl.rank_genes_groups(adata_cluster, n_genes=25, sharey=False)


        print(f"DEGs: \n{DEGs_cluster}")

        # 2.5 some UMAPs
        sc.pp.neighbors(adata_cluster, n_neighbors=30, n_pcs=50)
        sc.tl.umap(adata_cluster)
        sc.pl.umap(adata_cluster, color=['cell_type', 'timepoint'], legend_loc='on data', title = f"cluster {cluster}")
        
        # UMAP for DEGs
        if not DEGs_cluster.empty:
            # Convert 'pval' column to numeric type
            DEGs_cluster['pval'] = pd.to_numeric(DEGs_cluster['pval'])
            
            top_n_degs_cluster = DEGs_cluster.nsmallest(top_n_degs, 'pval')
            sc.pl.umap(adata_cluster, color=top_n_degs_cluster.index.tolist(), use_raw=False, cmap='viridis', legend_loc='on data')


        # 2.6. save a dataframe of significant DEGs
        significant_genes_df = pd.DataFrame(columns=['pval', 'log2FC', 'mean_1', 'mean_2', 'qval'])
        
        if not DEGs_cluster.empty:
            significant_genes = DEGs_cluster[(DEGs_cluster['pval'] < 0.05) & (DEGs_cluster['qval'] < 0.1)]
            significant_genes_df = significant_genes_df.append(significant_genes, ignore_index=True)
            # seperate clusters and save as csv
            
        # 2.7 volcano plot: (still a WIP)

    # Combine all the DEG dataframes into a single DataFrame
    DEGs = pd.concat(result_dfs)

    return DEGs, significant_genes_df

In [None]:
def clustering_adata(adata, resolution = 0.5, n_top_genes = 5000):
    '''
    This function will cluster an AnnData object 

    Parameters:
        adata: AnnData object

    Returns:
        adata: AnnData object with a new column in adata.obs called 'leiden' that contains the cluster label for each cell
    '''

    # check if adata is AnnData object
    if not isinstance(adata, ad.AnnData):
        print ("Input adata is not an AnnData object")
        return None
    
    # check if adata has raw data
    if adata.raw is None:
        print ("Input adata does not have raw data")
        return None
    
    # check if adata has more than 1000 cells
    if adata.shape[0] < 1000:
        print ("Input adata has less than 1000 cells")
        return None
    
    # check if adata has more than 1000 genes
    if adata.shape[1] < 1000:
        print ("Input adata has less than 1000 genes")
        return None 
    
    # check if adata has already been through selection of high variance genes
    if not 'highly_variable' in adata.var.columns:
        print ("Select ", n_top_genes, " high variance genes")
        # select high veriable genes

        sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes)
        # filter adata
        adata = adata[:, adata.var['highly_variable']]

        # check if X is log transformed
        if not 'log1p' in adata.layers:
            sc.pp.normalize_total(adata, target_sum=1e4)
            sc.pp.log1p(adata, base = 2)

    # run PCA
    sc.tl.pca(adata, svd_solver='arpack', n_comps=50)   
    sc.pp.neighbors(adata, n_neighbors=50, n_pcs=50)
    sc.tl.leiden(adata, resolution = resolution)

    #plot UMAP
    sc.tl.umap(adata)
    sc.pl.umap(adata, color=['leiden'],  title='leiden')

    return adata

In [None]:
findDEGsFromClusters(adata, condition_col = 'timepoint', condition_1 = 'Pre', condition_2 = 'On', top_n_degs = 12)            

### 3) For T cell Isolate cells express PDCD1 (PD-1) from data.  Exclude this requirement for other types of cells, because they do not express PDCD1 

In [None]:
# Filter cells expressing PDCD1 above 0.7
adata_T_PD1 = adata[adata.obs['PDCD1'] > 0.7]

### 4) Produce pseduobulk (scRNA2psuebulk) --> sample-by-gene matrix 

In [None]:
def scRNA2PseudoBulkAnnData(adata, sample_id_col = 'sample_id'): 