In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scanpy as sc
import scipy.io as sio
import anndata as ad
import seaborn as sns
import os as os
import sys as sys
sys.path.append('/home/qiuaodon/Desktop/PanCancer_scRNA_analysis/utils/')
from scRNA_utils import *
import operator as op
import matplotlib.colors as mcolors

# load in sc data

In [2]:
data_dir_NHDP = "/home/qiuaodon/Desktop/project_data_new/"
adata_T = sc.read(data_dir_NHDP + '1863-counts_cells_cohort1_T_cells.h5ad')
adata_1 = adata_T[adata_T[:,'PDCD1'].X > 0, :]
adata_M = sc.read(data_dir_NHDP + '1863-counts_cells_cohort1_M_cells.h5ad')
adata_2 = adata_M

In [3]:
colors = ["grey", "blue"]  # Start with grey and end with blue
cmap = mcolors.LinearSegmentedColormap.from_list("grey_to_blue", colors)


In [None]:
adata_1_pseudo = scRNA2PseudoBulkAnnData(adata_1, sample_id_col='sample_id')
adata_2_pseudo = scRNA2PseudoBulkAnnData(adata_2, sample_id_col='sample_id')

# prepare data matrix for IVtest

In [None]:
DEG_1 = paird_ttest(adata_1, condition_key = 'timepoint', sample_id_col = 'sample_id', patient_id_col = 'patient_id')
DEG_2 = paird_ttest(adata_2, condition_key = 'timepoint', sample_id_col = 'sample_id', patient_id_col = 'patient_id')

In [7]:
DEG_1 = DEG_1[DEG_1['pval'] < 0.05]
DEG_2 = DEG_2[DEG_2['pval'] < 0.05]
gene_1 = DEG_1.index.tolist()
gene_2 = DEG_2.index.tolist()

# Filter out invalid gene names
gene_1 = [gene for gene in gene_1 if gene in adata_1_pseudo.var_names]
gene_2 = [gene for gene in gene_2 if gene in adata_2_pseudo.var_names]

gene_1_matrix = adata_1_pseudo[:, gene_1].X
gene_2_matrix = adata_2_pseudo[:, gene_2].X
gene_1_df = pd.DataFrame(gene_1_matrix, columns=gene_1, index=adata_1_pseudo.obs['sample_id'])
gene_2_df = pd.DataFrame(gene_2_matrix, columns=gene_2, index=adata_2_pseudo.obs['sample_id'])
gene_1_df.columns = [i + '_T' for i in gene_1_df.columns]
gene_2_df.columns = [i + '_M' for i in gene_2_df.columns]
gene_df = gene_1_df
gene_df = pd.merge(gene_df, gene_2_df, on='sample_id')

In [None]:
gene_df['treatment'] = gene_df.index.str.contains('On').astype(int)
# get patient id from sample id remove the _On or _Pre
gene_df['patient_id'] = gene_df.index.str.replace('_On', '').str.replace('_Pre', '')
gene_df['treatment'].value_counts()

In [38]:
# export the gene_df to csv
# gene_df.to_csv(data_dir_NHDP + 'gene_df_PD1_Myeloid.csv')
gene_df_GEM.to_csv(data_dir_NHDP + 'gene_df_PD1_Myeloid_GEM.csv')

# load in IV result and perform CIT

In [9]:
DEG_pairs = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/IV_regression_results_PD1vsM.xlsx')
# DEG_pairs = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/IV_regression_results_CD4EXvsEndo_filtered.xlsx')
# drop the p_value column and r_squared column
DEG_pairs = DEG_pairs.drop(columns=['p_value', 'r_squared'])
# set the first two columns as g1 and g2 in DEG_pairs
DEG_pairs = DEG_pairs.rename(columns = {'gene_T':'g1', 'gene_cell':'g2'})

In [None]:

calculate_gene_correlation(DEG_pairs, adata_1_pseudo, adata_2_pseudo)

In [None]:
# Combine the two conditions using the logical OR operator
DEG_pairs = DEG_pairs[(DEG_pairs['g1vsg2_correlation'] > 0.3) | (DEG_pairs['g1vsg2_correlation'] < -0.3)]
DEG_pairs = DEG_pairs[(DEG_pairs['g1vsg2_correlation'] > 0.3)]
# sort the DEG_pairs by g1vsg2_correlation
DEG_pairs = DEG_pairs.sort_values(by='g1vsg2_correlation', ascending=False)
DEG_pairs

In [None]:
g1 = list(DEG_pairs['g1']) 
g2 = list(DEG_pairs['g2']) 
g1_in_CD4EX = [i for i in gene_df.columns if np.sum(gene_df.loc[:, i] == 0) < 15] 
g2_in_Endo = [i for i in gene_df.columns if np.sum(gene_df.loc[:, i] == 0) < 15] 
DEG_pairs = DEG_pairs[DEG_pairs['g1'].isin(g1_in_CD4EX) & DEG_pairs['g2'].isin(g2_in_Endo)] 
DEG_pairs 

In [None]:
lrpair = pd.read_csv('/home/qiuaodon/Desktop/project_data_new/lr_network_unique.tsv', sep='\t')
# get the lr pairs including ligand(from) and receptor(to) from the lrpair file with database as 'kegg'
lrpair = lrpair[['from', 'to']]
# change from as L and to as R
lrpair.columns = ['L', 'R']
lrpair

In [None]:
# check if the L and R in adata_CD4EX and adata_Mono separately and expression not all zero
L = list(lrpair['L'])
R = list(lrpair['R'])
L_in_1 = list(set(L).intersection(adata_1.var_names))
R_in_2 = list(set(R).intersection(adata_2.var_names))
L_in_1 = [x for x in L_in_1 if np.sum(adata_1_pseudo[:, x].X == 0) < 0.75 * adata_1_pseudo.shape[0]]
R_in_2 = [x for x in R_in_2 if np.sum(adata_2_pseudo[:, x].X == 0) < 0.75 * adata_2_pseudo.shape[0]]
print('L_in_CD4EX:', len(L_in_1))
print('R_in_2:', len(R_in_2))

In [None]:
# remove the LR pairs not in L_in_CD4EX and R_in_DC
lrpair = lrpair[lrpair['L'].isin(L_in_1)]
lrpair = lrpair[lrpair['R'].isin(R_in_2)]
print('lrpair_kegg:', lrpair.shape)

In [17]:
# check if CXCL13-ACKR1 in the LR_pairs
lrpair[(lrpair['L'] == 'CXCL13') & (lrpair['R'] == 'ACKR1')]
new_pair = pd.DataFrame({'L': ['CXCL13'], 'R': ['ACKR1']})
lrpair = pd.concat([lrpair, new_pair], ignore_index=True)

# validate the IV

In [38]:
import numpy as np
import statsmodels.api as sm

# generate X for treatment data
X = adata_1_pseudo.obs['timepoint']
# change the pre to 0 and on to 1
X = X.replace('pre', 0)
X = X.replace('on', 1)
X = sm.add_constant(X)

# loop through DEG_pairs
for index, row in DEG_pairs.iterrows():
    Y = row[0]
    Z = row[1]
    # get the gene expression data for Y and Z from gene_df
    Y = gene_df[Y].to_frame()
    Z = gene_df[Z].to_frame()


In [None]:
import numpy as np
import statsmodels.api as sm

# generate X for treatment data
X = adata_1_pseudo.obs['timepoint']
# change the 'pre' to 0 and 'on' to 1
X = X.replace('pre', 0)
X = X.replace('on', 1)
X = sm.add_constant(X)  # Add constant term for intercept

bic_model_1_list = []  # Store BICs for model X -> Y -> Z
bic_model_2_list = []  # Store BICs for model Y -> Z

# loop through DEG_pairs
for index, row in DEG_pairs.iterrows():
    Y_gene = row[0]  # T cell DEG
    Z_gene = row[1]  # M cell DEG
    
    # get the gene expression data for Y and Z from gene_df
    Y = gene_df[Y_gene]
    Z = gene_df[Z_gene]
    
    # Ensure Y and Z are DataFrames
    Y = Y.to_frame()
    Z = Z.to_frame()
    
    # Model 1: X -> Y -> Z
    # First, fit Y ~ X
    model_Y_X = sm.OLS(Y, X).fit()  # OLS regression for Y ~ X (treatment -> T cells)
    
    # Then fit Z ~ Y (T cells -> M cells)
    Y_with_const = sm.add_constant(Y)  # Add constant to Y for intercept
    model_Z_Y = sm.OLS(Z, Y_with_const).fit()  # OLS regression for Z ~ Y

    # Calculate BIC for Model 1
    bic_model_1 = model_Y_X.bic + model_Z_Y.bic
    bic_model_1_list.append(bic_model_1)

    # Model 2: Y -> Z (Direct relationship)
    model_Z_Y_direct = sm.OLS(Z, Y_with_const).fit()  # OLS regression for Z ~ Y
    bic_model_2_list.append(model_Z_Y_direct.bic)

# Convert lists to numpy arrays for easier operations
bic_model_1_list = np.array(bic_model_1_list)
bic_model_2_list = np.array(bic_model_2_list)

# Calculate average BIC for both models
avg_bic_model_1 = np.mean(bic_model_1_list)
avg_bic_model_2 = np.mean(bic_model_2_list)

# Output BIC results
print(f"Average BIC for Model X -> Y -> Z: {avg_bic_model_1}")
print(f"Average BIC for Model Y -> Z: {avg_bic_model_2}")

# Compare the models
if avg_bic_model_1 < avg_bic_model_2:
    print("Model X -> Y -> Z is better on average.")
else:
    print("Model Y -> Z is better on average.")


In [None]:
import seaborn as sns

# plot the violin plot for bic_model_1_list and bic_model_2_list
import matplotlib.pyplot as plt

# Create a DataFrame for the violin plot
data = pd.DataFrame({'Model 1: X->Y->Z': bic_model_1_list, 'Model 2: Y->Z': bic_model_2_list})

# Plot the violin plot
sns.violinplot(data=data)
plt.xlabel('Model')
plt.ylabel('BIC')
plt.title('BIC Comparison between Model 1 and Model 2')
plt.show()


In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import statsmodels.api as sm
from scipy.stats import spearmanr, pearsonr

def plot_CIT_DEGcorr(g1, g2, L, R, DEG_exp, pseudo_1, pseudo_2):
    # Define a custom palette
    palette = {'pre': "#E69F00", 'on': "#56B4E9"}
    #replace the treatment column 0 pre and 1 on
    gene_df['treatment'] = gene_df['treatment'].replace({0: 'pre', 1: 'on'})
    # Set font sizes
    title_fontsize = 14
    axis_fontsize = 14
    legend_fontsize = 12

    def adjust_plot_limits(ax, x_data, y_data):
        x_min, x_max = x_data.min(), x_data.max()
        y_min, y_max = y_data.min(), y_data.max()
        x_padding = (x_max - x_min) * 0.1
        y_padding = (y_max - y_min) * 0.1
        ax.set_xlim(x_min - x_padding, x_max + x_padding)
        ax.set_ylim(y_min - y_padding, y_max + y_padding)

    # Plot the jointplot between DEG of g1 and g2
    g1_exp = DEG_exp[g1]
    g2_exp = DEG_exp[g2]
    df1 = pd.DataFrame({g1: g1_exp, g2: g2_exp, 'treatment': DEG_exp['treatment']})
    df1 = df1.dropna()
    g = sns.jointplot(x=g1, y=g2, data=df1, hue='treatment', palette=palette, kind='scatter', marginal_kws=dict(fill=True))
    sns.regplot(x=g1, y=g2, data=df1, scatter=False, ax=g.ax_joint, color='black')
    adjust_plot_limits(g.ax_joint, df1[g1], df1[g2])
    
    g.ax_joint.tick_params(left=False, bottom=False)
    g.ax_marg_x.tick_params(bottom=False)
    g.ax_marg_y.tick_params(left=False)
    cor = pearsonr(df1[g1], df1[g2])[0]
    plt.suptitle(f'Correlation between {g1} and {g2} is {cor:.3f}', fontsize=title_fontsize)
    plt.subplots_adjust(top=0.95)
    plt.show()

     # Plot the jointplot between DEG of g1 and g2 residue
    L_values = pseudo_1[:, L].X.toarray().flatten()
    R_values = pseudo_2[:, R].X.toarray().flatten()
    L_R = L_values * R_values
    X = sm.add_constant(L_R)

    # Fit OLS model for g2
    model_g2 = sm.OLS(g2_exp, X).fit()
    residuals_g2 = model_g2.resid

    # Fit OLS model for g1
    model_g1 = sm.OLS(g1_exp, X).fit()
    residuals_g1 = model_g1.resid

    #   Create DataFrame for residuals
    df_res = pd.DataFrame({
        'G_1_residual': residuals_g1,
        'G_2_residual': residuals_g2,
        'treatment': DEG_exp['treatment']
    })

    # Plot the residuals
    g = sns.jointplot(x='G_1_residual', y='G_2_residual', data=df_res, hue='treatment', palette=palette, kind='scatter', marginal_kws=dict(fill=True))
    sns.regplot(x='G_1_residual', y='G_2_residual', data=df_res, scatter=False, ax=g.ax_joint, color='black')

    # Adjust plot limits if needed
    adjust_plot_limits(g.ax_joint, df_res['G_1_residual'], df_res['G_2_residual'])

    # Customize plot appearance
    g.ax_joint.tick_params(left=False, bottom=False)
    g.ax_marg_x.tick_params(bottom=False)
    g.ax_marg_y.tick_params(left=False)

    # Calculate and display correlation
    corr_res = pearsonr(df_res['G_1_residual'], df_res['G_2_residual'])[0]
    plt.suptitle(f'Correlation between residuals of G_1 and G_2 is {corr_res:.3f}', fontsize=title_fontsize)
    plt.subplots_adjust(top=0.95)
    plt.show()



In [8]:
gene_df_pre = gene_df[gene_df['treatment'] == 0]
gene_df_on = gene_df[gene_df['treatment'] == 1]

In [None]:
import statsmodels.api as sm

X = gene_df_pre['TXNIP_T']
Y = gene_df_pre['TSC22D3_M']
# Add constant term to X for intercept
X = sm.add_constant(X)
# Fit the linear regression model
model = sm.OLS(Y, X).fit()
# Calculate the residuals
residuals = model.resid

# Fit the linear regression model
model2 = sm.OLS(Y, X).fit()
# Calculate the residuals
residuals2 = model2.resid
# combine the residuals and residuals2
residual = pd.concat([residuals, residuals2], axis=0)
# plot the residuals with X and also show color by treatment, 0 is orange and 1 is blue
plt.scatter(gene_df_pre['TXNIP_T'], residuals, c='orange')
plt.scatter(gene_df_on['TXNIP_T'], residuals2, c='skyblue')
plt.xlabel('TXNIP_T')
plt.ylabel('Residuals of TSC22D3_M')
plt.title('Residuals of TSC22D3_M vs. TXNIP_T')
plt.show()

In [9]:
def adjust_plot_limits(ax, x_data, y_data):
        x_min, x_max = x_data.min(), x_data.max()
        y_min, y_max = y_data.min(), y_data.max()
        x_padding = (x_max - x_min) * 0.1
        y_padding = (y_max - y_min) * 0.1
        ax.set_xlim(x_min - x_padding, x_max + x_padding)
        ax.set_ylim(y_min - y_padding, y_max + y_padding)

In [None]:
import statsmodels.api as sm
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
palette = {'pre': "#E69F00", 'on': "#56B4E9"}
# Assuming gene_df_pre and gene_df_on are your dataframes
g1_pre = gene_df_pre['PRDM1_T']  # g1 corresponds to ZFP36_T
g2_pre = gene_df_pre['TUBA1A_M']   # g2 corresponds to RHOB_M
g1_on = gene_df_on['PRDM1_T']
g2_on = gene_df_on['TUBA1A_M']

# Add constant term to g1 for intercept
g1_pre_with_const = sm.add_constant(g1_pre)
g1_on_with_const = sm.add_constant(g1_on)

# Fit the linear regression models
model_pre = sm.OLS(g2_pre, g1_pre_with_const).fit()
model_on = sm.OLS(g2_on, g1_on_with_const).fit()

# Calculate the residuals
residuals_pre = model_pre.resid
residuals_on = model_on.resid

# Combine both datasets into a single DataFrame
combined_residuals_pre = pd.DataFrame({'g1': g1_pre, 'residuals': residuals_pre, 'treatment': 'pre'})
combined_residuals_on = pd.DataFrame({'g1': g1_on, 'residuals': residuals_on, 'treatment': 'on'})
combined_residuals = pd.concat([combined_residuals_pre, combined_residuals_on])

# Create the joint plot with residuals
g = sns.jointplot(data=combined_residuals, x='g1', y='residuals', hue='treatment', kind='scatter', height=6, palette=palette)

# Add title
g.fig.suptitle('Residuals of g2 (TUBA1A_M) vs. g1 (PRDM1_T) by Treatment', y=1.02)
adjust_plot_limits(g.ax_joint, combined_residuals['g1'], combined_residuals['residuals'])
# Adjust layout to fit the title
plt.tight_layout()
plt.show()


# CIT using fisher-z

In [None]:
DEG_pairs

In [None]:
lrpair

In [16]:
def CIT_test(DEG_pairs, lrpairs, adata_1_pseudo, adata_2_pseudo, gene_df, method="kci", p_value_threshold=0.05):
    results = pd.DataFrame(columns=['g1', 'g2', 'L', 'R', 'pValue'])

    for index, row in DEG_pairs.iterrows():
        g1 = row['g1']
        g2 = row['g2']
        for L, R in lrpairs.values:
                # Check if L and R can be found in the DEG expression data
                if L in adata_1_pseudo.var_names and R in adata_2_pseudo.var_names:
                    # Get the L_T and R_B from pseudo_T and pseudo_M
                    L_T = adata_1_pseudo[:, L].X
                    L_T = pd.DataFrame(L_T, columns=[L], index=adata_1_pseudo.obs['sample_id']).dropna()
                    R_M = adata_2_pseudo[:, R].X
                    R_M = pd.DataFrame(R_M, columns=[R], index=adata_2_pseudo.obs['sample_id']).dropna()
                    
                    # Calculate L*R
                    L_T.columns = [R + '_vs_' + L]
                    R_M.columns = L_T.columns
                    L_R = L_T.multiply(R_M, axis=0).iloc[:, 0]

                    # Make a DataFrame g1, g2 and L*R
                    if g1 not in gene_df.columns or g2 not in gene_df.columns:
                        continue
                    
                    g1_T = gene_df[g1]
                    g2_M = gene_df[g2]
                    df = pd.DataFrame({g1: g1_T, g2: g2_M, 'L*R': L_R}).dropna()

                    # Convert df to numpy array
                    df_numpy = df.to_numpy()

                    # Perform the chosen independence test
                    if method == "kci":
                        cit_obj = CIT(df_numpy, "kci")
                    elif method == "fisherz":
                        cit_obj = CIT(df_numpy, "fisherz")
                    else:
                        raise ValueError("Unsupported method. Use 'kci' or 'fisherz'.")
                    
                    pValue = cit_obj(0, 1, [2])
                    results = pd.concat([results, pd.DataFrame({'g1': [g1], 'g2': [g2], 'L': [L], 'R': [R], 'pValue': [pValue]})], ignore_index=True)

    return results

In [17]:
results = CIT_test(DEG_pairs, lrpair, adata_1_pseudo, adata_2_pseudo, gene_df, method = 'fisherz', p_value_threshold=0.05)

In [None]:
results

In [None]:
#calculate the correlation between g1 and g2
calculate_gene_correlation(results, adata_1_pseudo, adata_2_pseudo)

In [None]:
results_filtered = results.loc[results.groupby(['g1', 'g2'])['pValue'].idxmax()]
results_filtered

In [None]:
# combine g1_g2_L_R from the results_filtered as one column
results_filtered['g1_g2_L_R'] = results_filtered['g1'] + '_' + results_filtered['g2'] + '_' + results_filtered['L'] + '_' + results_filtered['R']
results_filtered.index = results_filtered['g1'] + '_' + results_filtered['g2'] + '_' + results_filtered['L'] + '_' + results_filtered['R']
results_filtered

In [None]:
# plot the distribution g1vsg2_correlation vs log(pValue)
# Calculate -log10(pValue) for each pair
results_filtered['log_pValue'] = np.log10(results_filtered['pValue'])

# Plot the distribution
plt.figure(figsize=(10, 6))
sns.scatterplot(x='g1vsg2_correlation', y='log_pValue', data=results_filtered)
plt.xlabel('g1 vs g2 Correlation')
plt.ylabel('-log10(pValue)')
plt.title('Distribution of g1 vs g2 Correlation vs log10(pValue)')
plt.show()


In [None]:
#calculate the correlation between g1 and g2
calculate_gene_correlation(results, adata_1_pseudo, adata_2_pseudo)
#calculate the residuals correlation between g1, g2 and LR
calculate_residuals_correlation(results, adata_1_pseudo, adata_2_pseudo)

In [23]:
results.to_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_PD1vsMyeloid_fisherz_corr03.xlsx')

In [None]:
# results = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_CD4EXvsEndo_fisherz_nosamplefiltered.xlsx')

## CIT using KCI


In [None]:
DEG_pairs

In [None]:
lrpair

In [None]:
results = CIT_test(DEG_pairs, lrpair, adata_1_pseudo, adata_2_pseudo, gene_df, method = 'kci', p_value_threshold=0.05)

In [None]:
#calculate the correlation between g1 and g2
calculate_gene_correlation(results, adata_1_pseudo, adata_2_pseudo)
#calculate the residuals correlation between g1, g2 and LR
calculate_residuals_correlation(results, adata_1_pseudo, adata_2_pseudo)

In [None]:
# save results
results.to_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_CD4EXvsMyeloid_KCI_nosamplefiltered.xlsx')

# plot the corr For GEM

In [None]:
gene_1 = 'CD69'
gene_2 = 'HERPUD1'
L = 'ANXA1'
R = 'CXCR4'
# make the color for umap from grey to blue
colors = ["grey", "blue"]
cmap = mcolors.LinearSegmentedColormap.from_list("grey_to_blue", colors)
# plot the umap with the color of gene_1 and gene_2
sc.pl.umap(adata_1, color = [gene_1, L], cmap = cmap)
sc.pl.umap(adata_2, color = [gene_2, R], cmap = cmap)

## group the DEGpairs together using L and R

In [36]:
# results = pd.read_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_CD4EXvsDC_fisherz_nosamplefiltered.xlsx')
# results = results.drop(columns=['Unnamed: 0'])
# # Remove the suffix '_X' from g1 and g2
results['g1'] = results['g1'].str.replace(r'_\w+', '', regex=True)
results['g2'] = results['g2'].str.replace(r'_\w+', '', regex=True)

In [None]:
results

In [None]:
results['R'].unique()

In [None]:
results['R'].nunique()

In [None]:
import pandas as pd
grouped_results = results.groupby(["L", "R"])

# Display grouped data
grouped_dict = { (ligand, receptor): group for (ligand, receptor), group in grouped_results }

# Create a new DataFrame to hold the GEM1 and GEM2 groupings
grouped_data = []
tracked_gem1 = set()
tracked_gem2 = set()

min_gene_num = 3

# Temporary list to hold data before creating DataFrame
temp_data = []

for (ligand, receptor), group in grouped_results:
    gem1 = group["g1"].drop_duplicates().tolist()
    gem2 = group["g2"].drop_duplicates().tolist()
    
    # Check the overlap between current GEM_g1 and previously tracked GEM_g1
    overlap_gem1 = len(set(gem1) & tracked_gem1)
    overlap_gem2 = len(set(gem2) & tracked_gem2)
    
    # Filter to include only those with at least three genes in GEM_g1 and GEM_g2
    # and where the overlap is below the maximum allowable threshold
    if len(gem1) > min_gene_num :
        temp_data.append({
            "Ligand": ligand,
            "Receptor": receptor,
            "GEM_g1": gem1,
            "GEM_g2": gem2,
            "Num_genes_g1": len(gem1),
            "Num_genes_g2": len(gem2)
        })
        # Update the tracked GEM_g1 and GEM_g2 sets
        tracked_gem1.update(gem1)
        tracked_gem2.update(gem2)

# Convert the temporary data to a DataFrame
grouped_df = pd.DataFrame(temp_data)

# Function to find overlapping rows
def find_overlapping_rows(df, column_name):
    overlap_info = []
    for index, row in df.iterrows():
        current_set = set(row[column_name])
        overlapping_rows = []
        for other_index, other_row in df.iterrows():
            if index != other_index:
                other_set = set(other_row[column_name])
                overlap = len(current_set & other_set)
                if overlap > len(current_set) / 3: # Set the threshold for overlap
                    overlapping_rows.append(other_index)
        overlap_info.append(overlapping_rows)
    return overlap_info

# Add columns for overlapping rows in GEM_g1 and GEM_g2
grouped_df["Overlapping_g1"] = find_overlapping_rows(grouped_df, "GEM_g1")
grouped_df["Overlapping_g2"] = find_overlapping_rows(grouped_df, "GEM_g2")

grouped_df


In [None]:
grouped_df


In [50]:
# rank each GEM_g1 by order of name inside
grouped_df['GEM_g1'] = grouped_df['GEM_g1'].apply(lambda x: sorted(x))
grouped_df['GEM_g2'] = grouped_df['GEM_g2'].apply(lambda x: sorted(x))

In [51]:
# save the grouped_df to excel
grouped_df.to_excel('/home/qiuaodon/Desktop/project_data_new/CIT_results_PD1vsM_grouped_corr04.xlsx')

In [None]:
# # Remove the suffix '_X' from g1 and g2
grouped_df['GEM_g1'] = grouped_df['GEM_g1'].str.replace(r'_\w+', '', regex=True)
grouped_df['GEM_g2'] = grouped_df['GEM_g2'].str.replace(r'_\w+', '', regex=True)

In [None]:
# remove the [] in GEM_g1 and GEM_g2
grouped_df['GEM_g1'] = grouped_df['GEM_g1'].str.strip('[]')
grouped_df['GEM_g2'] = grouped_df['GEM_g2'].str.strip('[]')


In [None]:
grouped_df

In [None]:
sc.pl.umap(adata_1, color =['SESN1', 'NFKBIA', 'PMAIP1', 'IL6ST', 'TXNIP', 'PELI1', 'PRDM1', 'CRYBG1', 'TSC22D3', 'TP53INP1'])

## draw heatmap

In [None]:
# sort DEG1_1 and DEG1_2 by the p_value
DEG_1 = DEG_1.sort_values(by='pval')
DEG_2 = DEG_2.sort_values(by='pval')
DEG_1

In [16]:
# top 20 genes in DEG_1 and DEG_2 as top_genes_T and top_genes_M
top_genes_T = DEG_1.index[:25]
top_genes_M = DEG_2.index[:25]


In [None]:
# replace the ETS2 in top_genes_M with RHOB
top_genes_M = top_genes_M.drop('ETS2')
top_genes_M = top_genes_M.append(pd.Index(['RHOB']))
top_genes_M

In [19]:
top_genes_M = [
    'TSC22D3', 'DDIT4', 'FKBP5', 'SLC1A3', 'AREG', 'ACSL1',
    'RGS1', 'CH25H', 'SMIM3', 'TUBA1A', 'HERPUD1','RNF144B', 'KCNE1', 'TENT5A',
    'SLC19A2', 'SOCS1', 'DUSP6', 'TNFAIP3', 'PFKFB3', 'RHOB', 'SESN1', 'PDK4',
    'ARRDC2', 'MZF1-AS1', 'CEBPD'
]


In [None]:
top_genes_T = [
    'CH25H',
    'TSC22D3',
    'TNFAIP3',
    'BIRC3',
    'SPON2',
    'ZFP36',
    'SLA',
    'ZFP36',
    'TXNIP',
    'IER2',
    'KLF10'
]

In [None]:
import seaborn as sns
import scipy.stats as stats
from scipy.stats import spearmanr 
# Remove the sample_id that are not in the adata_T_pseudo
adata_T_pseudo = adata_1_pseudo
adata_M_pseudo = adata_2_pseudo
adata_M_pseudo = adata_M_pseudo[adata_M_pseudo.obs['sample_id'].isin(adata_T_pseudo.obs['sample_id'])]
adata_T_pseudo = adata_T_pseudo[adata_T_pseudo.obs['sample_id'].isin(adata_M_pseudo.obs['sample_id'])]

# Initialize an empty matrix to store the correlation values
corr_matrix = np.zeros((len(top_genes_T), len(top_genes_M)))
pval_matrix = np.zeros((len(top_genes_T), len(top_genes_M)))

# Calculate the correlation between the top genes of T and M cells
for i, gene_T in enumerate(top_genes_T):
    for j, gene_M in enumerate(top_genes_M):
        # remove the '_T' and '_M' from the gene names
        gene_T = gene_T.replace('_T', '')
        gene_M = gene_M.replace('_M', '')
        # Get the expression data of the two genes and flatten them to 1D arrays
        gene_T_expression = adata_T_pseudo[:, adata_T_pseudo.var_names == gene_T].X.flatten()
        gene_M_expression = adata_M_pseudo[:, adata_M_pseudo.var_names == gene_M].X.flatten()
        
        # Calculate the correlation
        corr, pval = spearmanr(gene_T_expression, gene_M_expression)
        
        # Store the correlation value in the matrix
        corr_matrix[i, j] = corr
        pval_matrix[i, j] = -np.log10(pval)


# Plot the correlation matrix as a heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=False, cmap='coolwarm', vmin=-1, vmax=1,
            xticklabels=top_genes_M, yticklabels=top_genes_T)
plt.title('Correlation between Highly Correlated DEGs in T and M cells')
plt.xlabel('M Genes')
plt.ylabel('PD1+ T Genes')
plt.show()

In [None]:

# Plotting the heatmap with correlation values
plt.figure(figsize=(12, 10))
heatmap = sns.heatmap(corr_matrix, cmap='coolwarm', vmin=-1, vmax=1,
                      xticklabels=top_genes_M, yticklabels=top_genes_T,
                      annot=False)  # Turn off default annotation

# Overlaying the correlation values and -log10 of p-values
for i in range(corr_matrix.shape[0]):
    for j in range(corr_matrix.shape[1]):
        plt.text(j + 0.5, i + 0.5, f'{corr_matrix[i, j]:.2f}\n({pval_matrix[i, j]:.2f})',
                 horizontalalignment='center', verticalalignment='center', fontsize=11, color='black' if abs(corr_matrix[i, j]) > 0.5 else 'black')

plt.title('Correlation between Top DEGs of CD4EX and M')
plt.xlabel('M Genes')
plt.ylabel('CD4EX Genes')
plt.show()

In [None]:
colors = ["grey", "blue"]  # Start with grey and end with blue
cmap = mcolors.LinearSegmentedColormap.from_list("grey_to_blue", colors)


In [None]:
sc.pl.umap(adata_CD4EX, color = ['SPON2', 'SELL'], color_map = cmap)

In [None]:
sc.pl.umap(adata_M, color = ['FKBP5', 'MADCAM1'], color_map = cmap)

In [None]:
sc.pl.umap(adata_CD4EX, color = 'timepoint')
sc.pl.umap(adata_M, color = 'timepoint')

In [None]:
plot_CIT_DEGcorr('PRDM1', 'TUBA1A', 'ANXA1', 'CXCR4', gene_df, adata_1_pseudo, adata_2_pseudo)

#### plot the scatter plots


In [21]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import statsmodels.api as sm
from scipy.stats import spearmanr, pearsonr

def plot_CIT_DEGcorr(g1, g2, L, R, DEG_exp, pseudo_1, pseudo_2):
    # Define a custom palette
    palette = {'pre': "#E69F00", 'on': "#56B4E9"}
    #replace the treatment column 0 pre and 1 on
    gene_df['treatment'] = gene_df['treatment'].replace({0: 'pre', 1: 'on'})
    # Set font sizes
    title_fontsize = 14
    axis_fontsize = 14
    legend_fontsize = 12

    def adjust_plot_limits(ax, x_data, y_data):
        x_min, x_max = x_data.min(), x_data.max()
        y_min, y_max = y_data.min(), y_data.max()
        x_padding = (x_max - x_min) * 0.1
        y_padding = (y_max - y_min) * 0.1
        ax.set_xlim(x_min - x_padding, x_max + x_padding)
        ax.set_ylim(y_min - y_padding, y_max + y_padding)

    # Plot the jointplot between DEG of g1 and g2
    g1_exp = DEG_exp[g1]
    g2_exp = DEG_exp[g2]
    df1 = pd.DataFrame({g1: g1_exp, g2: g2_exp, 'treatment': DEG_exp['treatment']})
    df1 = df1.dropna()
    g = sns.jointplot(x=g1, y=g2, data=df1, hue='treatment', palette=palette, kind='scatter', marginal_kws=dict(fill=True))
    sns.regplot(x=g1, y=g2, data=df1, scatter=False, ax=g.ax_joint, color='black')
    adjust_plot_limits(g.ax_joint, df1[g1], df1[g2])
    
    g.ax_joint.tick_params(left=False, bottom=False)
    g.ax_marg_x.tick_params(bottom=False)
    g.ax_marg_y.tick_params(left=False)
    cor = pearsonr(df1[g1], df1[g2])[0]
    plt.suptitle(f'Correlation between {g1} and {g2} is {cor:.3f}', fontsize=title_fontsize)
    plt.subplots_adjust(top=0.95)
    plt.show()

     # Plot the jointplot between DEG of g1 and g2 residue
    L_values = pseudo_1[:, L].X.toarray().flatten()
    R_values = pseudo_2[:, R].X.toarray().flatten()
    L_R = L_values * R_values
    X = sm.add_constant(L_R)

    # Fit OLS model for g2
    model_g2 = sm.OLS(g2_exp, X).fit()
    residuals_g2 = model_g2.resid

    # Fit OLS model for g1
    model_g1 = sm.OLS(g1_exp, X).fit()
    residuals_g1 = model_g1.resid

    #   Create DataFrame for residuals
    df_res = pd.DataFrame({
        'G_1_residual': residuals_g1,
        'G_2_residual': residuals_g2,
        'treatment': DEG_exp['treatment']
    })

    # Plot the residuals
    g = sns.jointplot(x='G_1_residual', y='G_2_residual', data=df_res, hue='treatment', palette=palette, kind='scatter', marginal_kws=dict(fill=True))
    sns.regplot(x='G_1_residual', y='G_2_residual', data=df_res, scatter=False, ax=g.ax_joint, color='black')

    # Adjust plot limits if needed
    adjust_plot_limits(g.ax_joint, df_res['G_1_residual'], df_res['G_2_residual'])

    # Customize plot appearance
    g.ax_joint.tick_params(left=False, bottom=False)
    g.ax_marg_x.tick_params(bottom=False)
    g.ax_marg_y.tick_params(left=False)

    # Calculate and display correlation
    corr_res = pearsonr(df_res['G_1_residual'], df_res['G_2_residual'])[0]
    plt.suptitle(f'Correlation between residuals of G_1 and G_2 is {corr_res:.3f}', fontsize=title_fontsize)
    plt.subplots_adjust(top=0.95)
    plt.show()



In [None]:
plot_CIT_DEGcorr('GEM_T_3', 'GEM_M_3', 'ITGA4', 'TFRC', gene_df, adata_1_pseudo, adata_2_pseudo)