# Load packages

In [1]:
import pandas as pd
import numpy as np
import os
os.chdir('/root/host_home')
import glob
import matplotlib.pyplot as plt
# import matplotlib.colors as mcolors # Not directly used
import random # For mock pathway file creation

# gseapy and decoupler imports
import gseapy as gs
import decoupler as dc
from gprofiler import GProfiler
from scipy.stats import hypergeom
from statsmodels.stats.multitest import multipletests # For decoupler and integration


# networkx import for network plots
import networkx as nx

# adjustText import (optional, for Manhattan plots)
try:
    from adjustText import adjust_text
    ADJUST_TEXT_AVAILABLE = True
except ImportError:
    print("Warning: adjustText library not found. Labels on Manhattan plots might overlap.")
    ADJUST_TEXT_AVAILABLE = False

# Global variables

In [2]:
# Global Variables
save_dir = 'hp_NPCs/SCENIC/'
DEG_dir = 'hp_NPCs/tables/diffxpy/' 
output_dir = os.path.join(save_dir, 'enrichment_analysis/')
os.makedirs(output_dir, exist_ok=True)

cell_types = ['NSC1a', 'NSC1b', 'NSC2a', 'NSC2b', 'Apop.-NSC', 'NCSC', 'Apop.-NCSC', 'Glial-precursors', 'Immature-neurons', 'bulk_like']
DEG_SYMBOL_COL = 'gene_symbols' 

PATHWAY_FILES = {
    'GO_BP': os.path.join(DEG_dir, 'cell_types_DEGs_diffxpy_pathways_GO_terms.xlsx'),
    'KEGG': os.path.join(DEG_dir, 'cell_types_DEGs_diffxpy_pathways_KEGG_terms.xlsx'),
    'WP': os.path.join(DEG_dir, 'cell_types_DEGs_diffxpy_pathways_WP_terms.xlsx'),
    'Reactome': os.path.join(DEG_dir, 'cell_types_DEGs_diffxpy_pathways_Reactome_terms.xlsx'),
    'mKEGG': os.path.join(DEG_dir, 'cell_types_DEGs_diffxpy_pathways_mKEGG_terms.xlsx'),
}
PATHWAY_ID_COL = 'ID' 
PATHWAY_NAME_COL = 'Description' 
PATHWAY_PADJ_COL = 'qvalue' 
PATHWAY_GENES_COL = 'geneID' 

SIGNIFICANCE_CUTOFF = 0.05 
LABEL_QVAL_THRESHOLD = 0.01 
TOP_N_LABEL = 20 
TOP_N_PIE = 10

# Functions

## Enrichment

In [3]:
# --- TF ENRICHMENT FUNCTIONS ---

def run_gseapy_enrichment(deg_list, cell_type, tf_libs, organism='Human', sig_cutoff=0.05):
    """Runs gseapy enrichment using Enrichr libraries."""
    global output_dir
    method_name = "gseapy"

    # --- Create directory structure ---
    table_dir = os.path.join(output_dir, method_name, "tables")
    figure_dir = os.path.join(output_dir, method_name, "figures")
    os.makedirs(table_dir, exist_ok=True)
    os.makedirs(figure_dir, exist_ok=True)
    
    if not deg_list:
        print(f"Warning: Empty DEG list for {cell_type}. Skipping gseapy.")
        return pd.DataFrame(columns=['Term', 'Genes', 'Adjusted P-value'])
    
    print(f"Running gseapy for {cell_type} with {len(deg_list)} genes. Sources: {tf_libs}.")
    try:
        enr_results_obj = gs.enrichr(
            gene_list=deg_list, gene_sets=tf_libs, organism=organism,
            outdir=None, cutoff=sig_cutoff, no_plot=True
        )
        enr_df = enr_results_obj.results
        enr_df = enr_df[~enr_df['Term'].str.contains('mouse|Mouse', case=False, na=False)]
        sig_df = enr_df[enr_df['Adjusted P-value'] < sig_cutoff].sort_values('Adjusted P-value')

        if not sig_df.empty:
            sig_df.to_csv(os.path.join(table_dir, f"{method_name}_{cell_type}_sig_tf.csv"), index=False)
            print(f"  Gseapy analysis for {cell_type} complete. {len(sig_df)} terms found.")
            return sig_df[['Term', 'Genes', 'Adjusted P-value']].copy()
        else:
            print(f"  Gseapy returned no enrichment results for {cell_type} with sources {tf_libs}.")
            return pd.DataFrame(columns=['Term', 'Genes', 'Adjusted P-value'])
        
    except Exception as e:
        print(f"  Error during gseapy for {cell_type}: {e}")
        return pd.DataFrame(columns=['Term', 'Genes', 'Adjusted P-value'])


## TF-Pathway interaction

In [4]:
# --- PATHWAY ENRICHMENT LOADING ---
def load_pathways(cell_type, pathway_db_key='GO_BP'):
    """Loads pre-computed pathway enrichment results for one DB type."""
    global PATHWAY_FILES, PATHWAY_PADJ_COL, PATHWAY_GENES_COL, PATHWAY_NAME_COL, PATHWAY_ID_COL, SIGNIFICANCE_CUTOFF, output_dir
    empty_df = pd.DataFrame(columns=['Term', 'Description', 'Genes', 'Adjusted P-value', 'Database'])

    if pathway_db_key not in PATHWAY_FILES: return empty_df
    fpath = PATHWAY_FILES[pathway_db_key]

    try:
        sheet = cell_type 
        try: df = pd.read_excel(fpath, sheet_name=sheet)
        except ValueError: 
             print(f"  Sheet '{sheet}' not found in {fpath}.")
             return empty_df 
        

        req_cols = [PATHWAY_PADJ_COL, PATHWAY_GENES_COL, PATHWAY_NAME_COL, PATHWAY_ID_COL]
        if not all(c in df.columns for c in req_cols): print(f"  Error: Missing cols in {fpath}"); return empty_df

        sig_df = df[df[PATHWAY_PADJ_COL] < SIGNIFICANCE_CUTOFF].copy()
        if sig_df.empty: print(f"  No significant pathways for {pathway_db_key} in {cell_type}."); return empty_df

        sig_df['Database'] = pathway_db_key
        rn_df = sig_df.rename(columns={PATHWAY_ID_COL:'Term', PATHWAY_NAME_COL:'Description', 
                                       PATHWAY_GENES_COL:'Genes_temp', PATHWAY_PADJ_COL:'Adjusted P-value'})
        rn_df['Genes'] = rn_df['Genes_temp'].astype(str).str.replace('/', ';')
        final_df = rn_df[['Term','Description','Genes','Adjusted P-value','Database']].copy()
        return final_df
    except Exception as e: print(f"  Error loading {pathway_db_key} for {cell_type}: {e}"); return empty_df


# --- TF-PATHWAY INTEGRATION ---
def integrate_tf_path(sig_tfs_df, sig_paths_df, cell_type, bg_genes, tf_tool, interaction_qval_cutoff=0.05, overlap_fraction_threshold=0.1):
    """Integrates TF and Pathway results."""
    print(f"--- Integrating TF ({tf_tool}) & Pathways for {cell_type} ---")
    integ_out_dir = os.path.join(output_dir, tf_tool, "tables")
    os.makedirs(integ_out_dir, exist_ok=True)

    if sig_tfs_df.empty or sig_paths_df.empty: print(f"  Skipping: Empty TFs/Pathways."); return pd.DataFrame()
    if not bg_genes: print(f"  Warning: Empty background for integration.")
    
    bg_set = set(bg_genes)
    links = []
    for _,tf_r in sig_tfs_df.iterrows():
        tf_n,tf_g_str = tf_r.get('Term','?TF'),str(tf_r.get('Genes',''))
        tf_targets = set(tf_g_str.split(';')) if tf_g_str else set()
        for _,p_r in sig_paths_df.iterrows():
            p_id,p_n,p_db = p_r.get('Term','?PathID'),p_r.get('Description','?PathN'),p_r.get('Database','?DB')
            p_g_str = str(p_r.get('Genes','')); p_members = set(p_g_str.split(';')) if p_g_str else set()
            shared = tf_targets.intersection(p_members)
            k_o = len(shared)
            if k_o > 0 and bg_set:
                tf_t_bg,p_m_bg = tf_targets.intersection(bg_set),p_members.intersection(bg_set)
                sh_bg = shared.intersection(bg_set); k_o_bg = len(sh_bg)
                M,n_tf,N_p = len(bg_set),len(tf_t_bg),len(p_m_bg)
                if n_tf > 0 and N_p > 0 and k_o_bg > 0:
                    #M = total number of genes in the background
                    #n = number of TF-related genes
                    #N = number of pathway-related genes
                    #k = observed overlap between pathway and TF genes                                        
                    
                    pval = hypergeom.sf(k_o_bg-1,M,n_tf,N_p)
                    overlap_frac_vs_pathway = k_o_bg / N_p if N_p > 0 else 0
                    overlap_frac_vs_tf = k_o_bg / n_tf if n_tf > 0 else 0
            
                    links.append({'TF_Tool':tf_tool,
                                  'TF':tf_n,
                                  'Pathway_DB':p_db,
                                  'Pathway_ID':p_id,
                                  'Pathway_Name':p_n,
                                  'TF_Adj_Pval':tf_r.get('Adjusted P-value',np.nan),
                                  'Pathway_Adj_Pval':p_r.get('Adjusted P-value',np.nan), 
                                  'tf_regulon_size': n_tf,
                                  'pathway_size': N_p,
                                  'overlap_size': k_o_bg,
                                  'overlap_frac_vs_pathway': overlap_frac_vs_pathway,
                                  'overlap_frac_vs_tf': overlap_frac_vs_tf, # Added this to the output
                                  'Shared_Genes':';'.join(sorted(list(sh_bg))),
                                  'interaction_Pval':pval})
    if not links: print(f"  No links found for {cell_type} ({tf_tool})."); return pd.DataFrame()
    links_df = pd.DataFrame(links); links_df.dropna(subset=['interaction_Pval'],inplace=True)
    if links_df.empty: return pd.DataFrame()
    reject,p_adj,_,_ = multipletests(links_df['interaction_Pval'],method='fdr_bh', alpha=interaction_qval_cutoff)
    links_df['interaction_Adj_Pval'] = p_adj; links_df = links_df.sort_values(by='interaction_Adj_Pval')
    links_df = links_df[
        (links_df['interaction_Adj_Pval'] < interaction_qval_cutoff) &
        (
            (links_df['overlap_frac_vs_pathway'] > overlap_fraction_threshold) | 
            (links_df['overlap_frac_vs_tf'] > overlap_fraction_threshold)
        )]
    links_df.to_csv(os.path.join(integ_out_dir,f"{tf_tool}_{cell_type}_tf_path_links.csv"),index=False)
    print(f"  Found {len(links_df)} TF-Pathway links for {cell_type} ({tf_tool}).")
    return links_df

## Plotting

In [5]:
def plot_manhattan(links_df, cell_type, tf_tool, sig_thr=SIGNIFICANCE_CUTOFF, n_lab=TOP_N_LABEL, lab_thr=LABEL_QVAL_THRESHOLD):
    """Creates Manhattan-style plot for TF-Pathway links, y-axis is interaction_Adj_Pval."""
    if links_df.empty or 'interaction_Adj_Pval' not in links_df.columns or 'Pathway_DB' not in links_df.columns:
        print(f"  Manhattan: Invalid data for {cell_type} ({tf_tool}). Missing 'interaction_Adj_Pval' or 'Pathway_DB'. Skipping."); return
    print(f"--- Plotting Manhattan: {cell_type} ({tf_tool}) ---")
    plot_dir = os.path.join(output_dir, tf_tool, "figures")
    os.makedirs(plot_dir, exist_ok=True)

    df = links_df.copy(); 
    df['interaction_Adj_Pval'] = pd.to_numeric(df['interaction_Adj_Pval'],errors='coerce')
    df.dropna(subset=['interaction_Adj_Pval'],inplace=True)
    df_plot_ready = df.copy() 
    if df_plot_ready.empty: print(f"  Manhattan: No valid interaction_Adj_Pval values after coercion."); return
    
    df_plot_ready['-logP_Overlap'] = -np.log10(df_plot_ready['interaction_Adj_Pval'] + np.finfo(float).tiny) 
    df_plot_ready['Label'] = df_plot_ready['TF'] + " - " + df_plot_ready['Pathway_Name'] 
    df_plot_ready = df_plot_ready.sort_values(by=['Pathway_DB','Label']); df_plot_ready['x'] = range(len(df_plot_ready))
    
    dbs = sorted(df_plot_ready['Pathway_DB'].unique())
    cmap = plt.colormaps['tab10'] if len(dbs)<=10 else plt.colormaps['tab20']
    cols = cmap(np.linspace(0,1,len(dbs))); db_cols = {db:cols[i] for i,db in enumerate(dbs)}
    df_plot_ready['color'] = df_plot_ready['Pathway_DB'].map(db_cols)
    
    if 'N_Shared' in df_plot_ready.columns and df_plot_ready['N_Shared'].notna().all() and df_plot_ready['N_Shared'].nunique()>1:
        min_s,max_s=df_plot_ready['N_Shared'].min(),df_plot_ready['N_Shared'].max()
        df_plot_ready['size']=150*(df_plot_ready['N_Shared']-min_s+1e-6)/(max_s-min_s+1e-6); df_plot_ready['size']=np.maximum(20,df_plot_ready['size'])
    else: df_plot_ready['size']=50
        
    figsize_width = min(100, max(15, len(df_plot_ready) * 0.05)) # Capped width
    plt.figure(figsize=(figsize_width, 7))
    plt.axhline(-np.log10(sig_thr),c='r',ls='--',lw=1,label=f'Adj. P-val = {sig_thr}') 
    plt.scatter(df_plot_ready['x'],df_plot_ready['-logP_Overlap'],c=df_plot_ready['color'],s=df_plot_ready['size'],alpha=.7,ec='k',lw=.5)
    
    lab_df = df_plot_ready[df_plot_ready['interaction_Adj_Pval']<lab_thr].sort_values(by='-logP_Overlap',ascending=False).head(n_lab)
    txts = [plt.text(r['x'],r['-logP_Overlap'],f"{str(r['TF'])[:10]}...\n{str(r['Pathway_Name'])[:15]}...",fontsize=5,ha='center') for _,r in lab_df.iterrows()]
    if ADJUST_TEXT_AVAILABLE and txts: adjust_text(txts,arrowprops=dict(arrowstyle='-',color='grey',lw=.3))
    
    plt.xticks([]); plt.xlabel('TF-Pathway Links (Grouped by Pathway DB)',fontsize=9)
    plt.ylabel('-log10 (TF-Pathway Overlap Adj. P-value)',fontsize=9); 
    plt.title(f'TF-Pathway Links: {cell_type} ({tf_tool})',fontsize=11)
    plt.ylim(bottom=0); plt.grid(axis='y',ls='--',alpha=.6)
    hndls = [plt.Line2D([0],[0],marker='o',c='w',label=db,markerfacecolor=c,ms=7) for db,c in db_cols.items()]
    plt.legend(handles=hndls,title="Pathway DB",bbox_to_anchor=(1.02,1),loc='upper left',fontsize=7)
    plt.tight_layout(rect=[0,0,0.88,1])
    for ext in ['svg','pdf']:
        fpath = os.path.join(plot_dir,f"{tf_tool}_{cell_type}_tf_path_manhattan.{ext}")
        plt.savefig(fpath); print(f"  Saved Manhattan: {fpath}")
    plt.close()

def plot_tf_path_network(links_df, cell_type, tf_tool, top_n_links=30):
    """Generates a bipartite network plot of TF-Pathway links."""
    if links_df.empty or not all(c in links_df.columns for c in ['TF', 'Pathway_Name', 'interaction_Adj_Pval']):
        print(f"  Network: Invalid data for {cell_type} ({tf_tool}). Skipping."); return
    print(f"--- Plotting TF-Pathway Network: {cell_type} ({tf_tool}) ---")
    plot_dir = os.path.join(output_dir, tf_tool, "figures")
    os.makedirs(plot_dir, exist_ok=True)

    plot_df = links_df.nsmallest(top_n_links, 'interaction_Adj_Pval').copy()
    if plot_df.empty: print(f"  Network: No links after filtering top {top_n_links}."); return

    plot_df['-logP_Overlap'] = -np.log10(plot_df['interaction_Adj_Pval'] + np.finfo(float).tiny)
    
    B = nx.Graph()
    tfs_in_plot = plot_df['TF'].unique()
    pathways_in_plot = plot_df['Pathway_Name'].unique()
    B.add_nodes_from(tfs_in_plot, bipartite=0, type='TF')
    B.add_nodes_from(pathways_in_plot, bipartite=1, type='Pathway')

    edge_weights_for_viz = []
    for _, row in plot_df.iterrows():
        B.add_edge(row['TF'], row['Pathway_Name'], weight=row['-logP_Overlap'])
        edge_weights_for_viz.append(row['-logP_Overlap'])
    
    if not B.edges(): print(f"  Network: No edges formed for {cell_type} ({tf_tool})."); return

    plt.figure(figsize=(12, 10))
    pos = nx.spring_layout(B, k=0.6, iterations=70, seed=42) 

    tf_nodes = [n for n, d in B.nodes(data=True) if d['bipartite']==0]
    path_nodes = [n for n, d in B.nodes(data=True) if d['bipartite']==1]
    
    node_degrees = [B.degree(n) for n in B.nodes()]
    min_degree = min(node_degrees) if node_degrees else 1
    max_degree = max(node_degrees) if node_degrees else 1
    node_sizes_map = {n: 100 + 1500 * (B.degree(n) - min_degree) / (max_degree - min_degree + 1e-6) if max_degree > min_degree else 100 + 1500 * 0.5 for n in B.nodes()}
    node_sizes_tf = [max(node_sizes_map.get(n, 100), 100) for n in tf_nodes]
    node_sizes_path = [max(node_sizes_map.get(n, 100), 100) for n in path_nodes]

    nx.draw_networkx_nodes(B, pos, nodelist=tf_nodes, node_color='skyblue', node_size=node_sizes_tf, alpha=0.9)
    nx.draw_networkx_nodes(B, pos, nodelist=path_nodes, node_color='lightgreen', node_size=node_sizes_path, alpha=0.9)
    
    if edge_weights_for_viz:
        min_w, max_w = min(edge_weights_for_viz), max(edge_weights_for_viz)
        norm_widths = [0.5 + 3*(w-min_w)/(max_w-min_w+1e-6) for w in edge_weights_for_viz] if max_w > min_w else [1]*len(edge_weights_for_viz)
    else: norm_widths = []

    nx.draw_networkx_edges(B, pos, width=norm_widths, alpha=0.4, edge_color='dimgray')
    nx.draw_networkx_labels(B, pos, font_size=6, font_weight='normal')
    
    plt.title(f'Top {min(top_n_links, len(plot_df))} TF-Pathway Links: {cell_type} ({tf_tool})', fontsize=14)
    plt.axis('off')
    tf_patch = plt.Line2D([0], [0], marker='o', color='w', label='TF', markersize=10, markerfacecolor='skyblue')
    path_patch = plt.Line2D([0], [0], marker='o', color='w', label='Pathway', markersize=10, markerfacecolor='lightgreen')
    plt.legend(handles=[tf_patch, path_patch], loc='best', fontsize=9)
    plt.tight_layout()
    for ext in ['svg','pdf']:
        fpath = os.path.join(plot_dir,f"{cell_type}_tf_path_network.{ext}")
        plt.savefig(fpath); print(f"  Saved Network plot: {fpath}")
    plt.close()

# --- DEG COVERAGE PLOTTING (Donut Plot Logic Revised) ---
def plot_deg_pies(deg_list, cell_type, tf_results, top_n=TOP_N_PIE):
    """Generates DEG coverage pie charts for a cell type, with revised logic."""
    global output_dir
    print(f"--- Generating DEG Coverage Pie Charts for {cell_type} ---")
    explained_stats_for_summary_bar = {} 
    if not deg_list:
        for method in tf_results.keys(): explained_stats_for_summary_bar[method] = 0.0
        return explained_stats_for_summary_bar

    deg_set = set(deg_list)
    total_deg_count = len(deg_set)

    for method, sig_tfs_df in tf_results.items():
        print(f"  Plotting for method: {method} ({cell_type})")
        plot_dir = os.path.join(output_dir, method, "figures")
        os.makedirs(plot_dir, exist_ok=True)
        
        plot_title = f'DEG Coverage: {cell_type} ({method})'
        pie_labels_final, pie_sizes_final_normalized = [], [] 

        if sig_tfs_df is None or sig_tfs_df.empty or not all(c in sig_tfs_df.columns for c in ['Term','Genes','Adjusted P-value']):
            explained_stats_for_summary_bar[method] = 0.0
            if total_deg_count > 0:
                pie_labels_final, pie_sizes_final_normalized = ['Not Explained'], [100.0]
                plot_title += '\n(No significant TFs or malformed TF data)'
            else: 
                print(f"    No DEGs for {cell_type}, skipping pie for method {method}.")
                continue
        else:
            sorted_tfs = sig_tfs_df.sort_values(by='Adjusted P-value')
            
            degs_explained_by_any_tf = set()
            for _, r_all_tf in sorted_tfs.iterrows():
                if pd.notna(r_all_tf['Genes']) and r_all_tf['Genes'] != '':
                    degs_explained_by_any_tf.update(deg_set.intersection(str(r_all_tf['Genes']).split(';')))
            
            num_total_explained_degs = len(degs_explained_by_any_tf)
            # This is the overall % of DEGs explained by this method, for the summary bar
            explained_stats_for_summary_bar[method] = (num_total_explained_degs / total_deg_count) * 100 if total_deg_count > 0 else 0.0
            
            # Slices for the pie chart (raw percentages of total DEGs)
            current_pie_slices_raw_pct = {} # Using dict to manage named slices before ordering

            # 1. Top N TF Slices
            degs_covered_by_top_n_explicitly_shown = set()
            for _, r_top_tf in sorted_tfs.head(top_n).iterrows():
                tf_name = r_top_tf['Term']
                tf_genes_str = str(r_top_tf.get('Genes', ''))
                if tf_genes_str:
                    degs_for_this_tf = deg_set.intersection(tf_genes_str.split(';'))
                    if degs_for_this_tf: 
                        pct_contribution = (len(degs_for_this_tf) / total_deg_count) * 100 if total_deg_count > 0 else 0.0
                        if pct_contribution > 0.01: 
                            current_pie_slices_raw_pct[tf_name] = pct_contribution
                            degs_covered_by_top_n_explicitly_shown.update(degs_for_this_tf)
            
            # 2. "Other Explained TFs" Slice
            degs_for_other_tfs_slice = degs_explained_by_any_tf - degs_covered_by_top_n_explicitly_shown
            percent_others_explained = (len(degs_for_other_tfs_slice) / total_deg_count) * 100 if total_deg_count > 0 else 0.0
            if percent_others_explained > 0.01:
                current_pie_slices_raw_pct['Other TFs'] = percent_others_explained
            
            # 3. "Not Explained" Slice
            num_not_explained = total_deg_count - num_total_explained_degs
            percent_not_explained = (num_not_explained / total_deg_count) * 100 if total_deg_count > 0 else 0.0
            if percent_not_explained > 0.01 or not current_pie_slices_raw_pct: # Add if meaningful or only category
                current_pie_slices_raw_pct['Not Explained'] = percent_not_explained
            
            plot_title = f'DEG Coverage by Top {top_n} TFs: {cell_type} ({method})'

            # Prepare final labels and sizes for plotting, ensuring order
            pie_labels_final = list(current_pie_slices_raw_pct.keys())
            pie_sizes_raw_pct_total_degs = list(current_pie_slices_raw_pct.values())

            # Normalize pie_sizes_raw_pct_total_degs to sum to 100% for the pie chart display
            current_sum_pie = sum(pie_sizes_raw_pct_total_degs)
            if current_sum_pie > 0:
                pie_sizes_final_normalized = [(s / current_sum_pie) * 100 for s in pie_sizes_raw_pct_total_degs]
            elif total_deg_count > 0 : 
                pie_labels_final, pie_sizes_final_normalized = ['Not Explained'], [100.0]
            else: # No DEGs and no slices
                pie_labels_final, pie_sizes_final_normalized = [],[]
        
        if not pie_labels_final: 
            print(f"    No data to plot in pie chart for {method}, {cell_type}.")
            continue

        cmap = plt.colormaps['tab20'] if len(pie_labels_final)<=20 else plt.colormaps['viridis']
        colors = cmap(np.linspace(0,1,len(pie_labels_final)))
        fig,ax = plt.subplots(figsize=(10,8))
        wedges,txs,auts = ax.pie(pie_sizes_final_normalized,autopct='%1.1f%%',startangle=140,colors=colors,wedgeprops={"ec":"w",'lw':.7,'antialiased':True},pctdistance=.85)
        for t in txs:t.set_fontsize(9)
        for at in auts:at.set_color('white');at.set_fontsize(8);at.set_fontweight('bold')
        ax.add_artist(plt.Circle((0,0),.7,fc='white')); ax.axis('equal'); ax.set_title(plot_title,fontsize=12)
        ax.legend(wedges,pie_labels_final,title="Categories",loc="center left",bbox_to_anchor=(1,0,.5,1),fontsize=9)
        plt.tight_layout(rect=[0,0,.85,1])
        for ext in ['svg','pdf']:
            fpath = os.path.join(plot_dir,f"{method}_{cell_type}_deg_cov_pie.{ext}") 
            plt.savefig(fpath); print(f"    Pie chart ({ext.upper()}) saved: {fpath}")
        plt.close(fig)
    return explained_stats_for_summary_bar

def plot_summary_bars(all_cov_stats, ordered_cts):
    global output_dir
    print("\n--- Generating Summary DEG Coverage Stacked Bars ---")
    if not all_cov_stats: print("No data for summary. Skipping."); return
    tf_methods = set(m for d in all_cov_stats.values() if isinstance(d,dict) for m in d.keys())
    if not tf_methods: print("No TF methods in data. Skipping."); return

    for method in sorted(list(tf_methods)):
        plot_dir = os.path.join(output_dir, method, "figures")
        os.makedirs(plot_dir, exist_ok=True)
        
        print(f"  Bar plot for: {method}")
        data = {ct:all_cov_stats.get(ct,{}).get(method,0.0) for ct in ordered_cts}
        if not data: continue
        df = pd.DataFrame.from_dict(data,orient='index',columns=['Explained (%)']).reindex(ordered_cts)
        df['Not Explained (%)'] = 100-df['Explained (%)']
        fig,ax = plt.subplots(figsize=(max(8,len(ordered_cts)*.7),7))
        ax.bar(df.index,df['Explained (%)'],label=f'Explained ({method})',color='skyblue')
        ax.bar(df.index,df['Not Explained (%)'],bottom=df['Explained (%)'],label='Not Explained',color='lightcoral')
        ax.set_xlabel('Cell Type'); ax.set_ylabel('% DEGs'); ax.set_title(f'DEG Coverage ({method})',fontsize=14)
        ax.tick_params(axis='x',rotation=45); ax.legend(loc='upper right'); ax.grid(axis='y',ls='--',alpha=.7)
        for label in ax.get_xticklabels():
            label.set_ha('right')
        plt.tight_layout()
        for ext in ['svg','pdf']:
            fpath = os.path.join(plot_dir,f"{method}_summary_bars.{ext}") 
            plt.savefig(fpath); print(f"    Stacked bar saved: {fpath}")
        plt.close(fig)

In [6]:
def plot_manhattan_ind(links_df, cell_type, tf_tool, save_path, sig_thr=SIGNIFICANCE_CUTOFF, n_lab=TOP_N_LABEL, lab_thr=LABEL_QVAL_THRESHOLD):
    """Creates Manhattan-style plot for TF-Pathway links, y-axis is interaction_Adj_Pval."""
    if links_df.empty or 'interaction_Adj_Pval' not in links_df.columns or 'Pathway_DB' not in links_df.columns:
        print(f"  Manhattan: Invalid data for {cell_type} ({tf_tool}). Missing 'interaction_Adj_Pval' or 'Pathway_DB'. Skipping."); return
    print(f"--- Plotting Manhattan: {cell_type} ({tf_tool}) ---")
    plot_dir = os.path.join(save_path, tf_tool, "figures")
    os.makedirs(save_path, exist_ok=True)

    df = links_df.copy(); 
    df['interaction_Adj_Pval'] = pd.to_numeric(df['interaction_Adj_Pval'],errors='coerce')
    df.dropna(subset=['interaction_Adj_Pval'],inplace=True)
    df_plot_ready = df.copy() 
    if df_plot_ready.empty: print(f"  Manhattan: No valid interaction_Adj_Pval values after coercion."); return
    
    df_plot_ready['-logP_Overlap'] = -np.log10(df_plot_ready['interaction_Adj_Pval'] + np.finfo(float).tiny) 
    df_plot_ready['Label'] = df_plot_ready['TF'] + " - " + df_plot_ready['Pathway_Name'] 
    df_plot_ready = df_plot_ready.sort_values(by=['Pathway_DB','Label']); df_plot_ready['x'] = range(len(df_plot_ready))
    
    dbs = sorted(df_plot_ready['Pathway_DB'].unique())
    cmap = plt.colormaps['tab10'] if len(dbs)<=10 else plt.colormaps['tab20']
    cols = cmap(np.linspace(0,1,len(dbs))); db_cols = {db:cols[i] for i,db in enumerate(dbs)}
    df_plot_ready['color'] = df_plot_ready['Pathway_DB'].map(db_cols)
    
    if 'N_Shared' in df_plot_ready.columns and df_plot_ready['N_Shared'].notna().all() and df_plot_ready['N_Shared'].nunique()>1:
        min_s,max_s=df_plot_ready['N_Shared'].min(),df_plot_ready['N_Shared'].max()
        df_plot_ready['size']=150*(df_plot_ready['N_Shared']-min_s+1e-6)/(max_s-min_s+1e-6); df_plot_ready['size']=np.maximum(20,df_plot_ready['size'])
    else: df_plot_ready['size']=50
        
    figsize_width = min(100, max(15, len(df_plot_ready) * 0.05)) # Capped width
    plt.figure(figsize=(figsize_width, 7))
    plt.axhline(-np.log10(sig_thr),c='r',ls='--',lw=1,label=f'Adj. P-val = {sig_thr}') 
    plt.scatter(df_plot_ready['x'],df_plot_ready['-logP_Overlap'],c=df_plot_ready['color'],s=df_plot_ready['size'],alpha=.7,ec='k',lw=.5)
    
    lab_df = df_plot_ready[df_plot_ready['interaction_Adj_Pval']<lab_thr].sort_values(by='-logP_Overlap',ascending=False).head(n_lab)
    txts = [plt.text(r['x'],r['-logP_Overlap'],f"{str(r['TF'])[:10]}...\n{str(r['Pathway_Name'])[:15]}...",fontsize=5,ha='center') for _,r in lab_df.iterrows()]
    if ADJUST_TEXT_AVAILABLE and txts: adjust_text(txts,arrowprops=dict(arrowstyle='-',color='grey',lw=.3))
    
    plt.xticks([]); plt.xlabel('TF-Pathway Links (Grouped by Pathway DB)',fontsize=9)
    plt.ylabel('-log10 (TF-Pathway Overlap Adj. P-value)',fontsize=9); 
    plt.title(f'TF-Pathway Links: {cell_type} ({tf_tool})',fontsize=11)
    plt.ylim(bottom=0); plt.grid(axis='y',ls='--',alpha=.6)
    hndls = [plt.Line2D([0],[0],marker='o',c='w',label=db,markerfacecolor=c,ms=7) for db,c in db_cols.items()]
    plt.legend(handles=hndls,title="Pathway DB",bbox_to_anchor=(1.02,1),loc='upper left',fontsize=7)
    plt.tight_layout(rect=[0,0,0.88,1])
    for ext in ['svg','pdf']:
        fpath = os.path.join(plot_dir,f"{tf_tool}_{cell_type}_tf_path_manhattan.{ext}")
        plt.savefig(fpath); print(f"  Saved Manhattan: {fpath}")
    plt.close()

def plot_tf_path_network_ind(links_df, cell_type, tf_tool, save_path, top_n_links=30):
    """Generates a bipartite network plot of TF-Pathway links."""
    if links_df.empty or not all(c in links_df.columns for c in ['TF', 'Pathway_Name', 'interaction_Adj_Pval']):
        print(f"  Network: Invalid data for {cell_type} ({tf_tool}). Skipping."); return
    print(f"--- Plotting TF-Pathway Network: {cell_type} ({tf_tool}) ---")
    plot_dir = os.path.join(save_path, tf_tool, "figures")
    os.makedirs(plot_dir, exist_ok=True)

    plot_df = links_df.nsmallest(top_n_links, 'interaction_Adj_Pval').copy()
    if plot_df.empty: print(f"  Network: No links after filtering top {top_n_links}."); return

    plot_df['-logP_Overlap'] = -np.log10(plot_df['interaction_Adj_Pval'] + np.finfo(float).tiny)
    
    B = nx.Graph()
    tfs_in_plot = plot_df['TF'].unique()
    pathways_in_plot = plot_df['Pathway_Name'].unique()
    B.add_nodes_from(tfs_in_plot, bipartite=0, type='TF')
    B.add_nodes_from(pathways_in_plot, bipartite=1, type='Pathway')

    edge_weights_for_viz = []
    for _, row in plot_df.iterrows():
        B.add_edge(row['TF'], row['Pathway_Name'], weight=row['-logP_Overlap'])
        edge_weights_for_viz.append(row['-logP_Overlap'])
    
    if not B.edges(): print(f"  Network: No edges formed for {cell_type} ({tf_tool})."); return

    plt.figure(figsize=(12, 10))
    pos = nx.spring_layout(B, k=0.6, iterations=70, seed=42) 

    tf_nodes = [n for n, d in B.nodes(data=True) if d['bipartite']==0]
    path_nodes = [n for n, d in B.nodes(data=True) if d['bipartite']==1]
    
    node_degrees = [B.degree(n) for n in B.nodes()]
    min_degree = min(node_degrees) if node_degrees else 1
    max_degree = max(node_degrees) if node_degrees else 1
    node_sizes_map = {n: 100 + 1500 * (B.degree(n) - min_degree) / (max_degree - min_degree + 1e-6) if max_degree > min_degree else 100 + 1500 * 0.5 for n in B.nodes()}
    node_sizes_tf = [max(node_sizes_map.get(n, 100), 100) for n in tf_nodes]
    node_sizes_path = [max(node_sizes_map.get(n, 100), 100) for n in path_nodes]

    nx.draw_networkx_nodes(B, pos, nodelist=tf_nodes, node_color='skyblue', node_size=node_sizes_tf, alpha=0.9)
    nx.draw_networkx_nodes(B, pos, nodelist=path_nodes, node_color='lightgreen', node_size=node_sizes_path, alpha=0.9)
    
    if edge_weights_for_viz:
        min_w, max_w = min(edge_weights_for_viz), max(edge_weights_for_viz)
        norm_widths = [0.5 + 3*(w-min_w)/(max_w-min_w+1e-6) for w in edge_weights_for_viz] if max_w > min_w else [1]*len(edge_weights_for_viz)
    else: norm_widths = []

    nx.draw_networkx_edges(B, pos, width=norm_widths, alpha=0.4, edge_color='dimgray')
    nx.draw_networkx_labels(B, pos, font_size=6, font_weight='normal')
    
    plt.title(f'Top {min(top_n_links, len(plot_df))} TF-Pathway Links: {cell_type} ({tf_tool})', fontsize=14)
    plt.axis('off')
    tf_patch = plt.Line2D([0], [0], marker='o', color='w', label='TF', markersize=10, markerfacecolor='skyblue')
    path_patch = plt.Line2D([0], [0], marker='o', color='w', label='Pathway', markersize=10, markerfacecolor='lightgreen')
    plt.legend(handles=[tf_patch, path_patch], loc='best', fontsize=9)
    plt.tight_layout()
    for ext in ['svg','pdf']:
        fpath = os.path.join(plot_dir,f"{cell_type}_tf_path_network.{ext}")
        plt.savefig(fpath); print(f"  Saved Network plot: {fpath}")
    plt.close()

# --- DEG COVERAGE PLOTTING (Donut Plot Logic Revised) ---
def plot_deg_pies_ind(deg_list, cell_type, sig_tfs_df, tf_tool, save_path, top_n=TOP_N_PIE):
    """Generates DEG coverage pie charts for a cell type, with revised logic."""
    global output_dir
    print(f"--- Generating DEG Coverage Pie Charts for {cell_type} ---")
    method = tf_tool
    explained_stats_for_summary_bar = {} 
    if not deg_list:
        explained_stats_for_summary_bar[method] = 0.0
        return explained_stats_for_summary_bar

    deg_set = set(deg_list)
    total_deg_count = len(deg_set)
    
    print(f"  Plotting for method: {method} ({cell_type})")
    plot_dir = os.path.join(save_path, method, "figures")
    os.makedirs(plot_dir, exist_ok=True)
    
    plot_title = f'DEG Coverage: {cell_type} ({method})'
    pie_labels_final, pie_sizes_final_normalized = [], [] 
    
    if sig_tfs_df is None or sig_tfs_df.empty or not all(c in sig_tfs_df.columns for c in ['Term','Genes','Adjusted P-value']):
        explained_stats_for_summary_bar[method] = 0.0
        if total_deg_count > 0:
            pie_labels_final, pie_sizes_final_normalized = ['Not Explained'], [100.0]
            plot_title += '\n(No significant TFs or malformed TF data)'
        else: 
            print(f"    No DEGs for {cell_type}, skipping pie for method {method}.")
    else:
        sorted_tfs = sig_tfs_df.sort_values(by='Adjusted P-value')
        
        degs_explained_by_any_tf = set()
        for _, r_all_tf in sorted_tfs.iterrows():
            if pd.notna(r_all_tf['Genes']) and r_all_tf['Genes'] != '':
                degs_explained_by_any_tf.update(deg_set.intersection(str(r_all_tf['Genes']).split(';')))
        
        num_total_explained_degs = len(degs_explained_by_any_tf)
        # This is the overall % of DEGs explained by this method, for the summary bar
        explained_stats_for_summary_bar[method] = (num_total_explained_degs / total_deg_count) * 100 if total_deg_count > 0 else 0.0
        
        # Slices for the pie chart (raw percentages of total DEGs)
        current_pie_slices_raw_pct = {} # Using dict to manage named slices before ordering

        # 1. Top N TF Slices
        degs_covered_by_top_n_explicitly_shown = set()
        for _, r_top_tf in sorted_tfs.head(top_n).iterrows():
            tf_name = r_top_tf['Term']
            tf_genes_str = str(r_top_tf.get('Genes', ''))
            if tf_genes_str:
                degs_for_this_tf = deg_set.intersection(tf_genes_str.split(';'))
                if degs_for_this_tf: 
                    pct_contribution = (len(degs_for_this_tf) / total_deg_count) * 100 if total_deg_count > 0 else 0.0
                    if pct_contribution > 0.01: 
                        current_pie_slices_raw_pct[tf_name] = pct_contribution
                        degs_covered_by_top_n_explicitly_shown.update(degs_for_this_tf)
        
        # 2. "Other Explained TFs" Slice
        degs_for_other_tfs_slice = degs_explained_by_any_tf - degs_covered_by_top_n_explicitly_shown
        percent_others_explained = (len(degs_for_other_tfs_slice) / total_deg_count) * 100 if total_deg_count > 0 else 0.0
        if percent_others_explained > 0.01:
            current_pie_slices_raw_pct['Other TFs'] = percent_others_explained
        
        # 3. "Not Explained" Slice
        num_not_explained = total_deg_count - num_total_explained_degs
        percent_not_explained = (num_not_explained / total_deg_count) * 100 if total_deg_count > 0 else 0.0
        if percent_not_explained > 0.01 or not current_pie_slices_raw_pct: # Add if meaningful or only category
            current_pie_slices_raw_pct['Not Explained'] = percent_not_explained
        
        plot_title = f'DEG Coverage by Top {top_n} TFs: {cell_type} ({method})'

        # Prepare final labels and sizes for plotting, ensuring order
        pie_labels_final = list(current_pie_slices_raw_pct.keys())
        pie_sizes_raw_pct_total_degs = list(current_pie_slices_raw_pct.values())

        # Normalize pie_sizes_raw_pct_total_degs to sum to 100% for the pie chart display
        current_sum_pie = sum(pie_sizes_raw_pct_total_degs)
        if current_sum_pie > 0:
            pie_sizes_final_normalized = [(s / current_sum_pie) * 100 for s in pie_sizes_raw_pct_total_degs]
        elif total_deg_count > 0 : 
            pie_labels_final, pie_sizes_final_normalized = ['Not Explained'], [100.0]
        else: # No DEGs and no slices
            pie_labels_final, pie_sizes_final_normalized = [],[]
    
    if not pie_labels_final: 
        print(f"    No data to plot in pie chart for {method}, {cell_type}.")

    cmap = plt.colormaps['tab20'] if len(pie_labels_final)<=20 else plt.colormaps['viridis']
    colors = cmap(np.linspace(0,1,len(pie_labels_final)))
    fig,ax = plt.subplots(figsize=(10,8))
    wedges,txs,auts = ax.pie(pie_sizes_final_normalized,autopct='%1.1f%%',startangle=140,colors=colors,wedgeprops={"ec":"w",'lw':.7,'antialiased':True},pctdistance=.85)
    for t in txs:t.set_fontsize(9)
    for at in auts:at.set_color('white');at.set_fontsize(8);at.set_fontweight('bold')
    ax.add_artist(plt.Circle((0,0),.7,fc='white')); ax.axis('equal'); ax.set_title(plot_title,fontsize=12)
    ax.legend(wedges,pie_labels_final,title="Categories",loc="center left",bbox_to_anchor=(1,0,.5,1),fontsize=9)
    plt.tight_layout(rect=[0,0,.85,1])
    for ext in ['svg','pdf']:
        fpath = os.path.join(plot_dir,f"{method}_{cell_type}_deg_cov_pie.{ext}") 
        plt.savefig(fpath); print(f"    Pie chart ({ext.upper()}) saved: {fpath}")
    plt.close(fig)
    return explained_stats_for_summary_bar

def plot_summary_bars_ind(all_cov_stats, ordered_cts, tf_tool, save_path):
    global output_dir
    print("\n--- Generating Summary DEG Coverage Stacked Bars ---")
    if not all_cov_stats: print("No data for summary. Skipping."); return
    tf_methods = set(m for d in all_cov_stats.values() if isinstance(d,dict) for m in d.keys())
    if not tf_methods: print("No TF methods in data. Skipping."); return
    method = tf_tool
   
    plot_dir = os.path.join(save_path, method, "figures")
    os.makedirs(plot_dir, exist_ok=True)
    
    print(f"  Bar plot for: {method}")
    data = {ct:all_cov_stats.get(ct,{}).get(method,0.0) for ct in ordered_cts}

    df = pd.DataFrame.from_dict(data,orient='index',columns=['Explained (%)']).reindex(ordered_cts)
    df['Not Explained (%)'] = 100-df['Explained (%)']
    fig,ax = plt.subplots(figsize=(max(8,len(ordered_cts)*.7),7))
    ax.bar(df.index,df['Explained (%)'],label=f'Explained ({method})',color='skyblue')
    ax.bar(df.index,df['Not Explained (%)'],bottom=df['Explained (%)'],label='Not Explained',color='lightcoral')
    ax.set_xlabel('Cell Type'); ax.set_ylabel('% DEGs'); ax.set_title(f'DEG Coverage ({method})',fontsize=14)
    ax.tick_params(axis='x',rotation=45); ax.legend(loc='upper right'); ax.grid(axis='y',ls='--',alpha=.7)
    for label in ax.get_xticklabels():
        label.set_ha('right')
    plt.tight_layout()
    for ext in ['svg','pdf']:
        fpath = os.path.join(plot_dir,f"{method}_summary_bars.{ext}") 
        plt.savefig(fpath); print(f"    Stacked bar saved: {fpath}")
    plt.close(fig)

# Run

In [None]:
# --- MAIN WORKFLOW ---
def main():
    global output_dir, DEG_dir, DEG_SYMBOL_COL, cell_types, PATHWAY_FILES, SIGNIFICANCE_CUTOFF, TOP_N_PIE
    
    deg_file_infos = []
    print(f"Constructing DEG file paths based on `cell_types` list and expected naming convention.")
    for ct in cell_types:
        # Construct filename directly using the original cell_type string
        deg_filename = f"cell_types_diffxpy_{ct}_diff_res.csv" 
        deg_filepath = os.path.join(DEG_dir, deg_filename)
        
        if os.path.exists(deg_filepath):
            deg_file_infos.append({'name': ct, 'path': deg_filepath})
            print(f"  Found DEG file for processing: {ct} at {deg_filepath}")
        else:
            print(f"  Warning: DEG file NOT FOUND for cell_type '{ct}' at expected path: {deg_filepath}")
            
    if not deg_file_infos: 
        print("CRITICAL: No DEG files found based on the `cell_types` list and expected naming. Please check paths and filenames. Exiting.")
        return

    all_ct_coverage_stats = {}
    
    for item in deg_file_infos:
        cell_type = item['name'] 
        print(f"\n\n<<<< Processing cell type: {cell_type} >>>>")
        
        try:
            deg_df = pd.read_csv(item['path'])
            if DEG_SYMBOL_COL not in deg_df.columns:
                print(f"  Error: Gene symbol column '{DEG_SYMBOL_COL}' missing in {item['path']}. Skipping."); continue
            current_degs = deg_df[DEG_SYMBOL_COL].dropna().astype(str).unique().tolist()
        except Exception as e: print(f"  Error loading DEGs for {cell_type}: {e}"); current_degs = []

        if not current_degs:
            print(f"  No DEGs for {cell_type}. Skipping."); all_ct_coverage_stats[cell_type] = {}; continue
        print(f"  Loaded {len(current_degs)} DEGs for {cell_type}.")

        ct_bg_genes = []
        # Construct background filename directly using original cell_type string
        bg_file = os.path.join(DEG_dir, f"cell_types_diffxpy_{cell_type}_res-table.csv")
        try:
            if os.path.exists(bg_file):
                bg_df = pd.read_csv(bg_file)
                if DEG_SYMBOL_COL in bg_df.columns:
                    ct_bg_genes = list(set(bg_df[DEG_SYMBOL_COL].dropna().astype(str)))
                    print(f"  Loaded {len(ct_bg_genes)} background genes for {cell_type} from {bg_file}")
                else: print(f"  Warning: Gene symbol column '{DEG_SYMBOL_COL}' not in background file {bg_file}.")
            else: print(f"  Warning: Background file {bg_file} not found for {cell_type}.")
        except Exception as e_bg: print(f"  Warning: Error loading background {bg_file}: {e_bg}")
        if not ct_bg_genes: print(f"  Warning: Using empty background for {cell_type} for tools that require it (g:Profiler will use its default).")

        tf_libs = ['ChEA_2022','TRRUST_Transcription_Factors_2019','ENCODE_and_ChEA_Consensus_TFs_from_ChIP-X','ARCHS4_TFs_Coexp','TRANSFAC_and_JASPAR_PWMs']
        sig_tfs_gseapy = run_gseapy_enrichment(current_degs, cell_type, tf_libs)

        tf_results = {'gseapy': sig_tfs_gseapy}
        ct_coverage_stats = plot_deg_pies(current_degs, cell_type, tf_results, top_n=TOP_N_PIE)
        all_ct_coverage_stats[cell_type] = ct_coverage_stats

        print(f"\n  --- Loading & Integrating Pathways for {cell_type} ---")
        ct_paths_list = [load_pathways(cell_type, pt_type) for pt_type in PATHWAY_FILES.keys()]
        try:
            ct_paths_df = pd.concat([df for df in ct_paths_list if df is not None and not df.empty], ignore_index=True) if ct_paths_list else pd.DataFrame()
        except Exception as e_ct: print(f"  Warning: No significant pathways for: {item}")

        if not ct_paths_df.empty:
            print(f"    Consolidated {len(ct_paths_df)} significant pathways for {cell_type}.")
            for tool, sig_tfs in tf_results.items():
                if sig_tfs is not None and not sig_tfs.empty:
                    links = integrate_tf_path(sig_tfs, ct_paths_df, cell_type, current_degs, tool, interaction_qval_cutoff=0.01, overlap_fraction_threshold=0.1)
                    if links is not None and not links.empty:
                        plot_manhattan(links, cell_type, tool)
                        plot_tf_path_network(links, cell_type, tool) 
        else: print(f"    No pathways loaded for {cell_type}, skipping TF-Pathway integration.")

    plot_summary_bars(all_ct_coverage_stats, cell_types)
    print("\nWorkflow complete. Check output directory for results.")

if __name__ == '__main__':
    main()


# Explorer data and visualize

In [8]:
tf_libs = ['ChEA_2022','TRRUST_Transcription_Factors_2019','ENCODE_and_ChEA_Consensus_TFs_from_ChIP-X','ARCHS4_TFs_Coexp','TRANSFAC_and_JASPAR_PWMs']

In [None]:
for lib in tf_libs:
    path = os.path.join(output_dir, "gseapy")
            
    tf_results = {}
    all_ct_coverage_stats = {}
    for ct in cell_types:
        #Load precalculated files
        deg_filepath = os.path.join(DEG_dir, f"cell_types_diffxpy_{ct}_diff_res.csv")
        deg_df = pd.read_csv(deg_filepath)
        current_degs = deg_df[DEG_SYMBOL_COL].dropna().astype(str).unique().tolist()

        tf = pd.read_csv(os.path.join(output_dir, 'gseapy', 'tables', f"gseapy_{ct}_sig_tf.csv"))
        link = pd.read_csv(os.path.join(output_dir, 'gseapy', 'tables', f"gseapy_{ct}_tf_path_links.csv"))
        
        #Subset to lib     
        tf_sub = tf[tf.Gene_set == lib]
        tf_results[ct] = tf_sub
        ct_coverage_stats = plot_deg_pies_ind(current_degs, ct, tf_sub, lib, path, top_n=TOP_N_PIE)
        all_ct_coverage_stats[ct] = ct_coverage_stats
        
        link_sub = link[link.TF.isin(list(tf_sub.Term))]
        plot_manhattan_ind(link_sub, ct, lib, path)
        plot_tf_path_network_ind(link_sub, ct, lib, path, top_n_links=100) 
        
        
    plot_summary_bars_ind(all_ct_coverage_stats, cell_types, lib, path)    

In [3]:
lib = 'ENCODE_and_ChEA_Consensus_TFs_from_ChIP-X'
dbs_to_count = ['Reactome', 'GO_BP', 'WP', 'KEGG']

# Create an empty list to store the summary data from each cell type
summary_data = []

# --- Analysis Loop ---
for ct in cell_types:
    try:
        # Define file paths
        tf_file = os.path.join(output_dir, 'gseapy', 'tables', f"gseapy_{ct}_sig_tf.csv")
        link_file = os.path.join(output_dir, 'gseapy', 'tables', f"gseapy_{ct}_tf_path_links.csv")

        # Read data
        tf = pd.read_csv(tf_file)
        link = pd.read_csv(link_file)

        # Filter TFs by the specified library
        tf_sub = tf[tf['Gene_set'] == lib]
        
        # Filter pathway links to match the significant TFs
        significant_tf_terms = list(tf_sub['Term'])
        link_sub = link[link['TF'].isin(significant_tf_terms)]

        # --- Pathway Counting ---
        num_tfs = len(tf_sub)
        
        # Base dictionary for the summary row
        row = {'Cell Type': ct, 'Significant TFs': num_tfs}

        if not link_sub.empty:
            # 1. Filter for unique pathways
            unique_pathways = link_sub.drop_duplicates(subset=['Pathway_ID'])
            
            # 2. Get the total count of unique pathways
            row['Total Unique Pathways'] = len(unique_pathways)
            
            # 3. Get the counts per database
            db_counts = unique_pathways['Pathway_DB'].value_counts()
            
            # Add counts for each specified DB, defaulting to 0 if not found
            for db in dbs_to_count:
                row[db] = db_counts.get(db, 0)
        else:
            # If no pathways, set all pathway counts to 0
            row['Total Unique Pathways'] = 0
            for db in dbs_to_count:
                row[db] = 0
        
        # Add the completed summary row to our list
        summary_data.append(row)

    except FileNotFoundError:
        print(f"⚠️ Warning: Files not found for cell type '{ct}'. Skipping.")
        # Create a row indicating the error
        error_row = {'Cell Type': ct, 'Significant TFs': 'File not found'}
        error_row.update({db: 'N/A' for db in ['Total Unique Pathways'] + dbs_to_count})
        summary_data.append(error_row)

# --- Final Table ---
# Convert the list of summaries into a single DataFrame
summary_df = pd.DataFrame(summary_data)

# Reorder columns for clarity
final_cols = ['Cell Type', 'Significant TFs', 'Total Unique Pathways'] + dbs_to_count
summary_df = summary_df[final_cols]

# Display the final summary table
summary_df

Unnamed: 0,Cell Type,Significant TFs,Total Unique Pathways,Reactome,GO_BP,WP,KEGG
0,NSC1a,80,477,167,282,15,13
1,NSC1b,84,652,215,400,16,21
2,NSC2a,85,848,289,492,30,37
3,NSC2b,80,442,199,222,10,11
4,Apop.-NSC,34,55,34,18,1,2
5,NCSC,0,0,0,0,0,0
6,Apop.-NCSC,0,0,0,0,0,0
7,Glial-precursors,0,0,0,0,0,0
8,Immature-neurons,0,0,0,0,0,0
9,bulk_like,88,1445,363,946,50,86
