In [54]:
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots

In [146]:
df = pd.read_excel("~/local/populations.xlsx")
for col in ["hapA", "hapB", "none"]:
    df[f"{col}_f"] = df[col] / df["sum"]
df["cohort2"] = df["cohort"] + " [N=" + df["sum"].astype(str) + "]"
pats = (
    df.melt(id_vars=["phenotype", "cohort"], value_vars=["hapA_f", "hapB_f", "none_f"]).iloc[::-1]
)
pats["variable"] = pats["variable"].str.replace("_f", "")
# custom order the data
pats["variable"] = pd.Categorical(
    pats["variable"], categories=["hapA", "hapB", "none"], ordered=True
)
pats = pd.concat([pats[pats["phenotype"] == "aFTLD-U"].sort_values(by="variable"), pats[pats["phenotype"] == "Controls"]])

In [192]:
fig = px.bar(pats,
             x="value", y="cohort", color="variable", orientation="h", 
             color_discrete_map={"hapA": "red", "hapB": "orange", "none": "green"},
             labels={"value": "Proportion", "variable": "Haplotype", "cohort": "Cohort"},
             facet_col="phenotype",
             facet_col_spacing=0.03,
             )
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1], font=dict(size=24), xanchor="left"))

fig.update_layout(
    title="",
    xaxis_title="",
    xaxis2_title="",
    yaxis_title="Cohort",
    title_x=0.5,
    plot_bgcolor="white",
    font=dict(size=16),
    margin=dict(l=0, r=0, b=0),
) 

# put the legend horizontal, centered, below the plotting area
fig.update_layout(
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=-0.15,
        xanchor="center",
        x=0.5,
        title="Carrier frequency",
    )
)

# hide x-axis labels
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(
    ticks="outside",
    tickcolor="white",  # adjust color of the tick
)

pats_vals = df[df["phenotype"] == "aFTLD-U"].copy()
pats_vals["hapA_loc"] = pats_vals["hapA_f"] / 2  # center of hapA block
pats_vals["hapB_loc"] = (
    pats_vals["hapA_f"] + pats_vals["hapB_f"] / 2
)  # center of hapB block
pats_vals["none_loc"] = (
    pats_vals["hapA_f"] + pats_vals["hapB_f"] + pats_vals["none_f"] / 2
)  # center of none block

for i, row in pats_vals.iterrows():
    fig.add_annotation(
        x=row["hapA_loc"],
        y=row["cohort"],
        text=f"{row['hapA']}",
        showarrow=False,
        font=dict(color="white", size=16),
        col=1, row=1
    )
    fig.add_annotation(
        x=row["hapB_loc"],
        y=row["cohort"],
        text=f"{row['hapB']}",
        showarrow=False,
        font=dict(color="white", size=16),
        col=1, row=1
    )
    fig.add_annotation(
        x=row["none_loc"],
        y=row["cohort"],
        text=f"{row['none']}",
        showarrow=False,
        font=dict(color="white", size = 16),
        col=1, row=1
    )

    fig.add_annotation(
        x=-0.1,
        y=row["cohort"],
        text=f"N={row['sum']}",
        font=dict(color="black", size=16),
        col=1, row=1,
        showarrow=False,
    )

con_vals = df[df["phenotype"] == "Controls"].copy()
con_vals["hapA_loc"] = con_vals["hapA_f"] / 2  # center of hapA block
con_vals["hapB_loc"] = (
    con_vals["hapA_f"] + con_vals["hapB_f"] / 2
)  # center of hapB block
con_vals["none_loc"] = (
    con_vals["hapA_f"] + con_vals["hapB_f"] + con_vals["none_f"] / 2
)  # center of none block

for i, row in con_vals.iterrows():
    fig.add_annotation(
        x=row["hapA_loc"],
        y=row["cohort"],
        text=f"{row['hapA']}",
        showarrow=True,
        arrowhead=1,
        font=dict(color="black", size=16),
        yshift=0,
        ax=-20,
        ay=-15,
        col=2, row=1
    )
    fig.add_annotation(
        x=row["hapB_loc"],
        y=row["cohort"],
        text=f"{row['hapB']}",
        showarrow=False,
        font=dict(color="white", size=16),
        textangle=-60, 
        col=2, row=1
    )
    fig.add_annotation(
        x=row["none_loc"],
        y=row["cohort"],
        text=f"{row['none']}",
        showarrow=False,
        font=dict(color="white", size=16),
        col=2, row=1
    )

    fig.add_annotation(
        x=-0.16,
        y=row["cohort"],
        text=f"N={row['sum']}",
        font=dict(color="black", size=16),
        col=2, row=1,
        showarrow=False,
    )



fig.show()