In [1]:
import os
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio

wompwomp_vignettes_dir = os.path.abspath("")
output_dir = os.path.join(wompwomp_vignettes_dir, "output")

In [2]:
genes = ["Akr1c21", "Slc7a12"]
csv_path = os.path.join(output_dir, "Akr1c21_Slc7a12.csv")

In [None]:
if not os.path.exists(csv_path):
    raise FileNotFoundError(f"CSV file not found at {csv_path}")

df = pd.read_csv(csv_path)
df.head()

Unnamed: 0,cellID,Gene,individual,Sex,Genotype,Tissue,Celltype_final,umi_total
0,E4_F5_E11_Subpool_4_igvf_010,Akr1c21,41,Male,129S1J,Kidney,Kidney-smooth muscle cell,1.432402
1,B8_B10_G8_Subpool_4_igvf_010,Akr1c21,73,Male,NODJ,Kidney,Kidney-loop of Henle thick ascending limb epit...,1.268769
2,H2_H3_F11_Subpool_4_igvf_010,Akr1c21,47,Male,NZOJ,Kidney,Kidney-smooth muscle cell,1.130386
3,A7_A5_E3_Subpool_4_igvf_010,Akr1c21,24,Female,B6J,Kidney,Kidney-collecting duct beta-intercalated cell,0.554298
4,H2_D8_G8_Subpool_4_igvf_010,Akr1c21,47,Male,NZOJ,Kidney,Kidney-collecting duct beta-intercalated cell,1.560454


In [None]:
def hex_to_rgba(hex_color, alpha=0.6):
    hex_color = hex_color.lstrip("#")
    r, g, b = [int(hex_color[i:i+2], 16) for i in (0, 2, 4)]
    return f"rgba({r}, {g}, {b}, {alpha})"

def plot_data(df, groupings, plot_out_path = None, df_out_path = None):
    umi_group_sum = df.groupby(groupings)["umi_total"].sum().reset_index(name="value")
    if df_out_path is not None:
        umi_group_sum.to_csv(df_out_path, index=False)
    labels = pd.unique(umi_group_sum[groupings].values.ravel())
    label_to_idx = {label: i for i, label in enumerate(labels)}
    gene_colors = {genes[0]: 'rgba(213, 94, 0, 0.6)', genes[1]: 'rgba(86, 180, 233, 0.6)'}
    sources, targets, values, link_colors = [], [], [], []

    for i in range(len(groupings) - 1):
        cols = [groupings[i], groupings[i + 1]]
        if "Gene" not in cols:
            cols.append("Gene")
        temp = umi_group_sum.groupby(cols, as_index=False)["value"].sum()
        for _, row in temp.iterrows():
            sources.append(label_to_idx[row[groupings[i]]])
            targets.append(label_to_idx[row[groupings[i + 1]]])
            values.append(row["value"])
            link_colors.append(gene_colors.get(row["Gene"], 'rgba(200,200,200,0.4)'))

    fig = go.Figure(data=[go.Sankey(
        node=dict(
            pad=15,
            thickness=20,
            label=list(label_to_idx.keys()),
            color='lightblue'
        ),
        link=dict(
            source=sources,
            target=targets,
            value=values,
            color=link_colors
        )
    )])

    fig.update_layout(
        # title_text='UMI-Weighted Alluvial Plot',
        font_size=10,
        height=1700,
        width=1500,  
    )

    fig.write_image(plot_out_path)

In [5]:
groupings = ["Gene", "individual", "Sex", "Genotype", "Tissue", "Celltype_final"]
plot_data(df, groupings, df_out_path = os.path.join(output_dir, "Akr1c21_Slc7a12_grouped.csv"), plot_out_path = os.path.join(output_dir, "Akr1c21_Slc7a12_sankey_plot.pdf"))

In [6]:
groupings = ["Gene", "Sex", "Genotype", "Tissue"]
plot_data(df, groupings, plot_out_path = os.path.join(output_dir, "Akr1c21_Slc7a12_sankey_plot_fixed_columns.pdf"))

In [9]:
groupings = ["Gene", "Sex", "Tissue", "Genotype"]
plot_data(df, groupings, plot_out_path = os.path.join(output_dir, "Akr1c21_Slc7a12_sankey_plot_optimized_columns.pdf"))