In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scanpy as sc
import scipy.io as sio
import anndata as ad
import seaborn as sns
import os as os
import sys as sys
sys.path.append('/home/qiuaodon/Desktop/PanCancer_scRNA_analysis/utils/')
from scRNA_utils import *
import operator as op
import matplotlib.colors as mcolors

# load in sc data

In [2]:
data_dir_NHDP = "/home/qiuaodon/Desktop/project_data_new/"
adata_T = sc.read(data_dir_NHDP + '1863-counts_cells_cohort1_T_cells.h5ad')
adata_1 = adata_T[adata_T[:,'PDCD1'].X > 0, :]
adata_Endo = sc.read(data_dir_NHDP + '1863-counts_cells_cohort1_Endo_cells.h5ad')
adata_2 = adata_Endo

In [None]:
sc.pl.violin(adata_1, 'PDCD1', size= 3)

In [None]:
adata_Endo

In [32]:
colors = ["grey", "blue"]  # Start with grey and end with blue
cmap = mcolors.LinearSegmentedColormap.from_list("grey_to_blue", colors)


In [None]:
sc.pl.umap(adata_1, color='GEM_T_1', cmap='Blues',  size = 10)
sc.pl.umap(adata_1, color='ANXA1', cmap='Blues',  size = 10)

In [None]:
adata_1_pseudo = scRNA2PseudoBulkAnnData(adata_1, sample_id_col='sample_id')
adata_2_pseudo = scRNA2PseudoBulkAnnData(adata_2, sample_id_col='sample_id')

# prepare data matrix for IVtest

In [None]:
DEG_1 = paird_ttest(adata_1, condition_key = 'timepoint', sample_id_col = 'sample_id', patient_id_col = 'patient_id')
DEG_2 = paird_ttest(adata_2, condition_key = 'timepoint', sample_id_col = 'sample_id', patient_id_col = 'patient_id')

In [5]:
DEG_1 = DEG_1[DEG_1['pval'] < 0.05]
DEG_2 = DEG_2[DEG_2['pval'] < 0.05]
gene_1 = DEG_1.index.tolist()
gene_2 = DEG_2.index.tolist()

# Filter out invalid gene names
gene_1 = [gene for gene in gene_1 if gene in adata_1_pseudo.var_names]
gene_2 = [gene for gene in gene_2 if gene in adata_2_pseudo.var_names]

gene_1_matrix = adata_1_pseudo[:, gene_1].X
gene_2_matrix = adata_2_pseudo[:, gene_2].X
gene_1_df = pd.DataFrame(gene_1_matrix, columns=gene_1, index=adata_1_pseudo.obs['sample_id'])
gene_2_df = pd.DataFrame(gene_2_matrix, columns=gene_2, index=adata_2_pseudo.obs['sample_id'])
gene_1_df.columns = [i + '_T' for i in gene_1_df.columns]
gene_2_df.columns = [i + '_Endo' for i in gene_2_df.columns]
gene_df = gene_1_df
gene_df = pd.merge(gene_df, gene_2_df, on='sample_id')

In [None]:
gene_df['treatment'] = gene_df.index.str.contains('On').astype(int)
# get patient id from sample id remove the _On or _Pre
gene_df['patient_id'] = gene_df.index.str.replace('_On', '').str.replace('_Pre', '')
gene_df['treatment'].value_counts()

In [9]:
# export the gene_df to csv
gene_df.to_csv(data_dir_NHDP + 'gene_df_PD1_Endo.csv')
# gene_df.to_csv(data_dir_NHDP + 'gene_df_CD4EX_Endo_filtered.csv')

# load in IV result and perform CIT

In [7]:
DEG_pairs = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/IV_regression_results_PD1vsEndo.xlsx')
# DEG_pairs = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/IV_regression_results_CD4EXvsEndo_filtered.xlsx')
# drop the p_value column and r_squared column
DEG_pairs = DEG_pairs.drop(columns=['p_value', 'r_squared'])
# set the first two columns as g1 and g2 in DEG_pairs
DEG_pairs = DEG_pairs.rename(columns = {'gene_T':'g1', 'gene_cell':'g2'})

In [None]:
DEG_pairs

In [None]:
# because some samples are missing, we need to add them to the dataset
# to make sure nosamplefiltered
from scRNA_utils import add_missing_samples
import anndata
# Extract sample_ids
sample_ids_1 = set(adata_1_pseudo.obs['sample_id'])
sample_ids_2 = set(adata_2_pseudo.obs['sample_id'])

# Identify missing sample_ids
missing_sample_ids_1 = sample_ids_2 - sample_ids_1
missing_sample_ids_2 = sample_ids_1 - sample_ids_2

# Add missing sample_ids to each dataset
adata_1_pseudo = add_missing_samples(adata_1_pseudo, missing_sample_ids_1)
adata_2_pseudo = add_missing_samples(adata_2_pseudo, missing_sample_ids_2)
# Sort the obs DataFrame by sample_id in both AnnData objects
adata_1_pseudo = adata_1_pseudo[adata_1_pseudo.obs.sort_values('sample_id').index]
adata_2_pseudo = adata_2_pseudo[adata_2_pseudo.obs.sort_values('sample_id').index]

# Verify that both datasets have the same sample_ids and in the same order
assert (adata_1_pseudo.obs['sample_id'].values == adata_2_pseudo.obs['sample_id'].values).all()

# Now adata_T_pseudo and adata_Endo_pseudo have the same sample_ids in the same order

In [None]:

from scRNA_utils import calculate_gene_correlation, calculate_geneLR_correlation, calculate_residuals_correlation

calculate_gene_correlation(DEG_pairs, adata_1_pseudo, adata_2_pseudo)

In [None]:
# Combine the two conditions using the logical OR operator
DEG_pairs = DEG_pairs[(DEG_pairs['g1vsg2_correlation'] > 0.3) | (DEG_pairs['g1vsg2_correlation'] < -0.3)]

# sort the DEG_pairs by g1vsg2_correlation
DEG_pairs = DEG_pairs.sort_values(by='g1vsg2_correlation', ascending=False)
DEG_pairs

In [None]:
g1 = list(DEG_pairs['g1']) 
g2 = list(DEG_pairs['g2']) 
g1_in_CD4EX = [i for i in gene_df.columns if np.sum(gene_df.loc[:, i] == 0) < 15] 
g2_in_Endo = [i for i in gene_df.columns if np.sum(gene_df.loc[:, i] == 0) < 15] 
DEG_pairs = DEG_pairs[DEG_pairs['g1'].isin(g1_in_CD4EX) & DEG_pairs['g2'].isin(g2_in_Endo)] 
DEG_pairs 

In [None]:
lrpair = pd.read_csv('/home/qiuaodon/Desktop/project_data_new/lr_network_unique.tsv', sep='\t')
# get the lr pairs including ligand(from) and receptor(to) from the lrpair file with database as 'kegg'
lrpair = lrpair[['from', 'to']]
# change from as L and to as R
lrpair.columns = ['L', 'R']
lrpair

In [None]:
# check if the L and R in adata_CD4EX and adata_Mono separately and expression not all zero
L = list(lrpair['L'])
R = list(lrpair['R'])
L_in_1 = list(set(L).intersection(adata_1.var_names))
R_in_2 = list(set(R).intersection(adata_2.var_names))
L_in_1 = [x for x in L_in_1 if np.sum(adata_1_pseudo[:, x].X == 0) < 0.75 * adata_1_pseudo.shape[0]]
R_in_2 = [x for x in R_in_2 if np.sum(adata_2_pseudo[:, x].X == 0) < 0.75 * adata_2_pseudo.shape[0]]
print('L_in_CD4EX:', len(L_in_1))
print('R_in_2:', len(R_in_2))

In [None]:
# remove the LR pairs not in L_in_CD4EX and R_in_DC
lrpair = lrpair[lrpair['L'].isin(L_in_1)]
lrpair = lrpair[lrpair['R'].isin(R_in_2)]
print('lrpair_kegg:', lrpair.shape)

In [15]:
# check if CXCL13-ACKR1 in the LR_pairs
lrpair[(lrpair['L'] == 'CXCL13') & (lrpair['R'] == 'ACKR1')]
new_pair = pd.DataFrame({'L': ['CXCL13'], 'R': ['ACKR1']})
lrpair = pd.concat([lrpair, new_pair], ignore_index=True)

## CIT using fisher-z

In [None]:
lrpair

In [None]:
DEG_pairs

In [18]:
results = CIT_test(DEG_pairs, lrpair, adata_1_pseudo, adata_2_pseudo, gene_df, method = 'fisherz', p_value_threshold=0.05)

In [None]:
#calculate the correlation between g1 and g2
calculate_gene_correlation(results, adata_1_pseudo, adata_2_pseudo)
#calculate the residuals correlation between g1, g2 and LR
calculate_residuals_correlation(results, adata_1_pseudo, adata_2_pseudo)

In [None]:
results["R"].unique()

In [20]:
results.to_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_PD1vsEndo_fisherz_corr03.xlsx')

In [None]:
# results = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_CD4EXvsEndo_fisherz_nosamplefiltered.xlsx')

## CIT using KCI


In [None]:
DEG_pairs

In [None]:
lrpair

In [None]:
results = CIT_test(DEG_pairs, lrpair, adata_1_pseudo, adata_2_pseudo, gene_df, method = 'kci', p_value_threshold=0.05)

In [None]:
#calculate the correlation between g1 and g2
calculate_gene_correlation(results, adata_1_pseudo, adata_2_pseudo)
#calculate the residuals correlation between g1, g2 and LR
calculate_residuals_correlation(results, adata_1_pseudo, adata_2_pseudo)

In [None]:
# save results
results.to_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_CD4EXvsMyeloid_KCI_nosamplefiltered.xlsx')

## group the DEGpairs together using L and R

In [21]:
# results = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_CD4EXvsDC_fisherz_nosamplefiltered.xlsx')
# results = results.drop(columns=['Unnamed: 0'])
# # Remove the suffix '_X' from g1 and g2
results['g1'] = results['g1'].str.replace(r'_\w+', '', regex=True)
results['g2'] = results['g2'].str.replace(r'_\w+', '', regex=True)

In [23]:
# only keep the significant results with correlation > 0.3 or < -0.3
results = results[(results['g1vsg2_correlation'] > 0.3)]

In [24]:
def calculate_g1_L_correlation(results, adata_1_pseudo, method='pearson'):
    if method not in ['pearson', 'spearman']:
        raise ValueError("Method must be either 'pearson' or 'spearman'")
    
    results['g1vsL_correlation'] = 0

    for index, row in results.iterrows():
        g1 = row['g1'].split('_')[0]
        L = row['L']

        if g1 not in adata_1_pseudo.var_names:
            continue

        g1_data = adata_1_pseudo[:, g1].X
        g1_df = pd.DataFrame(g1_data, columns=[g1], index=adata_1_pseudo.obs['sample_id'])
        g1_df = g1_df.dropna()

        L_data = adata_1_pseudo[:, L].X
        L_df = pd.DataFrame(L_data, columns=[L], index=adata_1_pseudo.obs['sample_id'])
        L_df = L_df.dropna()

        if method == 'pearson':
            corr_g1_g2 = pearsonr(g1_df[g1], L_df[L])[0]
        elif method == 'spearman':
            corr_g1_g2 = spearmanr(g1_df[g1], L_df[L])[0]

        results.loc[index, 'g1vsL_correlation'] = corr_g1_g2
    
    return results

In [25]:
def calculate_g2_R_correlation(results, adata_2_pseudo, method='pearson'):
    if method not in ['pearson', 'spearman']:
        raise ValueError("Method must be either 'pearson' or 'spearman'")
    
    results['g2vsR_correlation'] = 0

    for index, row in results.iterrows():
        g2 = row['g2'].split('_')[0]
        R = row['R']

        if g2 not in adata_2_pseudo.var_names:
            continue

        g2_data = adata_2_pseudo[:, g2].X
        g2_df = pd.DataFrame(g2_data, columns=[g2], index=adata_2_pseudo.obs['sample_id'])
        g2_df = g2_df.dropna()

        R_data = adata_2_pseudo[:, R].X
        R_df = pd.DataFrame(R_data, columns=[R], index=adata_2_pseudo.obs['sample_id'])
        R_df = R_df.dropna()

        if method == 'pearson':
            corr = pearsonr(g2_df[g2], R_df[R])[0]
        elif method == 'spearman':
            corr = spearmanr(g2_df[g2], R_df[R])[0]

        results.loc[index, 'g2vsR_correlation'] = corr
    
    return results


In [None]:
calculate_g1_L_correlation(results, adata_1_pseudo)
calculate_g2_R_correlation(results, adata_2_pseudo)

In [27]:
# remove the rows with g1vsL_correlation < 0.3 
results = results[(results['g1vsL_correlation'] > 0.3) & (results['g1vsL_correlation'] < 0.99)]
# remove the rows with g2vsR_correlation < 0.3 
results = results[(results['g2vsR_correlation'] > 0.3) & (results['g2vsR_correlation'] < 0.99)]

In [None]:
# get the unique (L,R) pairs in results
lr_pairs = results[['L', 'R']]
# remove the LR pairs repeated less than 3 times in the lr_pairs
lr_pairs = lr_pairs.groupby(['L', 'R']).filter(lambda x: len(x) >= 2)
lr_pairs = lr_pairs.drop_duplicates()
lr_pairs

In [29]:
import pandas as pd
grouped_results = results.groupby(["L", "R"])

# Display grouped data
grouped_dict = { (ligand, receptor): group for (ligand, receptor), group in grouped_results }

# Create a new DataFrame to hold the GEM1 and GEM2 groupings
grouped_data = []
tracked_gem1 = set()
tracked_gem2 = set()

min_gene_num = 2

# Temporary list to hold data before creating DataFrame
temp_data = []

for (ligand, receptor), group in grouped_results:
    gem1 = group["g1"].drop_duplicates().tolist()
    gem2 = group["g2"].drop_duplicates().tolist()
    
    # Check the overlap between current GEM_g1 and previously tracked GEM_g1
    overlap_gem1 = len(set(gem1) & tracked_gem1)
    overlap_gem2 = len(set(gem2) & tracked_gem2)
    
    # Filter to include only those with at least three genes in GEM_g1 and GEM_g2
    # and where the overlap is below the maximum allowable threshold
    if len(gem1) > min_gene_num and len(gem2) > min_gene_num:
        temp_data.append({
            "L": ligand,
            "R": receptor,
            "g1": gem1,
            "g2": gem2,
            "Num_genes_g1": len(gem1),
            "Num_genes_g2": len(gem2)
        })
        # Update the tracked GEM_g1 and GEM_g2 sets
        tracked_gem1.update(gem1)
        tracked_gem2.update(gem2)

# Convert the temporary data to a DataFrame
grouped_df = pd.DataFrame(temp_data)

# Function to find overlapping rows
def find_overlapping_rows(df, column_name):
    overlap_info = []
    for index, row in df.iterrows():
        current_set = set(row[column_name])
        overlapping_rows = []
        for other_index, other_row in df.iterrows():
            if index != other_index:
                other_set = set(other_row[column_name])
                overlap = len(current_set & other_set)
                if overlap > len(current_set) / 3: # Set the threshold for overlap
                    overlapping_rows.append(other_index)
        overlap_info.append(overlapping_rows)
    return overlap_info

# Add columns for overlapping rows in GEM_g1 and GEM_g2
grouped_df["Overlapping_g1"] = find_overlapping_rows(grouped_df, "g1")
grouped_df["Overlapping_g2"] = find_overlapping_rows(grouped_df, "g2")


In [None]:
# rank each GEM_g1 by order of name inside
grouped_df['g1'] = grouped_df['g1'].apply(lambda x: sorted(x))
grouped_df['g2'] = grouped_df['g2'].apply(lambda x: sorted(x))
grouped_df

In [32]:
import numpy as np
import pandas as pd
from scipy.stats import pearsonr, spearmanr
# results include the columns 'g1', 'g2' as a list of genes and 'L', 'R' as the ligand and receptor
def calculate_GEM_correlation(results, adata_1_pseudo, adata_2_pseudo, method='pearson'):
    if method not in ['pearson', 'spearman']:
        raise ValueError("Method must be either 'pearson' or 'spearman'")
    
    results['g1vsg2_correlation'] = 0

    for index, row in results.iterrows():
        g1 = row['g1']
        g2 = row['g2']

        # Ensure g1 and g2 are lists
        if not isinstance(g1, list):
            g1 = [g1]
        if not isinstance(g2, list):
            g2 = [g2]

        # Check if all genes in g1 are present
        missing_genes_g1 = set(g1) - set(adata_1_pseudo.var_names)
        if missing_genes_g1:
            print(f"Genes {missing_genes_g1} not found in adata_1_pseudo")
            continue

        # Get the data for g1 and compute mean
        g1_data = adata_1_pseudo[:, g1].X
        g1_mean = np.mean(g1_data, axis=1)
        g1_df = pd.DataFrame(g1_mean, columns=['g1_mean'], index=adata_1_pseudo.obs['sample_id'])
        g1_df = g1_df.dropna()

        # Repeat for g2
        missing_genes_g2 = set(g2) - set(adata_2_pseudo.var_names)
        if missing_genes_g2:
            print(f"Genes {missing_genes_g2} not found in adata_2_pseudo")
            continue

        g2_data = adata_2_pseudo[:, g2].X
        g2_mean = np.mean(g2_data, axis=1)
        g2_df = pd.DataFrame(g2_mean, columns=['g2_mean'], index=adata_2_pseudo.obs['sample_id'])
        g2_df = g2_df.dropna()

        # Align the samples
        common_samples = g1_df.index.intersection(g2_df.index)
        g1_values = g1_df.loc[common_samples, 'g1_mean']
        g2_values = g2_df.loc[common_samples, 'g2_mean']

        # Calculate the correlation
        if method == 'pearson':
            corr_g1_g2 = pearsonr(g1_values, g2_values)[0]
        elif method == 'spearman':
            corr_g1_g2 = spearmanr(g1_values, g2_values)[0]

        results.loc[index, 'g1vsg2_correlation'] = corr_g1_g2
    
    return results


In [31]:
import numpy as np
import pandas as pd
import statsmodels.api as sm
from scipy.stats import pearsonr, spearmanr
import scipy.sparse

def calculate_residuals_GEM_correlation(results, adata_T_pseudo, adata_M_pseudo, method='pearson'):
    if method not in ['pearson', 'spearman']:
        raise ValueError("Method must be either 'pearson' or 'spearman'")
    
    results['g1_residualvsg2_residuals_correlation'] = 0

    for index, row in results.iterrows():
        g1 = row['g1']
        g2 = row['g2']
        L = row['L']
        R = row['R']

        # Ensure g1, g2, L, and R are lists
        if not isinstance(g1, list):
            g1 = [g1]
        if not isinstance(g2, list):
            g2 = [g2]
        if not isinstance(L, list):
            L = [L]
        if not isinstance(R, list):
            R = [R]

        # Check for missing genes in adata_T_pseudo and adata_M_pseudo
        missing_genes_g1 = set(g1) - set(adata_T_pseudo.var_names)
        missing_genes_g2 = set(g2) - set(adata_M_pseudo.var_names)
        missing_genes_L = set(L) - set(adata_T_pseudo.var_names)
        missing_genes_R = set(R) - set(adata_M_pseudo.var_names)

        if missing_genes_g1:
            print(f"Genes {missing_genes_g1} not found in adata_T_pseudo")
            continue
        if missing_genes_g2:
            print(f"Genes {missing_genes_g2} not found in adata_M_pseudo")
            continue
        if missing_genes_L:
            print(f"Genes {missing_genes_L} not found in adata_T_pseudo")
            continue
        if missing_genes_R:
            print(f"Genes {missing_genes_R} not found in adata_M_pseudo")
            continue

        # Extract and average expression data for g1 in adata_T_pseudo
        g1_T_data = adata_T_pseudo[:, g1].X
        if scipy.sparse.issparse(g1_T_data):
            g1_T_data = g1_T_data.toarray()
        g1_T_mean = np.mean(g1_T_data, axis=1)

        # Extract and average expression data for g2 in adata_M_pseudo
        g2_M_data = adata_M_pseudo[:, g2].X
        if scipy.sparse.issparse(g2_M_data):
            g2_M_data = g2_M_data.toarray()
        g2_M_mean = np.mean(g2_M_data, axis=1)

        # Extract and average expression data for L in adata_T_pseudo
        L_T_data = adata_T_pseudo[:, L].X
        if scipy.sparse.issparse(L_T_data):
            L_T_data = L_T_data.toarray()
        L_T_mean = np.mean(L_T_data, axis=1)

        # Extract and average expression data for R in adata_M_pseudo
        R_M_data = adata_M_pseudo[:, R].X
        if scipy.sparse.issparse(R_M_data):
            R_M_data = R_M_data.toarray()
        R_M_mean = np.mean(R_M_data, axis=1)

        # Compute the product of L_T_mean and R_M_mean
        L_R = L_T_mean * R_M_mean

        # Create DataFrames with sample indices
        df_g1 = pd.DataFrame({'g1_mean': g1_T_mean, 'L*R': L_R}, index=adata_T_pseudo.obs_names)
        df_g2 = pd.DataFrame({'g2_mean': g2_M_mean, 'L*R': L_R}, index=adata_M_pseudo.obs_names)

        # Drop NaN values
        df_g1 = df_g1.dropna()
        df_g2 = df_g2.dropna()

        # Align samples by common indices
        common_samples = df_g1.index.intersection(df_g2.index)
        df_g1 = df_g1.loc[common_samples]
        df_g2 = df_g2.loc[common_samples]

        if df_g1.empty or df_g2.empty:
            print(f"No common samples between g1 and g2 at index {index}")
            continue

        # Linear regression to obtain residuals for g1
        X_g1 = sm.add_constant(df_g1['L*R'])
        y_g1 = df_g1['g1_mean']
        model_g1 = sm.OLS(y_g1, X_g1).fit()
        residuals_g1 = model_g1.resid

        # Linear regression to obtain residuals for g2
        X_g2 = sm.add_constant(df_g2['L*R'])
        y_g2 = df_g2['g2_mean']
        model_g2 = sm.OLS(y_g2, X_g2).fit()
        residuals_g2 = model_g2.resid

        # Calculate correlation between residuals
        if method == 'pearson':
            corr = pearsonr(residuals_g1, residuals_g2)[0]
        elif method == 'spearman':
            corr = spearmanr(residuals_g1, residuals_g2)[0]

        results.at[index, 'g1_residualvsg2_residuals_correlation'] = corr

    return results


In [41]:
# remove _0 and _1 in index of adata_2_pseudo
adata_2_pseudo.obs.index = adata_2_pseudo.obs.index.str.replace('-0', '').str.replace('-1', '')
adata_1_pseudo.obs.index = adata_1_pseudo.obs.index.str.replace('-0', '').str.replace('-1', '')

In [None]:
#calculate the correlation between g1 and g2
calculate_GEM_correlation(grouped_df, adata_1_pseudo, adata_2_pseudo)
#calculate the residuals correlation between g1, g2 and LR
calculate_residuals_GEM_correlation(grouped_df, adata_1_pseudo, adata_2_pseudo)

In [43]:
# save the grouped_df to excel
grouped_df.to_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_PD1vsEndo_grouped_0918.xlsx')

## draw heatmap

In [None]:
top_genes_M = [
    'RNF144B',
    'SLC1A3',
    
    'TNFAIP3',
    'CH25H',
    'FKBP5',
    'CLN8',
    'GRASP',
    'SMIM3',
    'DUSP6',
    'HIST1H4H'
]
top_genes_T = [
    'CH25H',
    'TSC22D3',
    'TNFAIP3',
    'BIRC3',
    'SPON2',
    'ZFP36',
    'SLA',
    'ZFP36',
    'TXNIP',
    'IER2',
    'KLF10'
]

In [None]:
import seaborn as sns
import scipy.stats as stats
from scipy.stats import spearmanr 
# Remove the sample_id that are not in the adata_T_pseudo
adata_M_pseudo = adata_M_pseudo[adata_M_pseudo.obs['sample_id'].isin(adata_T_pseudo.obs['sample_id'])]
adata_T_pseudo = adata_T_pseudo[adata_T_pseudo.obs['sample_id'].isin(adata_M_pseudo.obs['sample_id'])]

# Initialize an empty matrix to store the correlation values
corr_matrix = np.zeros((len(top_genes_T), len(top_genes_M)))
pval_matrix = np.zeros((len(top_genes_T), len(top_genes_M)))

# Calculate the correlation between the top genes of T and M cells
for i, gene_T in enumerate(top_genes_T):
    for j, gene_M in enumerate(top_genes_M):
        # remove the '_T' and '_M' from the gene names
        gene_T = gene_T.replace('_T', '')
        gene_M = gene_M.replace('_M', '')
        # Get the expression data of the two genes and flatten them to 1D arrays
        gene_T_expression = adata_T_pseudo[:, adata_T_pseudo.var_names == gene_T].X.flatten()
        gene_M_expression = adata_M_pseudo[:, adata_M_pseudo.var_names == gene_M].X.flatten()
        
        # Calculate the correlation
        corr, pval = spearmanr(gene_T_expression, gene_M_expression)
        
        # Store the correlation value in the matrix
        corr_matrix[i, j] = corr
        pval_matrix[i, j] = -np.log10(pval)


# Plot the correlation matrix as a heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=False, cmap='coolwarm', vmin=-1, vmax=1,
            xticklabels=top_genes_M, yticklabels=top_genes_T)
plt.title('Correlation between Highly Correlated Genes of CD4EX and M')
plt.xlabel('M Genes')
plt.ylabel('CD4EX Genes')
plt.show()

In [None]:

# Plotting the heatmap with correlation values
plt.figure(figsize=(12, 10))
heatmap = sns.heatmap(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1,
                      xticklabels=top_genes_M, yticklabels=top_genes_T,
                      annot=False)  # Turn off default annotation

# Overlaying the correlation values and -log10 of p-values
for i in range(corr_matrix.shape[0]):
    for j in range(corr_matrix.shape[1]):
        plt.text(j + 0.5, i + 0.5, f'{corr_matrix[i, j]:.2f}\n({pval_matrix[i, j]:.2f})',
                 horizontalalignment='center', verticalalignment='center', fontsize=11, color='black' if abs(corr_matrix[i, j]) > 0.5 else 'black')

plt.title('Correlation between Top DEGs of CD4EX and M')
plt.xlabel('M Genes')
plt.ylabel('CD4EX Genes')
plt.show()

In [None]:
colors = ["grey", "blue"]  # Start with grey and end with blue
cmap = mcolors.LinearSegmentedColormap.from_list("grey_to_blue", colors)


In [None]:
sc.pl.umap(adata_CD4EX, color = ['SPON2', 'SELL'], color_map = cmap)

In [None]:
sc.pl.umap(adata_M, color = ['FKBP5', 'MADCAM1'], color_map = cmap)

In [None]:
sc.pl.umap(adata_CD4EX, color = 'timepoint')
sc.pl.umap(adata_M, color = 'timepoint')

#### plot the scatter plots


In [9]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import statsmodels.api as sm
from scipy.stats import spearmanr, pearsonr

def plot_CIT_DEGcorr(g1, g2, L, R, DEG_exp, pseudo_1, pseudo_2):
    # Define a custom palette
    palette = {'pre': "#E69F00", 'on': "#56B4E9"}
    #replace the treatment column 0 pre and 1 on
    gene_df['treatment'] = gene_df['treatment'].replace({0: 'pre', 1: 'on'})
    # Set font sizes
    title_fontsize = 14
    axis_fontsize = 14
    legend_fontsize = 12

    def adjust_plot_limits(ax, x_data, y_data):
        x_min, x_max = x_data.min(), x_data.max()
        y_min, y_max = y_data.min(), y_data.max()
        x_padding = (x_max - x_min) * 0.1
        y_padding = (y_max - y_min) * 0.1
        ax.set_xlim(x_min - x_padding, x_max + x_padding)
        ax.set_ylim(y_min - y_padding, y_max + y_padding)

    # Plot the jointplot between DEG of g1 and g2
    g1_exp = DEG_exp[g1]
    g2_exp = DEG_exp[g2]
    df1 = pd.DataFrame({g1: g1_exp, g2: g2_exp, 'treatment': DEG_exp['treatment']})
    df1 = df1.dropna()
    g = sns.jointplot(x=g1, y=g2, data=df1, hue='treatment', palette=palette, kind='scatter', marginal_kws=dict(fill=True))
    sns.regplot(x=g1, y=g2, data=df1, scatter=False, ax=g.ax_joint, color='black')
    adjust_plot_limits(g.ax_joint, df1[g1], df1[g2])
    
    g.ax_joint.tick_params(left=False, bottom=False)
    g.ax_marg_x.tick_params(bottom=False)
    g.ax_marg_y.tick_params(left=False)
    cor = pearsonr(df1[g1], df1[g2])[0]
    plt.suptitle(f'Correlation between {g1} and {g2} is {cor:.3f}', fontsize=title_fontsize)
    plt.subplots_adjust(top=0.95)
    plt.show()

     # Plot the jointplot between DEG of g1 and g2 residue
    L_values = pseudo_1[:, L].X.toarray().flatten()
    R_values = pseudo_2[:, R].X.toarray().flatten()
    L_R = L_values * R_values
    X = sm.add_constant(L_R)

    # Fit OLS model for g2
    model_g2 = sm.OLS(g2_exp, X).fit()
    residuals_g2 = model_g2.resid

    # Fit OLS model for g1
    model_g1 = sm.OLS(g1_exp, X).fit()
    residuals_g1 = model_g1.resid

    #   Create DataFrame for residuals
    df_res = pd.DataFrame({
        'G_1_residual': residuals_g1,
        'G_2_residual': residuals_g2,
        'treatment': DEG_exp['treatment']
    })

    # Plot the residuals
    g = sns.jointplot(x='G_1_residual', y='G_2_residual', data=df_res, hue='treatment', palette=palette, kind='scatter', marginal_kws=dict(fill=True))
    sns.regplot(x='G_1_residual', y='G_2_residual', data=df_res, scatter=False, ax=g.ax_joint, color='black')

    # Adjust plot limits if needed
    adjust_plot_limits(g.ax_joint, df_res['G_1_residual'], df_res['G_2_residual'])

    # Customize plot appearance
    g.ax_joint.tick_params(left=False, bottom=False)
    g.ax_marg_x.tick_params(bottom=False)
    g.ax_marg_y.tick_params(left=False)

    # Calculate and display correlation
    corr_res = pearsonr(df_res['G_1_residual'], df_res['G_2_residual'])[0]
    plt.suptitle(f'Correlation between residuals of G_1 and G_2 is {corr_res:.3f}', fontsize=title_fontsize)
    plt.subplots_adjust(top=0.95)
    plt.show()



In [None]:
plot_CIT_DEGcorr('GEM_T_3', 'GEM_M_3', 'ITGA4', 'TFRC', gene_df, adata_1_pseudo, adata_2_pseudo)