In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scanpy as sc
import scipy.io as sio
import anndata as ad
import seaborn as sns
import os as os
import sys as sys
sys.path.append('/home/qiuaodon/Desktop/PanCancer_scRNA_analysis/utils/')
from scRNA_utils import *
import operator as op
import matplotlib.colors as mcolors

# load in sc data

In [2]:
data_dir_NHDP = "/home/qiuaodon/Desktop/project_data_new/"
adata_T = sc.read(data_dir_NHDP + '1863-counts_cells_cohort1_T_cells.h5ad')
adata_1 = adata_T[adata_T.obs['leiden'] == '3', :]
adata_B = sc.read(data_dir_NHDP + '1863-counts_cells_cohort1_B_cells.h5ad')
adata_2 = adata_B

In [None]:
sc.pl.umap(adata_1, color='leiden', legend_loc='on data')
sc.pl.umap(adata_2, color='leiden', legend_loc='on data')

In [None]:
adata_1_pseudo = scRNA2PseudoBulkAnnData(adata_1, sample_id_col='sample_id')
adata_2_pseudo = scRNA2PseudoBulkAnnData(adata_2, sample_id_col='sample_id')

In [None]:
# because some samples are missing, we need to add them to the dataset
# to make sure nosamplefiltered
from scRNA_utils import add_missing_samples
import anndata
# Extract sample_ids
sample_ids_1 = set(adata_1_pseudo.obs['sample_id'])
sample_ids_2 = set(adata_2_pseudo.obs['sample_id'])

# Identify missing sample_ids
missing_sample_ids_1 = sample_ids_2 - sample_ids_1
missing_sample_ids_2 = sample_ids_1 - sample_ids_2

# Add missing sample_ids to each dataset
adata_1_pseudo = add_missing_samples(adata_1_pseudo, missing_sample_ids_1)
adata_2_pseudo = add_missing_samples(adata_2_pseudo, missing_sample_ids_2)
# Sort the obs DataFrame by sample_id in both AnnData objects
adata_1_pseudo = adata_1_pseudo[adata_1_pseudo.obs.sort_values('sample_id').index]
adata_2_pseudo = adata_2_pseudo[adata_2_pseudo.obs.sort_values('sample_id').index]

# Verify that both datasets have the same sample_ids and in the same order
assert (adata_1_pseudo.obs['sample_id'].values == adata_2_pseudo.obs['sample_id'].values).all()

# Now adata_T_pseudo and adata_Endo_pseudo have the same sample_ids in the same order

# prepare data matrix for IVtest

In [5]:
adata_1_pseudo.raw = adata_1_pseudo
adata_2_pseudo.raw = adata_2_pseudo

In [None]:
DEG_1 = paird_ttest(adata_1, condition_key = 'timepoint', sample_id_col = 'sample_id', patient_id_col = 'patient_id')
DEG_2 = paird_ttest(adata_2, condition_key = 'timepoint', sample_id_col = 'sample_id', patient_id_col = 'patient_id')

In [7]:
DEG_1 = DEG_1[DEG_1['pval'] < 0.05]
DEG_2 = DEG_2[DEG_2['pval'] < 0.05]
gene_1 = DEG_1.index.tolist()
gene_2 = DEG_2.index.tolist()

# Filter out invalid gene names
gene_1 = [gene for gene in gene_1 if gene in adata_1_pseudo.var_names]
gene_2 = [gene for gene in gene_2 if gene in adata_2_pseudo.var_names]

gene_1_matrix = adata_1_pseudo[:, gene_1].X
gene_2_matrix = adata_2_pseudo[:, gene_2].X
gene_1_df = pd.DataFrame(gene_1_matrix, columns=gene_1, index=adata_1_pseudo.obs['sample_id'])
gene_2_df = pd.DataFrame(gene_2_matrix, columns=gene_2, index=adata_2_pseudo.obs['sample_id'])
gene_1_df.columns = [i + '_T' for i in gene_1_df.columns]
gene_2_df.columns = [i + '_B' for i in gene_2_df.columns]
gene_df = gene_1_df
gene_df = pd.merge(gene_df, gene_2_df, on='sample_id')

In [None]:
gene_df['treatment'] = gene_df.index.str.contains('On').astype(int)
# get patient id from sample id remove the _On or _Pre
gene_df['patient_id'] = gene_df.index.str.replace('_On', '').str.replace('_Pre', '')
gene_df['treatment'].value_counts()

In [12]:
# export the gene_df to csv
gene_df.to_csv(data_dir_NHDP + 'gene_df_CD4EX_B.csv')
# gene_df.to_csv(data_dir_NHDP + 'gene_df_CD4EX_Endo_filtered.csv')

# load in IV result and perform CIT

In [9]:
DEG_pairs = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/IV_regression_results_CD4EXvsB.xlsx')
# DEG_pairs = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/IV_regression_results_CD4EXvsEndo_filtered.xlsx')
# drop the p_value column and r_squared column
DEG_pairs = DEG_pairs.drop(columns=['p_value', 'r_squared'])
# set the first two columns as g1 and g2 in DEG_pairs
DEG_pairs = DEG_pairs.rename(columns = {'gene_T':'g1', 'gene_cell':'g2'})

In [None]:
DEG_pairs

In [None]:

from scRNA_utils import calculate_gene_correlation, calculate_geneLR_correlation, calculate_residuals_correlation

calculate_gene_correlation(DEG_pairs, adata_1_pseudo, adata_2_pseudo)

In [None]:
# Combine the two conditions using the logical OR operator
DEG_pairs = DEG_pairs[(DEG_pairs['g1vsg2_correlation'] > 0.3) | (DEG_pairs['g1vsg2_correlation'] < -0.3)]

# sort the DEG_pairs by g1vsg2_correlation
DEG_pairs = DEG_pairs.sort_values(by='g1vsg2_correlation', ascending=False)
DEG_pairs

In [None]:
# check if the g1 and g2 in adata_CD4EX and adata_Mono separately and expression not all zero
g1 = list(DEG_pairs['g1'])
g2 = list(DEG_pairs['g2'])
g1_in_CD4EX = [i for i in gene_df.columns if np.sum(gene_df.loc[:, i] == 0) < 20]
g2_in_Endo = [i for i in gene_df.columns if np.sum(gene_df.loc[:, i] == 0) < 20]
DEG_pairs = DEG_pairs[DEG_pairs['g1'].isin(g1_in_CD4EX) & DEG_pairs['g2'].isin(g2_in_Endo)]
DEG_pairs

In [None]:
lrpair = pd.read_csv('/home/qiuaodon/Desktop/project_data_new/lr_network_unique.tsv', sep='\t')
# get the lr pairs including ligand(from) and receptor(to) from the lrpair file with database as 'kegg'
lrpair = lrpair[['from', 'to']]
# change from as L and to as R
lrpair.columns = ['L', 'R']
lrpair

In [None]:
# check if the L and R in adata_CD4EX and adata_Mono separately and expression not all zero
L = list(lrpair['L'])
R = list(lrpair['R'])
L_in_1 = list(set(L).intersection(adata_1.var_names))
R_in_2 = list(set(R).intersection(adata_2.var_names))
L_in_1 = [x for x in L_in_1 if np.sum(adata_1_pseudo[:, x].X == 0) < 0.75 * adata_1_pseudo.shape[0]]
R_in_2 = [x for x in R_in_2 if np.sum(adata_2_pseudo[:, x].X == 0) < 0.75 * adata_2_pseudo.shape[0]]
print('L_in_CD4EX:', len(L_in_1))
print('R_in_2:', len(R_in_2))

In [None]:
# remove the LR pairs not in L_in_CD4EX and R_in_DC
lrpair = lrpair[lrpair['L'].isin(L_in_1)]
lrpair = lrpair[lrpair['R'].isin(R_in_2)]
print('lrpair_kegg:', lrpair.shape)

In [17]:
# check if CXCL13-ACKR1 in the LR_pairs
lrpair[(lrpair['L'] == 'CXCL13') & (lrpair['R'] == 'ACKR1')]
new_pair = pd.DataFrame({'L': ['CXCL13'], 'R': ['ACKR1']})
lrpair = pd.concat([lrpair, new_pair], ignore_index=True)


## CIT using fisher-z

In [None]:
lrpair

In [None]:
DEG_pairs

In [18]:
results = CIT_test(DEG_pairs, lrpair, adata_1_pseudo, adata_2_pseudo, gene_df, method = 'fisherz', p_value_threshold=0.05)

In [None]:
#calculate the correlation between g1 and g2
calculate_gene_correlation(results, adata_1_pseudo, adata_2_pseudo)
#calculate the residuals correlation between g1, g2 and LR
calculate_residuals_correlation(results, adata_1_pseudo, adata_2_pseudo)

In [20]:
# save results to excel
results.to_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_CD4EXvsB_fisherz_nosamplefiltered.xlsx')

In [33]:
results.to_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_CD4EXvsEndo_fisherz_filtered.xlsx')

In [24]:
# results = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_CD4EXvsEndo_fisherz_nosamplefiltered.xlsx')

## CIT using KCI

In [None]:
DEG_pairs

In [None]:
lrpair

In [None]:
results = CIT_test(DEG_pairs, lrpair, adata_1_pseudo, adata_2_pseudo, gene_df, method = 'kci', p_value_threshold=0.05)

In [None]:
#calculate the correlation between g1 and g2
calculate_gene_correlation(results, adata_1_pseudo, adata_2_pseudo)
#calculate the residuals correlation between g1, g2 and LR
calculate_residuals_correlation(results, adata_1_pseudo, adata_2_pseudo)

In [23]:
# save results
results.to_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_CD4EXvsB_KCI_nosamplefiltered.xlsx')

## group the DEGpairs together using L and R

In [21]:
# results = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_CD4EXvsDC_fisherz_nosamplefiltered.xlsx')
# results = results.drop(columns=['Unnamed: 0'])
# # Remove the suffix '_X' from g1 and g2
results['g1'] = results['g1'].str.replace(r'_\w+', '', regex=True)
results['g2'] = results['g2'].str.replace(r'_\w+', '', regex=True)

In [None]:
results

In [None]:
results['R'].unique()

In [None]:
results['R'].nunique()

In [None]:
import pandas as pd
grouped_results = results.groupby(["L", "R"])

# Display grouped data
grouped_dict = { (ligand, receptor): group for (ligand, receptor), group in grouped_results }

# Create a new DataFrame to hold the GEM1 and GEM2 groupings
grouped_data = []
tracked_gem1 = set()
tracked_gem2 = set()

min_gene_num = 3

# Temporary list to hold data before creating DataFrame
temp_data = []

for (ligand, receptor), group in grouped_results:
    gem1 = group["g1"].drop_duplicates().tolist()
    gem2 = group["g2"].drop_duplicates().tolist()
    
    # Check the overlap between current GEM_g1 and previously tracked GEM_g1
    overlap_gem1 = len(set(gem1) & tracked_gem1)
    overlap_gem2 = len(set(gem2) & tracked_gem2)
    
    # Filter to include only those with at least three genes in GEM_g1 and GEM_g2
    # and where the overlap is below the maximum allowable threshold
    if len(gem1) > min_gene_num :
        temp_data.append({
            "Ligand": ligand,
            "Receptor": receptor,
            "GEM_g1": gem1,
            "GEM_g2": gem2,
            "Num_genes_g1": len(gem1),
            "Num_genes_g2": len(gem2)
        })
        # Update the tracked GEM_g1 and GEM_g2 sets
        tracked_gem1.update(gem1)
        tracked_gem2.update(gem2)

# Convert the temporary data to a DataFrame
grouped_df = pd.DataFrame(temp_data)

# Function to find overlapping rows
def find_overlapping_rows(df, column_name):
    overlap_info = []
    for index, row in df.iterrows():
        current_set = set(row[column_name])
        overlapping_rows = []
        for other_index, other_row in df.iterrows():
            if index != other_index:
                other_set = set(other_row[column_name])
                overlap = len(current_set & other_set)
                if overlap > len(current_set) / 3: # Set the threshold for overlap
                    overlapping_rows.append(other_index)
        overlap_info.append(overlapping_rows)
    return overlap_info

# Add columns for overlapping rows in GEM_g1 and GEM_g2
grouped_df["Overlapping_g1"] = find_overlapping_rows(grouped_df, "GEM_g1")
grouped_df["Overlapping_g2"] = find_overlapping_rows(grouped_df, "GEM_g2")

grouped_df


In [27]:
# export the grouped_df to xlsx
grouped_df.to_excel('/home/qiuaodon/Desktop/project_data_new/grouped_GEM_CD4EXvsB_fisherz.xlsx')

## draw heatmap

In [None]:
top_genes_M = [
    'RNF144B',
    'SLC1A3',
    
    'TNFAIP3',
    'CH25H',
    'FKBP5',
    'CLN8',
    'GRASP',
    'SMIM3',
    'DUSP6',
    'HIST1H4H'
]
top_genes_T = [
    'CH25H',
    'TSC22D3',
    'TNFAIP3',
    'BIRC3',
    'SPON2',
    'ZFP36',
    'SLA',
    'ZFP36',
    'TXNIP',
    'IER2',
    'KLF10'
]

In [None]:
import seaborn as sns
import scipy.stats as stats
from scipy.stats import spearmanr 
# Remove the sample_id that are not in the adata_T_pseudo
adata_M_pseudo = adata_M_pseudo[adata_M_pseudo.obs['sample_id'].isin(adata_T_pseudo.obs['sample_id'])]
adata_T_pseudo = adata_T_pseudo[adata_T_pseudo.obs['sample_id'].isin(adata_M_pseudo.obs['sample_id'])]

# Initialize an empty matrix to store the correlation values
corr_matrix = np.zeros((len(top_genes_T), len(top_genes_M)))
pval_matrix = np.zeros((len(top_genes_T), len(top_genes_M)))

# Calculate the correlation between the top genes of T and M cells
for i, gene_T in enumerate(top_genes_T):
    for j, gene_M in enumerate(top_genes_M):
        # remove the '_T' and '_M' from the gene names
        gene_T = gene_T.replace('_T', '')
        gene_M = gene_M.replace('_M', '')
        # Get the expression data of the two genes and flatten them to 1D arrays
        gene_T_expression = adata_T_pseudo[:, adata_T_pseudo.var_names == gene_T].X.flatten()
        gene_M_expression = adata_M_pseudo[:, adata_M_pseudo.var_names == gene_M].X.flatten()
        
        # Calculate the correlation
        corr, pval = spearmanr(gene_T_expression, gene_M_expression)
        
        # Store the correlation value in the matrix
        corr_matrix[i, j] = corr
        pval_matrix[i, j] = -np.log10(pval)


# Plot the correlation matrix as a heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=False, cmap='coolwarm', vmin=-1, vmax=1,
            xticklabels=top_genes_M, yticklabels=top_genes_T)
plt.title('Correlation between Highly Correlated Genes of CD4EX and M')
plt.xlabel('M Genes')
plt.ylabel('CD4EX Genes')
plt.show()

In [None]:

# Plotting the heatmap with correlation values
plt.figure(figsize=(12, 10))
heatmap = sns.heatmap(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1,
                      xticklabels=top_genes_M, yticklabels=top_genes_T,
                      annot=False)  # Turn off default annotation

# Overlaying the correlation values and -log10 of p-values
for i in range(corr_matrix.shape[0]):
    for j in range(corr_matrix.shape[1]):
        plt.text(j + 0.5, i + 0.5, f'{corr_matrix[i, j]:.2f}\n({pval_matrix[i, j]:.2f})',
                 horizontalalignment='center', verticalalignment='center', fontsize=11, color='black' if abs(corr_matrix[i, j]) > 0.5 else 'black')

plt.title('Correlation between Top DEGs of CD4EX and M')
plt.xlabel('M Genes')
plt.ylabel('CD4EX Genes')
plt.show()

In [None]:
colors = ["grey", "blue"]  # Start with grey and end with blue
cmap = mcolors.LinearSegmentedColormap.from_list("grey_to_blue", colors)


In [None]:
sc.pl.umap(adata_CD4EX, color = ['SPON2', 'SELL'], color_map = cmap)

In [None]:
sc.pl.umap(adata_M, color = ['FKBP5', 'MADCAM1'], color_map = cmap)

In [None]:
sc.pl.umap(adata_CD4EX, color = 'timepoint')
sc.pl.umap(adata_M, color = 'timepoint')

#### plot the scatter plots


In [None]:
gene_df['GZMA_Endo']

In [None]:
plot_CIT_DEGcorr('RGS1_T', 'TCIM_Endo', 'XCL2', 'TACR2', gene_df, adata_T_pseudo, adata_Endo_pseudo)