In [None]:
import os
import tarfile
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import gseapy as gp

# Set relative paths
base_data_path = os.path.join("..", "..", "data", "nicole_lu")
gsea_output_path = os.path.join("..", "..", "gsea_output")

# File paths
tar_path = os.path.join(base_data_path, "GSE171524_RAW.tar")
extract_path = os.path.join(base_data_path, "GSE171524_RAW")
deg_path = os.path.join(base_data_path, "DESeq2_results_cov_vs_ctrl.csv")
sig_gene_list_path = os.path.join(base_data_path, "significant_gene_list.txt")
ranked_list_path = os.path.join(base_data_path, "ranked_gene_list.rnk")
gsea_report_path = os.path.join(gsea_output_path, "gseapy.gene_set.prerank.report.csv")

# Extract TAR file
with tarfile.open(tar_path, 'r') as tar:
    tar.extractall(path=extract_path)

# Load DESeq2 results
deg_df = pd.read_csv(deg_path)
deg_df.rename(columns={'Unnamed: 0': 'gene'}, inplace=True)

# Filter significant DEGs
sig_degs = deg_df[(deg_df['padj'] < 0.05) & (deg_df['log2FoldChange'].abs() > 1)]

# Save significant gene list
sig_degs['gene'].to_csv(sig_gene_list_path, index=False, header=False)

# Create and save ranked gene list
ranked_list = deg_df[['gene', 'log2FoldChange']].sort_values(by='log2FoldChange', ascending=False)
ranked_list.to_csv(ranked_list_path, sep="\t", index=False, header=False)

print("Files saved:")
print("- significant_gene_list.txt")
print("- ranked_gene_list.rnk")

# Run GSEA
gsea_results = gp.prerank(
    rnk=ranked_list_path,
    gene_sets=['KEGG_2021_Human', 'GO_Biological_Process_2021'],
    outdir=gsea_output_path,
    permutation_num=100,
    seed=42,
    format='png'
)

# Load GSEA results
gsea_df = pd.read_csv(gsea_report_path)

# Top immune-related terms
immune_df = gsea_df[gsea_df['Term'].str.contains("immune|cytokine|interferon|inflamm|virus|coronavirus", case=False)]
top_immune = immune_df.sort_values(by='FDR q-val').head(10)

# Plot top immune pathways
plt.figure(figsize=(10, 6))
plt.barh(top_immune['Term'], top_immune['NES'], color='skyblue')
plt.xlabel("Normalized Enrichment Score (NES)")
plt.title("Top Enriched Immune-Related Pathways")
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

# Get top genes from enriched immune terms
target_terms = gsea_df[gsea_df['Term'].str.contains("cytokine|interferon|immune|coronavirus", case=False)]
genes_str = ";".join(target_terms['Lead_genes'])
genes_list = list(set(genes_str.split(';')))

# Filter DEGs for those immune-related genes
immune_gene_df = deg_df[deg_df['gene'].isin(genes_list)].sort_values(by='log2FoldChange', ascending=False)

# Display full immune gene table
display(immune_gene_df)

# Plot immune gene log2 fold changes
plt.figure(figsize=(10, 6))
sns.barplot(
    x='log2FoldChange',
    y='gene',
    data=immune_gene_df,
    hue='gene',
    palette='coolwarm',
    dodge=False,
    legend=False
)
plt.xlabel("Log2 Fold Change (COVID vs. Control)")
plt.ylabel("Immune-related Gene")
plt.title("Expression Changes in Immune Driver Genes")
plt.tight_layout()
plt.show()
