In [None]:
import pandas as pd
import numpy as np
import gseapy as gp
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rcParams
import re

In [None]:
import pandas as pd
def _df_append(self, other, ignore_index=False, verify_integrity=False, sort=False):
    return pd.concat([self, other], ignore_index=ignore_index, sort=sort, axis=0)

pd.DataFrame.append = _df_append

In [None]:
gene_df_H = pd.read_csv("/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/COVID/ensemble_gene_attributions_H.csv")
top_genes_H = gene_df_H.sort_values('mean_attribution', ascending=False)['gene'].head(20).tolist()
enr_H=gp.enrichr(gene_list=top_genes_H,gene_sets=['GO_Biological_Process_2025','KEGG_2021_Human','Reactome_Pathways_2024',"COVID-19_Related_Gene_Sets_2021"],organism='Human',outdir='/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/COVID/Analysis/Figure/GSEA_plot',cutoff=1)

gene_df_M = pd.read_csv("/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/COVID/ensemble_gene_attributions_M.csv")
top_genes_M = gene_df_M.sort_values('mean_attribution', ascending=False)['gene'].head(20).tolist()
enr_M=gp.enrichr(gene_list=top_genes_M,gene_sets=['GO_Biological_Process_2025','KEGG_2021_Human','Reactome_Pathways_2024',"COVID-19_Related_Gene_Sets_2021"],organism='Human',outdir='/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/COVID/Analysis/Figure/GSEA_plot',cutoff=1)

gene_df_S = pd.read_csv("/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/COVID/ensemble_gene_attributions_S.csv")
top_genes_S = gene_df_S.sort_values('mean_attribution', ascending=False)['gene'].head(20).tolist()
enr_S=gp.enrichr(gene_list=top_genes_S,gene_sets=['GO_Biological_Process_2025','KEGG_2021_Human','Reactome_Pathways_2024',"COVID-19_Related_Gene_Sets_2021"],organism='Human',outdir='/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/COVID/Analysis/Figure/GSEA_plot',cutoff=1)


In [None]:
enr_H_results = enr_H.results.copy()
enr_H_results['Group'] = 'Healthy'

enr_M_results = enr_M.results.copy()
enr_M_results['Group'] = 'Moderate'

enr_S_results = enr_S.results.copy()
enr_S_results['Group'] = 'Severe'

In [None]:
combined_results = pd.concat([enr_H_results, enr_M_results, enr_S_results], ignore_index=True)
significant_results = combined_results[combined_results['Adjusted P-value'] < 0.05]

In [None]:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


rcParams['font.sans-serif'] = ['Arial']
rcParams['axes.unicode_minus'] = False

def create_top_pathways_heatmap(enr_H, enr_M, enr_S, gene_set_filter='GO_Biological_Process_2025', top_n=10):

    def get_top_pathways(enr_result, gene_set, n=top_n):
        df = enr_result.results.copy()
        df_filtered = df[(df['Gene_set'] == gene_set) & (df['Adjusted P-value'] < 0.05)]
        top_pathways = df_filtered.nsmallest(n, 'Adjusted P-value')
        return top_pathways[['Term', 'Adjusted P-value', 'Odds Ratio']]
    
    h_top = get_top_pathways(enr_H, gene_set_filter, top_n)
    m_top = get_top_pathways(enr_M, gene_set_filter, top_n)
    s_top = get_top_pathways(enr_S, gene_set_filter, top_n)
    

    all_pathways = pd.concat([h_top['Term'], m_top['Term'], s_top['Term']]).unique()

    heatmap_data = pd.DataFrame(index=all_pathways, columns=['Healthy', 'Moderate', 'Severe'])

    for pathway in all_pathways:
        h_match = h_top[h_top['Term'] == pathway]
        if len(h_match) > 0:
            heatmap_data.loc[pathway, 'Healthy'] = -np.log10(h_match['Adjusted P-value'].iloc[0])
        else:
            h_all = enr_H.results[(enr_H.results['Gene_set'] == gene_set_filter) & 
                                 (enr_H.results['Term'] == pathway)]
            if len(h_all) > 0:
                heatmap_data.loc[pathway, 'Healthy'] = -np.log10(h_all['Adjusted P-value'].iloc[0])
            else:
                heatmap_data.loc[pathway, 'Healthy'] = 0

        m_match = m_top[m_top['Term'] == pathway]
        if len(m_match) > 0:
            heatmap_data.loc[pathway, 'Moderate'] = -np.log10(m_match['Adjusted P-value'].iloc[0])
        else:
            m_all = enr_M.results[(enr_M.results['Gene_set'] == gene_set_filter) & 
                                 (enr_M.results['Term'] == pathway)]
            if len(m_all) > 0:
                heatmap_data.loc[pathway, 'Moderate'] = -np.log10(m_all['Adjusted P-value'].iloc[0])
            else:
                heatmap_data.loc[pathway, 'Moderate'] = 0

        s_match = s_top[s_top['Term'] == pathway]
        if len(s_match) > 0:
            heatmap_data.loc[pathway, 'Severe'] = -np.log10(s_match['Adjusted P-value'].iloc[0])
        else:
            s_all = enr_S.results[(enr_S.results['Gene_set'] == gene_set_filter) & 
                                 (enr_S.results['Term'] == pathway)]
            if len(s_all) > 0:
                heatmap_data.loc[pathway, 'Severe'] = -np.log10(s_all['Adjusted P-value'].iloc[0])
            else:
                heatmap_data.loc[pathway, 'Severe'] = 0

    heatmap_data = heatmap_data.astype(float)

    plt.figure(figsize=(10, max(8, len(all_pathways) * 0.15)))

    ax = sns.heatmap(heatmap_data, 
                     annot=True,
                     fmt='.2f',
                     cmap='Reds',
                     cbar_kws={'label': '-log10(Adjusted P-value)'},
                     linewidths=0.5,
                     linecolor='white')
    

    plt.title(f'Top {top_n} Pathways Heatmap\n{gene_set_filter}', fontsize=14, pad=20)
    plt.xlabel('Groups', fontsize=12)
    plt.ylabel('Pathways', fontsize=12)
    
    cleaned_pathway_labels = [re.sub(r'\s*\([^)]*\)$', '', pathway).strip() 
                              for pathway in heatmap_data.index]
    ax.set_yticklabels(cleaned_pathway_labels, rotation=0, fontsize=9)
    
    ax.set_xticklabels(['Healthy', 'Moderate', 'Severe'], rotation=0, fontsize=11)
    
    plt.tight_layout()
    plt.savefig(f"/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/COVID/Analysis/Figure/GSEA_plot/top_{top_n}_pathways_heatmap_{gene_set_filter}_20.pdf", dpi=300, bbox_inches='tight')
    plt.show()
    
    return heatmap_data, all_pathways


heatmap_go, pathways_go = create_top_pathways_heatmap(enr_H, enr_M, enr_S, 
                                                     'GO_Biological_Process_2025', 
                                                     top_n=5)                           


## OUT

In [None]:
# 使用dotplot进行多组比较可视化
ax = dotplot(significant_results,
             column="Adjusted P-value",
             x='Group',  # 按组别分组
             size=10,
             top_term=10,  # 显示前10个最显著的通路
             figsize=(8, 20),
             title="GSEA Enrichment Comparison Across Groups",
             xticklabels_rot=45,
             show_ring=True,
             marker='o')
plt.tight_layout()
plt.show()

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rcParams

# 设置中文字体支持
rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
rcParams['axes.unicode_minus'] = False

def create_top_pathways_heatmap(enr_H, enr_M, enr_S, gene_set_filter='GO_Biological_Process_2025', top_n=10):
    """
    提取三组的Top通路并绘制热图
    
    Parameters:
    - enr_H, enr_M, enr_S: 三组的enrichr结果对象
    - gene_set_filter: 要分析的基因集类型
    - top_n: 每组提取的top通路数量
    """
    
    # 1. 提取每组的top通路
    def get_top_pathways(enr_result, gene_set, n=top_n):
        df = enr_result.results.copy()
        # 筛选指定基因集和显著结果
        df_filtered = df[(df['Gene_set'] == gene_set) & (df['Adjusted P-value'] < 0.05)]
        # 按调整后p值排序，取前n个
        top_pathways = df_filtered.nsmallest(n, 'Adjusted P-value')
        return top_pathways[['Term', 'Adjusted P-value', 'Odds Ratio']]
    
    # 获取每组的top通路
    h_top = get_top_pathways(enr_H, gene_set_filter, top_n)
    m_top = get_top_pathways(enr_M, gene_set_filter, top_n)
    s_top = get_top_pathways(enr_S, gene_set_filter, top_n)
    
    print(f"Healthy组找到 {len(h_top)} 个显著通路")
    print(f"Moderate组找到 {len(m_top)} 个显著通路")
    print(f"Severe组找到 {len(s_top)} 个显著通路")
    
    # 2. 合并所有通路，创建通路集合
    all_pathways = pd.concat([h_top['Term'], m_top['Term'], s_top['Term']]).unique()
    print(f"\n总共发现 {len(all_pathways)} 个独特通路")
    
    # 3. 创建热图数据矩阵
    heatmap_data = pd.DataFrame(index=all_pathways, columns=['Healthy', 'Moderate', 'Severe'])
    
    # 为每个通路填充数据
    for pathway in all_pathways:
        # Healthy组
        h_match = h_top[h_top['Term'] == pathway]
        if len(h_match) > 0:
            heatmap_data.loc[pathway, 'Healthy'] = -np.log10(h_match['Adjusted P-value'].iloc[0])
        else:
            # 如果该通路不在top列表中，从完整结果中查找
            h_all = enr_H.results[(enr_H.results['Gene_set'] == gene_set_filter) & 
                                 (enr_H.results['Term'] == pathway)]
            if len(h_all) > 0:
                heatmap_data.loc[pathway, 'Healthy'] = -np.log10(h_all['Adjusted P-value'].iloc[0])
            else:
                heatmap_data.loc[pathway, 'Healthy'] = 0
        
        # Moderate组
        m_match = m_top[m_top['Term'] == pathway]
        if len(m_match) > 0:
            heatmap_data.loc[pathway, 'Moderate'] = -np.log10(m_match['Adjusted P-value'].iloc[0])
        else:
            m_all = enr_M.results[(enr_M.results['Gene_set'] == gene_set_filter) & 
                                 (enr_M.results['Term'] == pathway)]
            if len(m_all) > 0:
                heatmap_data.loc[pathway, 'Moderate'] = -np.log10(m_all['Adjusted P-value'].iloc[0])
            else:
                heatmap_data.loc[pathway, 'Moderate'] = 0
        
        # Severe组
        s_match = s_top[s_top['Term'] == pathway]
        if len(s_match) > 0:
            heatmap_data.loc[pathway, 'Severe'] = -np.log10(s_match['Adjusted P-value'].iloc[0])
        else:
            s_all = enr_S.results[(enr_S.results['Gene_set'] == gene_set_filter) & 
                                 (enr_S.results['Term'] == pathway)]
            if len(s_all) > 0:
                heatmap_data.loc[pathway, 'Severe'] = -np.log10(s_all['Adjusted P-value'].iloc[0])
            else:
                heatmap_data.loc[pathway, 'Severe'] = 0
    
    # 转换为数值类型
    heatmap_data = heatmap_data.astype(float)
    
    # 4. 绘制热图
    plt.figure(figsize=(10, max(8, len(all_pathways) * 0.4)))
    
    # 创建热图
    ax = sns.heatmap(heatmap_data, 
                     annot=True,  # 显示数值
                     fmt='.2f',   # 数值格式
                     cmap='Reds', # 颜色映射
                     cbar_kws={'label': '-log10(Adjusted P-value)'},
                     linewidths=0.5,
                     linecolor='white')
    
    # 设置标题和标签
    plt.title(f'Top {top_n} Pathways Heatmap\n{gene_set_filter}', fontsize=14, pad=20)
    plt.xlabel('Groups', fontsize=12)
    plt.ylabel('Pathways', fontsize=12)
    
    # 调整y轴标签（通路名称）
    pathway_labels = [pathway[:60] + '...' if len(pathway) > 60 else pathway 
                     for pathway in heatmap_data.index]
    ax.set_yticklabels(pathway_labels, rotation=0, fontsize=9)
    
    # 调整x轴标签
    ax.set_xticklabels(['Healthy', 'Moderate', 'Severe'], rotation=0, fontsize=11)
    
    # 添加显著性阈值线（可选）
    significance_line = -np.log10(0.05)
    
    plt.tight_layout()
    plt.savefig(f"/home/wuqinhua/Project/PHASE_1r/AttnMoE_test/result/COVID/Analysis/Figure/GSEA_plot/top_{top_n}_pathways_heatmap_{gene_set_filter}_150.pdf", dpi=300, bbox_inches='tight')

    plt.show()
    
    return heatmap_data, all_pathways

# 使用函数 - GO生物过程
print("=== GO Biological Process Top 10 通路热图 ===")
heatmap_go, pathways_go = create_top_pathways_heatmap(enr_H, enr_M, enr_S, 
                                                     'GO_Biological_Process_2025', 
                                                     top_n=10)

# 使用函数 - KEGG通路
print("\n=== KEGG Top 10 通路热图 ===")
heatmap_kegg, pathways_kegg = create_top_pathways_heatmap(enr_H, enr_M, enr_S, 
                                                         'KEGG_2021_Human', 
                                                         top_n=10)

# 使用函数 - Reactome通路
print("\n=== Reactome Top 10 通路热图 ===")
heatmap_reactome, pathways_reactome = create_top_pathways_heatmap(enr_H, enr_M, enr_S, 
                                                                  'Reactome_Pathways_2024', 
                                                                  top_n=10)

In [None]:
# default: Human
names = gp.get_library_name()
names

In [None]:
gene_list = pd.read_csv("/home/wuqinhua/Project/PHASE_1r/AttnMOE/Code/Result/COVID/Model_result/ensemble_gene_attributions_H.csv")
gene_list

In [None]:
top_50_genes = gene_list.sort_values('mean_attribution', ascending=False)['gene'].head(100).tolist()
print(top_50_genes)

In [None]:
enr=gp.enrichr(gene_list=top_50_genes,
    gene_sets=['GO_Biological_Process_2025','KEGG_2021_Human','Reactome_Pathways_2024',"COVID-19_Related_Gene_Sets_2021"],
    organism='Human',
    outdir='/home/wuqinhua/Project/PHASE_1r/AttnMOE/Code/Result/COVID/Analysis/Figure/GSEA_plot',
    cutoff=1
    )

In [None]:
enr.results

In [None]:
from gseapy import barplot, dotplot

# categorical scatterplot
ax = dotplot(enr.results,
              column="Adjusted P-value",
              x='Gene_set', # set x axis, so you could do a multi-sample/library comparsion
              size=10,
              top_term=5,
              figsize=(5,10),
              title = "KEGG",
              xticklabels_rot=45, # rotate xtick labels
              show_ring=True, # set to False to revmove outer ring
              marker='o',
             )

In [None]:
gene_list = pd.read_csv("/home/wuqinhua/Project/PHASE_1r/AttnMOE/Code/Result/COVID/Model_result/ensemble_gene_attributions_M.csv")
top_50_genes_M = gene_list.sort_values('mean_attribution', ascending=False)['gene'].head(100).tolist()
enr_M=gp.enrichr(gene_list=top_50_genes_M,
    gene_sets=['GO_Biological_Process_2025','KEGG_2021_Human','Reactome_Pathways_2024',"COVID-19_Related_Gene_Sets_2021"],
    organism='Human',
    outdir='/home/wuqinhua/Project/PHASE_1r/AttnMOE/Code/Result/COVID/Analysis/Figure/GSEA_plot',
    cutoff=1
    )
ax = dotplot(enr_M.results,
              column="Adjusted P-value",
              x='Gene_set', # set x axis, so you could do a multi-sample/library comparsion
              size=10,
              top_term=5,
              figsize=(5,10),
              title = "KEGG",
              xticklabels_rot=45, # rotate xtick labels
              show_ring=True, # set to False to revmove outer ring
              marker='o',
             )

In [None]:
gene_list = pd.read_csv("/home/wuqinhua/Project/PHASE_1r/AttnMOE/Code/Result/COVID/Model_result/ensemble_gene_attributions_S.csv")
top_50_genes_S = gene_list.sort_values('mean_attribution', ascending=False)['gene'].head(100).tolist()
enr_S=gp.enrichr(gene_list=top_50_genes_S,
    gene_sets=['GO_Biological_Process_2025','KEGG_2021_Human','Reactome_Pathways_2024',"COVID-19_Related_Gene_Sets_2021"],
    organism='Human',
    outdir='/home/wuqinhua/Project/PHASE_1r/AttnMOE/Code/Result/COVID/Analysis/Figure/GSEA_plot',
    cutoff=1
    )
ax = dotplot(enr_S.results,
              column="Adjusted P-value",
              x='Gene_set', # set x axis, so you could do a multi-sample/library comparsion
              size=10,
              top_term=5,
              figsize=(5,10),
              title = "KEGG",
              xticklabels_rot=45, # rotate xtick labels
              show_ring=True, # set to False to revmove outer ring
              marker='o',
             )

In [None]:
ax = dotplot(enr_S.res2d, title='KEGG_2021_Human',cmap='viridis_r', size=10, figsize=(3,5))
