In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scanpy as sc
import scipy.io as sio
import anndata as ad
import seaborn as sns
import os as os
import sys as sys
sys.path.append('/home/qiuaodon/Desktop/PanCancer_scRNA_analysis/utils/')
from scRNA_utils import *
import operator as op
import matplotlib.colors as mcolors

# load in sc data

In [2]:
data_dir_NHDP = "/home/qiuaodon/Desktop/project_data_new/"
adata_T = sc.read(data_dir_NHDP + '1863-counts_cells_cohort1_T_cells.h5ad')
adata_1 = sc.read(data_dir_NHDP + '1863-counts_cells_cohort1_T_cells.h5ad')
adata_Endo = sc.read(data_dir_NHDP + '1863-counts_cells_cohort1_Endo_cells.h5ad')
adata_2 = sc.read(data_dir_NHDP + '1863-counts_cells_cohort1_Endo_cells.h5ad')

In [8]:
projected_data_T = pd.read_csv('/home/qiuaodon/Desktop/project_data_new/T_data_projected.csv', index_col=0)
projected_data_T = projected_data_T.T
projected_data_M = pd.read_csv('/home/qiuaodon/Desktop/project_data_new/M_data_projected.csv', index_col=0)
projected_data_M = projected_data_M.T
adata_1.X = projected_data_T.values
adata_2.X = projected_data_M.values

In [None]:
projected_data_M

In [13]:
adata_1.X = projected_data_T.values
adata_2.X = projected_data_M.values

In [None]:
sc.pl.umap(adata_M, color =['ABCA1'], cmap = 'Blues')
sc.pl.umap(adata_2, color =['ABCA1'], cmap = 'Blues')

In [None]:
sc.pl.umap(adata_M, color =[ 'CD274', 'cell_type'], cmap = 'Blues')

In [None]:
sc.pl.umap(adata_1, color='leiden', legend_loc='on data')
sc.pl.umap(adata_2, color='leiden', legend_loc='on data')

In [None]:
adata_1_pseudo = scRNA2PseudoBulkAnnData(adata_1, sample_id_col='sample_id')
adata_2_pseudo = scRNA2PseudoBulkAnnData(adata_2, sample_id_col='sample_id')

In [None]:
sc.pl.violin(adata_M, 'CXCR4', size= 3)

In [None]:

gene_of_interest = 'CXCR4'
gene_data = projected_data_M.loc[gene_of_interest]

# Convert to a DataFrame suitable for seaborn
gene_data_df = pd.DataFrame({
    'Sample': gene_data.index,
    'Expression': gene_data.values
})

In [None]:
gene_data_df

In [None]:

# Create the violin plot
plt.figure(figsize=(10, 6))
sns.violinplot(y='Expression', data= gene_data_df, inner=None)
sns.stripplot(y='Expression', data= gene_data_df, size=2, jitter=True)
plt.title(f'Violin plot')
plt.xlabel('Samples')
plt.ylabel('Expression level')
plt.xticks(rotation=90)  # Rotate x-axis labels if there are many samples
plt.show()

In [None]:
# show the expression table of PRDM1 in pseudo-bulk data
PRDM1 = adata_2_pseudo[:, 'A2M'].X.tolist()
PRDM1

# prepare data matrix for IVtest

In [None]:
DEG_1 = paird_ttest(adata_1, condition_key = 'timepoint', sample_id_col = 'sample_id', patient_id_col = 'patient_id')
DEG_2 = paird_ttest(adata_2, condition_key = 'timepoint', sample_id_col = 'sample_id', patient_id_col = 'patient_id')

In [10]:
DEG_1 = DEG_1[DEG_1['pval'] < 0.05]
DEG_2 = DEG_2[DEG_2['pval'] < 0.05]
gene_1 = DEG_1.index.tolist()
gene_2 = DEG_2.index.tolist()

# Filter out invalid gene names
gene_1 = [gene for gene in gene_1 if gene in adata_1_pseudo.var_names]
gene_2 = [gene for gene in gene_2 if gene in adata_2_pseudo.var_names]

gene_1_matrix = adata_1_pseudo[:, gene_1].X
gene_2_matrix = adata_2_pseudo[:, gene_2].X
gene_1_df = pd.DataFrame(gene_1_matrix, columns=gene_1, index=adata_1_pseudo.obs['sample_id'])
gene_2_df = pd.DataFrame(gene_2_matrix, columns=gene_2, index=adata_2_pseudo.obs['sample_id'])
gene_1_df.columns = [i + '_T' for i in gene_1_df.columns]
gene_2_df.columns = [i + '_Endo' for i in gene_2_df.columns]
gene_df = gene_1_df
gene_df = pd.merge(gene_df, gene_2_df, on='sample_id')

In [None]:
gene_df['treatment'] = gene_df.index.str.contains('On').astype(int)
# get patient id from sample id remove the _On or _Pre
gene_df['patient_id'] = gene_df.index.str.replace('_On', '').str.replace('_Pre', '')
gene_df['treatment'].value_counts()

In [12]:
# export the gene_df to csv
gene_df.to_csv(data_dir_NHDP + 'gene_df_T_Endo.csv')
# gene_df.to_csv(data_dir_NHDP + 'gene_df_CD4EX_Endo_filtered.csv')

# load in IV result and perform CIT

In [27]:
DEG_pairs = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/IV_regression_results_TvsEndo.xlsx')
# DEG_pairs = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/IV_regression_results_CD4EXvsEndo_filtered.xlsx')
# drop the p_value column and r_squared column
DEG_pairs = DEG_pairs.drop(columns=['p_value', 'r_squared'])
# set the first two columns as g1 and g2 in DEG_pairs
DEG_pairs = DEG_pairs.rename(columns = {'gene_T':'g1', 'gene_cell':'g2'})

In [None]:
DEG_pairs

In [28]:
# because some samples are missing, we need to add them to the dataset
# to make sure nosamplefiltered
from scRNA_utils import add_missing_samples
import anndata
# Extract sample_ids
sample_ids_1 = set(adata_1_pseudo.obs['sample_id'])
sample_ids_2 = set(adata_2_pseudo.obs['sample_id'])

# Identify missing sample_ids
missing_sample_ids_1 = sample_ids_2 - sample_ids_1
missing_sample_ids_2 = sample_ids_1 - sample_ids_2

# Add missing sample_ids to each dataset
adata_1_pseudo = add_missing_samples(adata_1_pseudo, missing_sample_ids_1)
adata_2_pseudo = add_missing_samples(adata_2_pseudo, missing_sample_ids_2)
# Sort the obs DataFrame by sample_id in both AnnData objects
adata_1_pseudo = adata_1_pseudo[adata_1_pseudo.obs.sort_values('sample_id').index]
adata_2_pseudo = adata_2_pseudo[adata_2_pseudo.obs.sort_values('sample_id').index]

# Verify that both datasets have the same sample_ids and in the same order
assert (adata_1_pseudo.obs['sample_id'].values == adata_2_pseudo.obs['sample_id'].values).all()

# Now adata_T_pseudo and adata_Endo_pseudo have the same sample_ids in the same order

In [None]:

from scRNA_utils import calculate_gene_correlation, calculate_geneLR_correlation, calculate_residuals_correlation

calculate_gene_correlation(DEG_pairs, adata_1_pseudo, adata_2_pseudo)

In [None]:
# Combine the two conditions using the logical OR operator
DEG_pairs = DEG_pairs[(DEG_pairs['g1vsg2_correlation'] > 0.4) | (DEG_pairs['g1vsg2_correlation'] < -0.4)]

# sort the DEG_pairs by g1vsg2_correlation
DEG_pairs = DEG_pairs.sort_values(by='g1vsg2_correlation', ascending=False)
DEG_pairs

In [None]:
g1 = list(DEG_pairs['g1'])
g2 = list(DEG_pairs['g2'])
g1_in_CD4EX = [i for i in g1 if i in gene_df.columns and np.sum(gene_df.loc[:, i] == 0) < 20]
g2_in_Endo = [i for i in g2 if i in gene_df.columns and np.sum(gene_df.loc[:, i] == 0) < 20]
DEG_pairs = DEG_pairs[DEG_pairs['g1'].isin(g1_in_CD4EX) & DEG_pairs['g2'].isin(g2_in_Endo)]
DEG_pairs


In [None]:
lrpair = pd.read_csv('/home/qiuaodon/Desktop/project_data_new/lr_network_unique.tsv', sep='\t')
# get the lr pairs including ligand(from) and receptor(to) from the lrpair file with database as 'kegg'
lrpair = lrpair[['from', 'to']]
# change from as L and to as R
lrpair.columns = ['L', 'R']
lrpair

In [34]:
# check if CXCL13-ACKR1 in the LR_pairs
lrpair[(lrpair['L'] == 'CXCL13') & (lrpair['R'] == 'ACKR1')]
new_pair = pd.DataFrame({'L': ['CXCL13'], 'R': ['ACKR1']})
lrpair = pd.concat([lrpair, new_pair], ignore_index=True)

In [None]:
# check if the L and R in adata_CD4EX and adata_Mono separately and expression not all zero
L = list(lrpair['L'])
R = list(lrpair['R'])
L_in_1 = list(set(L).intersection(adata_1.var_names))
R_in_2 = list(set(R).intersection(adata_2.var_names))
L_in_1 = [x for x in L_in_1 if np.sum(adata_1_pseudo[:, x].X == 0) < 0.75 * adata_1_pseudo.shape[0]]
R_in_2 = [x for x in R_in_2 if np.sum(adata_2_pseudo[:, x].X == 0) < 0.75 * adata_2_pseudo.shape[0]]
print('L_in_CD4EX:', len(L_in_1))
print('R_in_2:', len(R_in_2))

In [None]:
# remove the LR pairs not in L_in_CD4EX and R_in_DC
lrpair = lrpair[lrpair['L'].isin(L_in_1)]
lrpair = lrpair[lrpair['R'].isin(R_in_2)]
print('lrpair_kegg:', lrpair.shape)

## CIT using fisher-z

In [None]:
lrpair

In [None]:
DEG_pairs

In [39]:
results = CIT_test(DEG_pairs, lrpair, adata_1_pseudo, adata_2_pseudo, gene_df, method = 'fisherz', p_value_threshold=0.05)

In [None]:
#calculate the correlation between g1 and g2
calculate_gene_correlation(results, adata_1_pseudo, adata_2_pseudo)
#calculate the residuals correlation between g1, g2 and LR
calculate_residuals_correlation(results, adata_1_pseudo, adata_2_pseudo)

In [None]:
results = results.sort_values(by='g1vsg2_correlation', ascending=False)
results

In [None]:
plot_CIT_DEGcorr('COL10A1_T', 'ITGA11_M', 'ITGB1', 'ITGA11', gene_df, adata_1_pseudo, adata_2_pseudo)

In [None]:
sc.pl.umap(adata_1, color = ['PRDM1','TXNIP'], cmap = 'Blues')

In [None]:
sc.pl.umap(adata_2, color = ['SESN1','CXCR4'], cmap = 'Blues')

In [None]:
results["R"].unique()

In [41]:
results.to_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_TvsEndo_fisherz_corrmorethan0.4.xlsx')

In [3]:
results = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_TvsEndo_fisherz_corrmorethan0.4.xlsx') 

In [None]:
plot_CIT_DEGcorr('BCL2L11_T', 'SOCS1_M', 'CXCL13', 'P2RY14', gene_df, adata_1_pseudo, adata_2_pseudo)

## CIT using KCI

In [None]:
DEG_pairs

In [None]:
lrpair

In [None]:
results = CIT_test(DEG_pairs, lrpair, adata_1_pseudo, adata_2_pseudo, gene_df, method = 'kci', p_value_threshold=0.05)

In [None]:
#calculate the correlation between g1 and g2
calculate_gene_correlation(results, adata_1_pseudo, adata_2_pseudo)
#calculate the residuals correlation between g1, g2 and LR
calculate_residuals_correlation(results, adata_1_pseudo, adata_2_pseudo)

In [19]:
# save results
results.to_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_TvsMyeloid_KCI.xlsx')

In [None]:
from causallearn.utils.cit import CIT
import pandas as pd

results = pd.DataFrame(columns=['g1', 'g2', 'L1', 'R1', 'L2', 'R2', 'pValue'])

for index, row in DEG_pairs.iterrows():
    g1 = row['g1']
    g2 = row['g2']
    if g2.endswith('_M'):
        for L1, R1 in lrpair.values:
            for L2, R2 in lrpair.values:
                if L1 == L2 and R1 == R2:
                    continue  # Skip if the pairs are the same
                
                # check if L1, R1, L2, and R2 can be found in the DEG expression data
                if L1 in adata_1_pseudo.var_names and R1 in adata_2_pseudo.var_names and L2 in adata_1_pseudo.var_names and R2 in adata_2_pseudo.var_names:
                    # get the L1_T, R1_M, L2_T, and R2_M from pseudo_T and pseudo_M
                    L1_T = adata_1_pseudo[:, L1].X
                    L1_T = pd.DataFrame(L1_T, columns=[L1], index=adata_1_pseudo.obs['sample_id'])
                    L1_T = L1_T.dropna()
                    
                    R1_M = adata_2_pseudo[:, R1].X
                    R1_M = pd.DataFrame(R1_M, columns=[R1], index=adata_2_pseudo.obs['sample_id'])
                    R1_M = R1_M.dropna()
                    
                    L2_T = adata_1_pseudo[:, L2].X
                    L2_T = pd.DataFrame(L2_T, columns=[L2], index=adata_1_pseudo.obs['sample_id'])
                    L2_T = L2_T.dropna()
                    
                    R2_M = adata_2_pseudo[:, R2].X
                    R2_M = pd.DataFrame(R2_M, columns=[R2], index=adata_2_pseudo.obs['sample_id'])
                    R2_M = R2_M.dropna()
                    
                    # calculate the L1*R1 and L2*R2
                    L1_T.columns = [R1 + '_vs_' + L1]
                    R1_M.columns = L1_T.columns
                    L1_R1 = L1_T.multiply(R1_M, axis=0)
                    L1_R1 = L1_R1.iloc[:, 0]
                    
                    L2_T.columns = [R2 + '_vs_' + L2]
                    R2_M.columns = L2_T.columns
                    L2_R2 = L2_T.multiply(R2_M, axis=0)
                    L2_R2 = L2_R2.iloc[:, 0]
                    
                    # make a dataframe g1, g2, L1*R1, and L2*R2
                    if g1 not in gene_df.columns or g2 not in gene_df.columns:
                        continue
                    
                    g1_T = gene_df[g1]
                    g2_M = gene_df[g2]
                    df = pd.DataFrame({g1: g1_T, g2: g2_M, 'L1*R1': L1_R1, 'L2*R2': L2_R2})
                    # remove the rows with NA
                    df = df.dropna()
                    
                    # convert df to numpy array
                    df_numpy = df.to_numpy()
                    
                    fisherz_obj = CIT(df_numpy, "fisherz")
                    pValue = fisherz_obj(0, 1, [2, 3])
                    if pValue > 0.01:
                        results = pd.concat([results, pd.DataFrame({'g1': [g1], 'g2': [g2], 'L1': [L1], 'R1': [R1], 'L2': [L2], 'R2': [R2], 'pValue': [pValue]})], ignore_index=True)

print(results)


In [114]:
# save the results
results.to_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_TvsMyeloid_fisherz_PRDM1_TSC22D3.xlsx')

## group the DEGpairs together using L and R

In [5]:
# results = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_CD4EXvsDC_fisherz_nosamplefiltered.xlsx')
# results = results.drop(columns=['Unnamed: 0'])
# # Remove the suffix '_X' from g1 and g2
results['g1'] = results['g1'].str.replace(r'_\w+', '', regex=True)
results['g2'] = results['g2'].str.replace(r'_\w+', '', regex=True)

In [None]:
results

In [None]:
results['R'].unique()

In [None]:
results['R'].nunique()

In [None]:
# group results by R
grouped_results = results.groupby("R")
# Display grouped data
grouped_dict = {receptor: group for receptor, group in grouped_results}

# Create a new DataFrame to hold the GEM1 and GEM2 groupings
grouped_data = []

for receptor, group in grouped_results:
    gem1 = group["g1"].drop_duplicates().tolist()
    gem2 = group["g2"].drop_duplicates().tolist()
    grouped_data.append({
        "Receptor": receptor,
        "GEM_g1": gem1,
        "GEM_g2": gem2,
        "Ligand": group["L"].drop_duplicates().tolist()
    })

# Convert the grouped data to a DataFrame
grouped_df = pd.DataFrame(grouped_data)
grouped_df

In [10]:
# save the grouped_df to excel
grouped_df.to_excel('/home/qiuaodon/Desktop/project_data_new/TvsEndo_grouped.xlsx')

In [2]:
# read in 
grouped_df = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_CD4EXvsMyeloid_grouped_virsion3.xlsx')

In [3]:
# # Remove the suffix '_X' from g1 and g2
grouped_df['GEM_g1'] = grouped_df['GEM_g1'].str.replace(r'_\w+', '', regex=True)
grouped_df['GEM_g2'] = grouped_df['GEM_g2'].str.replace(r'_\w+', '', regex=True)

In [14]:
# remove the [] in GEM_g1 and GEM_g2
grouped_df['GEM_g1'] = grouped_df['GEM_g1'].str.strip('[]')
grouped_df['GEM_g2'] = grouped_df['GEM_g2'].str.strip('[]')


In [None]:
grouped_df

In [None]:
sc.pl.umap(adata_1, color =['SESN1', 'NFKBIA', 'PMAIP1', 'IL6ST', 'TXNIP', 'PELI1', 'PRDM1', 'CRYBG1', 'TSC22D3', 'TP53INP1'])

## draw heatmap

In [None]:
top_genes_M = [
    'RNF144B',
    'SLC1A3',
    
    'TNFAIP3',
    'CH25H',
    'FKBP5',
    'CLN8',
    'GRASP',
    'SMIM3',
    'DUSP6',
    'HIST1H4H'
]
top_genes_T = [
    'CH25H',
    'TSC22D3',
    'TNFAIP3',
    'BIRC3',
    'SPON2',
    'ZFP36',
    'SLA',
    'ZFP36',
    'TXNIP',
    'IER2',
    'KLF10'
]

In [None]:
import seaborn as sns
import scipy.stats as stats
from scipy.stats import spearmanr 
# Remove the sample_id that are not in the adata_T_pseudo
adata_M_pseudo = adata_M_pseudo[adata_M_pseudo.obs['sample_id'].isin(adata_T_pseudo.obs['sample_id'])]
adata_T_pseudo = adata_T_pseudo[adata_T_pseudo.obs['sample_id'].isin(adata_M_pseudo.obs['sample_id'])]

# Initialize an empty matrix to store the correlation values
corr_matrix = np.zeros((len(top_genes_T), len(top_genes_M)))
pval_matrix = np.zeros((len(top_genes_T), len(top_genes_M)))

# Calculate the correlation between the top genes of T and M cells
for i, gene_T in enumerate(top_genes_T):
    for j, gene_M in enumerate(top_genes_M):
        # remove the '_T' and '_M' from the gene names
        gene_T = gene_T.replace('_T', '')
        gene_M = gene_M.replace('_M', '')
        # Get the expression data of the two genes and flatten them to 1D arrays
        gene_T_expression = adata_T_pseudo[:, adata_T_pseudo.var_names == gene_T].X.flatten()
        gene_M_expression = adata_M_pseudo[:, adata_M_pseudo.var_names == gene_M].X.flatten()
        
        # Calculate the correlation
        corr, pval = spearmanr(gene_T_expression, gene_M_expression)
        
        # Store the correlation value in the matrix
        corr_matrix[i, j] = corr
        pval_matrix[i, j] = -np.log10(pval)


# Plot the correlation matrix as a heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=False, cmap='coolwarm', vmin=-1, vmax=1,
            xticklabels=top_genes_M, yticklabels=top_genes_T)
plt.title('Correlation between Highly Correlated Genes of CD4EX and M')
plt.xlabel('M Genes')
plt.ylabel('CD4EX Genes')
plt.show()

In [None]:

# Plotting the heatmap with correlation values
plt.figure(figsize=(12, 10))
heatmap = sns.heatmap(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1,
                      xticklabels=top_genes_M, yticklabels=top_genes_T,
                      annot=False)  # Turn off default annotation

# Overlaying the correlation values and -log10 of p-values
for i in range(corr_matrix.shape[0]):
    for j in range(corr_matrix.shape[1]):
        plt.text(j + 0.5, i + 0.5, f'{corr_matrix[i, j]:.2f}\n({pval_matrix[i, j]:.2f})',
                 horizontalalignment='center', verticalalignment='center', fontsize=11, color='black' if abs(corr_matrix[i, j]) > 0.5 else 'black')

plt.title('Correlation between Top DEGs of CD4EX and M')
plt.xlabel('M Genes')
plt.ylabel('CD4EX Genes')
plt.show()

In [None]:
colors = ["grey", "blue"]  # Start with grey and end with blue
cmap = mcolors.LinearSegmentedColormap.from_list("grey_to_blue", colors)


In [None]:
sc.pl.umap(adata_CD4EX, color = ['SPON2', 'SELL'], color_map = cmap)

In [None]:
sc.pl.umap(adata_M, color = ['FKBP5', 'MADCAM1'], color_map = cmap)

In [None]:
sc.pl.umap(adata_CD4EX, color = 'timepoint')
sc.pl.umap(adata_M, color = 'timepoint')

#### plot the scatter plots


In [55]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import statsmodels.api as sm
from scipy.stats import spearmanr, pearsonr

def plot_CIT_DEGcorr(g1, g2, L, R, DEG_exp, pseudo_1, pseudo_2):
    # Define a custom palette
    palette = {'pre': "#E69F00", 'on': "#56B4E9"}
    #replace the treatment column 0 pre and 1 on
    gene_df['treatment'] = gene_df['treatment'].replace({0: 'pre', 1: 'on'})
    # Set font sizes
    title_fontsize = 14
    axis_fontsize = 14
    legend_fontsize = 12

    def adjust_plot_limits(ax, x_data, y_data):
        x_min, x_max = x_data.min(), x_data.max()
        y_min, y_max = y_data.min(), y_data.max()
        x_padding = (x_max - x_min) * 0.1
        y_padding = (y_max - y_min) * 0.1
        ax.set_xlim(x_min - x_padding, x_max + x_padding)
        ax.set_ylim(y_min - y_padding, y_max + y_padding)

    # Plot the jointplot between DEG of g1 and g2
    g1_exp = DEG_exp[g1]
    g2_exp = DEG_exp[g2]
    df1 = pd.DataFrame({g1: g1_exp, g2: g2_exp, 'treatment': DEG_exp['treatment']})
    df1 = df1.dropna()
    g = sns.jointplot(x=g1, y=g2, data=df1, hue='treatment', palette=palette, kind='scatter', marginal_kws=dict(fill=True))
    sns.regplot(x=g1, y=g2, data=df1, scatter=False, ax=g.ax_joint, color='black')
    adjust_plot_limits(g.ax_joint, df1[g1], df1[g2])
    
    g.ax_joint.tick_params(left=False, bottom=False)
    g.ax_marg_x.tick_params(bottom=False)
    g.ax_marg_y.tick_params(left=False)
    cor = pearsonr(df1[g1], df1[g2])[0]
    plt.suptitle(f'Correlation between {g1} and {g2} is {cor:.3f}', fontsize=title_fontsize)
    plt.subplots_adjust(top=0.95)
    plt.show()

     # Plot the jointplot between DEG of g1 and g2 residue
    L_values = pseudo_1[:, L].X.toarray().flatten()
    R_values = pseudo_2[:, R].X.toarray().flatten()
    L_R = L_values * R_values
    X = sm.add_constant(L_R)

    # Fit OLS model for g2
    model_g2 = sm.OLS(g2_exp, X).fit()
    residuals_g2 = model_g2.resid

    # Fit OLS model for g1
    model_g1 = sm.OLS(g1_exp, X).fit()
    residuals_g1 = model_g1.resid

    #   Create DataFrame for residuals
    df_res = pd.DataFrame({
        'G_1_residual': residuals_g1,
        'G_2_residual': residuals_g2,
        'treatment': DEG_exp['treatment']
    })

    # Plot the residuals
    g = sns.jointplot(x='G_1_residual', y='G_2_residual', data=df_res, hue='treatment', palette=palette, kind='scatter', marginal_kws=dict(fill=True))
    sns.regplot(x='G_1_residual', y='G_2_residual', data=df_res, scatter=False, ax=g.ax_joint, color='black')

    # Adjust plot limits if needed
    adjust_plot_limits(g.ax_joint, df_res['G_1_residual'], df_res['G_2_residual'])

    # Customize plot appearance
    g.ax_joint.tick_params(left=False, bottom=False)
    g.ax_marg_x.tick_params(bottom=False)
    g.ax_marg_y.tick_params(left=False)

    # Calculate and display correlation
    corr_res = pearsonr(df_res['G_1_residual'], df_res['G_2_residual'])[0]
    plt.suptitle(f'Correlation between residuals of G_1 and G_2 is {corr_res:.3f}', fontsize=title_fontsize)
    plt.subplots_adjust(top=0.95)
    plt.show()



In [None]:
plot_CIT_DEGcorr('ABCC10_T', 'AP3S2_Endo', 'CXCL10', 'CXCR3', gene_df, adata_1_pseudo, adata_2_pseudo)