In [1]:
import pandas as pd

In [2]:
geneset = pd.read_csv("/Users/polina/genetics_gsea/data/geneset_annotated/geneset_reactome_2025_disease_v2.csv")

In [3]:
pathways = geneset[["Term", "ID"]].drop_duplicates().reset_index(drop=True).copy()

In [4]:
pathways = pathways.assign(
    Term=pathways["Term"].str.split(";"),
    ID=pathways["ID"].str.split(";")
)

pathways_exploded = pathways.explode(["Term", "ID"]).drop_duplicates().reset_index(drop=True)

In [5]:
pathways_exploded

Unnamed: 0,Term,ID
0,Signal Transduction,R-HSA-162582
1,Generic Transcription Pathway,R-HSA-212436
2,RNA Polymerase II Transcription,R-HSA-73857
3,Gene expression (Transcription),R-HSA-74160
4,SMAD2/SMAD3:SMAD4 heterotrimer regulates trans...,R-HSA-2173796
...,...,...
403,Serotonin receptors,R-HSA-390666
404,Glycogen storage diseases,R-HSA-3229121
405,SIRT1 negatively regulates rRNA expression,R-HSA-427359
406,Alpha-defensins,R-HSA-1462054


In [6]:
hieararchy = pd.read_csv("/Users/polina/genetics_gsea/data/gmt/Reactome_2025/Pathways_hierarchy_relationship.txt", sep="\t", header=None)
hieararchy.columns = ["parentId", "childId"]

In [7]:
pathway_hieararchy = pd.merge(pathways_exploded, hieararchy, left_on="ID", right_on="childId", how="left")[["ID", "parentId", "Term"]]

In [8]:
pathway_hieararchy

Unnamed: 0,ID,parentId,Term
0,R-HSA-162582,,Signal Transduction
1,R-HSA-212436,R-HSA-73857,Generic Transcription Pathway
2,R-HSA-73857,R-HSA-74160,RNA Polymerase II Transcription
3,R-HSA-74160,,Gene expression (Transcription)
4,R-HSA-2173796,R-HSA-2173793,SMAD2/SMAD3:SMAD4 heterotrimer regulates trans...
...,...,...,...
408,R-HSA-390666,R-HSA-375280,Serotonin receptors
409,R-HSA-3229121,R-HSA-5663084,Glycogen storage diseases
410,R-HSA-427359,R-HSA-5250941,SIRT1 negatively regulates rRNA expression
411,R-HSA-1462054,R-HSA-1461973,Alpha-defensins


In [9]:
gsea_disease = pd.read_csv("/Users/polina/genetics_gsea/data/gsea/from_database/disease_zscore_reactome_2025.tsv", sep="\t")

In [10]:
pathway_hieararchy_gsea = pd.merge(pathway_hieararchy, gsea_disease, left_on="ID", right_on="ID", how="left")[["ID", "parentId", "Term_x", "nes", "fdr"]].rename(columns={"Term_x": "Term"})

In [17]:
import numpy as np
import plotly.express as px

In [117]:
import pandas as pd
import plotly.express as px

def plot_sunburst(df):
    df = df.copy()
    
    # fill NA with ROOT
    df["parentId"] = df["parentId"].fillna("ROOT")
    
    # redirect missing parents to ROOT
    valid_ids = set(df["ID"])
    df.loc[~df["parentId"].isin(valid_ids) & (df["parentId"] != "ROOT"), "parentId"] = "ROOT"
    
    # Smart text wrapping and orientation determination (kept for wrapping only)
    def wrap_text_and_orient(row):
        text = row["Term"]
        parent_id = row["parentId"]
        
        if text == "Reactome pathways":
            return text, "horizontal"  # center label
        
        if parent_id == "ROOT":
            orientation = "radial"
        else:
            orientation = "tangential"
        
        # Apply text wrapping
        if parent_id == "ROOT":
            if len(text) <= 10:
                wrapped_text = text
            elif len(text) <= 10:
                words = text.split()
                if len(words) >= 2:
                    mid = len(words) // 2
                    wrapped_text = " ".join(words[:mid]) + "<br>" + " ".join(words[mid:])
                else:
                    wrapped_text = text
            else:
                words = text.split()
                if len(words) >= 3:
                    third = len(words) // 3
                    wrapped_text = " ".join(words[:third]) + "<br>" + " ".join(words[third:2*third]) + "<br>" + " ".join(words[2*third:])
                elif len(words) == 2:
                    wrapped_text = words[0] + "<br>" + words[1]
                else:
                    wrapped_text = text
        else:
            if len(text) <= 20:
                wrapped_text = text
            elif len(text) <= 30:
                words = text.split()
                if len(words) >= 2:
                    mid = len(words) // 2
                    wrapped_text = " ".join(words[:mid]) + "<br>" + " ".join(words[mid:])
                else:
                    wrapped_text = text
            else:
                words = text.split()
                if len(words) >= 3:
                    third = len(words) // 3
                    wrapped_text = " ".join(words[:third]) + "<br>" + " ".join(words[third:3*third]) + "<br>" + " ".join(words[3*third:])
                elif len(words) == 9:
                    fourth = len(words) // 9
                    wrapped_text = " ".join(words[:fourth]) + "<br>" + " ".join(words[fourth:9*fourth]) + "<br>" + " ".join(words[9*fourth:])
                else:
                    wrapped_text = text
        
        return wrapped_text, orientation
    
    wrapped_data = df.apply(wrap_text_and_orient, axis=1, result_type='expand')
    df["wrapped_term"] = wrapped_data[0]
    df["text_orientation"] = wrapped_data[1]
    
    # add root node
    root = pd.DataFrame([{
        "ID": "ROOT", 
        "parentId": "", 
        "wrapped_term": "Reactome pathways", 
        "Term": "Reactome pathways", 
        "nes": 0,
        "text_orientation": "horizontal"
    }])
    df = pd.concat([df, root], ignore_index=True)

    # Colorblind-friendly diverging scale
    color_scale = ['#0571b0','#92c5de','#f7f7f7','#f4a582','#ca0020']
    
    # Per-node text colors: dark grey for center/root, white otherwise
    text_colors = ['#2F4F4F' if term == "Reactome pathways" else 'white' for term in df["Term"]]
    
    fig = px.sunburst(
        df,
        names="wrapped_term",
        ids="ID",
        parents="parentId",
        values=None,
        color="nes",
        color_continuous_scale=color_scale,
        color_continuous_midpoint=0,
        branchvalues='total',
        width=1100,
        height=800
    )
    
    fig.update_layout(
        margin=dict(t=50, l=50, r=50, b=50),
        coloraxis_colorbar=dict(
            title=dict(text="NES", font=dict(size=16)),
            tickvals=[df["nes"].min(), 0, df["nes"].max()],
            tickfont=dict(size=14),
            ticks="outside",
            thickness=20,
            len=0.6
        ),
        font=dict(size=14)
    )

    # Let Plotly choose radial/tangential per sector; center stays horizontal.
    fig.update_traces(
        insidetextorientation='auto',
        textfont=dict(size=14, family="Arial", color=text_colors),
        hovertemplate='<b>%{customdata[0]}</b><br>NES: %{color:.2f}<extra></extra>',
        customdata=df[["Term"]],
        textinfo='label',
        texttemplate='<b>%{label}</b>'
    )
    
    return fig


In [118]:
import pandas as pd

# Filter pathways by FDR
pathway_hieararchy_gsea_filt = pathway_hieararchy_gsea[pathway_hieararchy_gsea["fdr"] < 0.005].copy()

# Fill missing parentId with ROOT
pathway_hieararchy_gsea_filt["parentId"] = pathway_hieararchy_gsea_filt["parentId"].fillna("ROOT")

# Recursively remove rows where parentId is not in ID column (except for ROOT)
def filter_orphaned_nodes(df):
    valid_ids = set(df["ID"])
    
    # Keep only rows where parentId is in valid IDs or is ROOT
    filtered_df = df[df["parentId"].isin(valid_ids.union({"ROOT"}))].copy()
    
    # If no changes were made, we're done
    if len(filtered_df) == len(df):
        return filtered_df
    
    # Otherwise, continue filtering recursively
    return filter_orphaned_nodes(filtered_df)

# Apply the recursive filtering
pathway_hieararchy_gsea_filt_vis = filter_orphaned_nodes(pathway_hieararchy_gsea_filt)

# Add new row for 'Other' using pd.concat
other_row = pd.DataFrame([{"ID": "Other", "Term": "Other", "nes": -5, "parentId": "ROOT"}])
pathway_hieararchy_gsea_filt_vis = pd.concat([pathway_hieararchy_gsea_filt_vis, other_row], ignore_index=True)

# Optional: save to CSV
# pathway_hieararchy_gsea_filt_vis.to_csv("/Users/polina/genetics_gsea/data/sunburst/pathway_hieararchy_gsea_filt.csv", index=False)

# If you want to exclude the 'Other' category from the final result
pathway_hieararchy_gsea_filt_vis_no_other = pathway_hieararchy_gsea_filt_vis[
    (pathway_hieararchy_gsea_filt_vis["parentId"] != "Other")
]

In [119]:
pathway_hieararchy_gsea_filt_vis_no_other

Unnamed: 0,ID,parentId,Term,nes,fdr
0,R-HSA-162582,ROOT,Signal Transduction,5.879025,7e-06
1,R-HSA-212436,R-HSA-73857,Generic Transcription Pathway,4.456707,0.000382
2,R-HSA-73857,R-HSA-74160,RNA Polymerase II Transcription,4.202785,0.000634
3,R-HSA-74160,ROOT,Gene expression (Transcription),4.022728,0.000928
4,R-HSA-1280215,R-HSA-168256,Cytokine Signaling in Immune system,4.523809,0.000357
5,R-HSA-168256,ROOT,Immune System,4.273209,0.000557
6,R-HSA-1266738,ROOT,Developmental Biology,4.117902,0.000737
7,R-HSA-1643685,ROOT,Disease,3.556823,0.003337
8,R-HSA-449147,R-HSA-1280215,Signaling by Interleukins,3.41002,0.004928
9,R-HSA-8963743,ROOT,Digestion and absorption,-3.887837,0.001404


In [120]:
fig = plot_sunburst(pathway_hieararchy_gsea_filt_vis_no_other)
fig.show()