In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import arviz as az
from cycler import cycler

from sphincter.data_preparation import load_prepared_data
from sphincter.plotting import plot_obs, plot_predictive, save_figure


CMAP = plt.get_cmap('Set2')

plt.rcParams['axes.prop_cycle'] = cycler(color=CMAP.colors)

In [None]:
abbreviations = {
    "post_ablation": "post abl.",
}

def format_name(name: str):
    return name.replace("_", " ").capitalize()


In [None]:
# def forestplot(ax, ts, xlabel="Value of test statistic"):
#     az.plot_forest(ts, ax=ax, combined=True, textsize=12, linewidth=3, hdi_prob=0.95);
#     ax.axvline(0.0, linestyle="--", color="black");
#     xlow, xhigh = ax.get_xlim();
#     xbiggest = max(abs(xlow), abs(xhigh))
#     ax.set_xlim(-xbiggest, xbiggest);
#     ax.set(title="", xlabel=xlabel);
#     return ax

def forestplot(ax, ts, xlabel="Value of test statistic", qlow=0.025, qhigh=0.975):
    ylimlow, ylimhigh = ax.get_ylim()
    ytickys = np.linspace(ylimlow, ylimhigh, len(ts)+2)
    ys = ytickys[1:-1]
    xlows = [np.quantile(t, qlow) for t in ts.values()]
    xhighs = [np.quantile(t, qhigh) for t in ts.values()]
    xmeans = [np.mean(t) for t in ts.values()]
    xbiggest = max(np.abs(xlows + xhighs)) + 0.1
    ax.set_xlim(-xbiggest, xbiggest)
    for y, xlow, xhigh, xmean in zip(ys, xlows, xhighs, xmeans):
        line = ax.hlines(y=y, xmin=xlow, xmax=xhigh, linewidth=2)
        ax.plot(xmean, y, marker="o", color=line.get_colors()[0])
    ax.set_yticks(ytickys, [""] + list(ts.keys()) + [""])
    # az.plot_forest(ts, ax=ax, combined=True, textsize=12, linewidth=3, hdi_prob=0.95);
    ax.axvline(0.0, linestyle="--", color="black");
    ax.set(title="", xlabel=xlabel);
    ax.tick_params(axis='y', which="both", left=False, right=False)
    return ax

### MAP, PP, HR (Figure 1)

In [None]:
idata_f1 = az.InferenceData.from_zarr(os.path.join("..", "inferences", "pressure", "idata"))

In [None]:
idata_f1

In [None]:
idata_f1.posterior.coords["treatment"].values

In [None]:
treatment_to_compare = {
    "hyper1": "baseline",
    "after_hyper1": "baseline",
    "ablation": "after_hyper1",
    "hyper2": "ablation",
}


for measurement_type in ["map", "pp", "hr"]:
    vcb = [
        ("a_age", "age", "adult"),
        ("a_treatment", "treatment", "baseline")
    ]
    age_ts = {
        "Difference in age effect: Adult - Old": idata_f1.posterior["a_age"].sel(measurement_type=measurement_type, age="adult") 
        - idata_f1.posterior["a_age"].sel(measurement_type=measurement_type, age="old")
    }
    treatment_ts = {
        f"Difference in treatment effect: {format_name(treatment)} - {format_name(treatment_to_compare[treatment])}": idata_f1.posterior["a_treatment"].sel(measurement_type=measurement_type, treatment=treatment) 
        - idata_f1.posterior["a_treatment"].sel(measurement_type=measurement_type, treatment=treatment_to_compare[treatment])
        for treatment in idata_f1.posterior.coords["treatment"].values if treatment != "baseline"
    }
    ts = age_ts | treatment_ts
    
    f, ax = plt.subplots(figsize=[8, 4])
    forestplot(ax, ts);
    save_figure(f, f"supporting-f1-effects-{measurement_type}")

In [None]:
prepared_data_f1 = load_prepared_data("../data/prepared/pressure.json")
msts_f1 = prepared_data_f1.measurements

ylabels = [
    "Mean arterial pressure (mmHg)",
    "Pulse pressure (mmHg)",
    "Heart rate (Hz)"
]
yrep = idata_f1.posterior_predictive["yrep"]

for col, ylabel in zip(["map", "pp", "hr"], ylabels):
    f, ax = plt.subplots(figsize=[8, 5])
    plot_obs(ax, msts_f1[col], cat=msts_f1["treatment"]);
    plot_predictive(ax, yrep.sel(measurement_type=col), zorder=-1, cat=msts_f1["treatment"], label="model")
    ax.set(ylabel=ylabel)
    ax.set_xticks([])
    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(map(format_name, labels), handles))
    f.legend(by_label.values(), by_label.keys(), frameon=False, ncol=3)
    ax.semilogy()
    #f.suptitle("Pressure measurements with posterior predictive intervals");
    save_figure(f, f"supporting-f1-ppc-{col}")


### Whisker stimulation (Figure 2)

In [None]:
prepared_data_f2 = load_prepared_data("../data/prepared/whisker.json")
msts_f2 = prepared_data_f2.measurements

idata_f2 = az.InferenceData.from_zarr(os.path.join("..", "inferences", "whisker-ind", "idata"))
idata_f2_big = az.InferenceData.from_zarr(os.path.join("..", "inferences", "whisker-big", "idata"))
idata_f2_big

In [None]:
t = idata_f2.posterior["mu"].sel(age="adult") - idata_f2.posterior["mu"].sel(age="old")
f, ax = plt.subplots(figsize=[8, 3])
forestplot(ax, {"Difference in age effect: adult - old": t});

In [None]:
diff = {
    f"Difference in vessel type effect: {format_name(vt)} - {format_name('pen_art')}": (
        idata_f2.posterior["a_vessel_type"].sel(vessel_type=vt) 
        - idata_f2.posterior["a_vessel_type"].sel(vessel_type="pen_art")
    )
    for vt in ["sphincter", "bulb", "cap1", "cap2"]
}
f, ax = plt.subplots(figsize=[8, 5])
forestplot(ax, diff);

save_figure(f, "supporting-f2-effects-vesseltype")

In [None]:
f, ax = plt.subplots(figsize=[6, 10])
t = {
    f"Interaction effect {format_name(vt)}:{format_name(t)}": (
        idata_f2_big.posterior["a_vessel_type_treatment"].sel(vessel_type=vt, treatment=t)
    )
    for vt in idata_f2_big.posterior.coords["vessel_type"].values
    for t in idata_f2_big.posterior.coords["treatment"].values
}
forestplot(ax, t);
save_figure(f, "supporting-f2-effects-vesseltype-treatment")

In [None]:
f, ax = plt.subplots(figsize=[6, 5])
t = {
    f"Interaction effect {format_name(age)}:{format_name(t)}": (
        idata_f2_big.posterior["a_age_treatment"].sel(age=age, treatment=t)
    )
    for age in idata_f2_big.posterior.coords["age"].values
    for t in idata_f2_big.posterior.coords["treatment"].values
}
forestplot(ax, t);
save_figure(f, "supporting-f2-effects-age-treatment")

In [None]:
f, ax = plt.subplots(figsize=[12, 5])
plot_obs(ax, msts_f2["diam_log_ratio"], cat=msts_f2["treatment"])
plot_predictive(ax, idata_f2.posterior_predictive["yrep"], cat=msts_f2["treatment"], zorder=-1, label="model")
ax.set(ylabel="Whisker response")
ax.set_xticks([])
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(map(format_name, labels), handles))
ax.legend(by_label.values(), by_label.keys(), frameon=False)
save_figure(f, "supporting-f2-ppc")

In [None]:
comparison = az.compare({"big": idata_f2_big, "ind": idata_f2})
comparison

## Figure 3

In [None]:
prepped_f3 = load_prepared_data("../data/prepared/hypertension.json")
msts_f3 = prepped_f3.measurements
idata_f3_big = az.InferenceData.from_zarr(os.path.join("..", "inferences", "hypertension-big", "idata"))
idata_f3_basic = az.InferenceData.from_zarr(os.path.join("..", "inferences", "hypertension-basic", "idata"))
idata_f3_big

In [None]:
msts_f3

In [None]:
catcol = "treatment"
f, ax = plt.subplots(figsize=[12, 5])
plot_obs(ax, msts_f3["atanh_corr_bp_diam"], cat=msts_f3[catcol])
plot_predictive(ax, idata_f3_basic.posterior_predictive["yrep"], cat=msts_f3[catcol], zorder=-1, label="model")
ax.set(ylabel="$tan^{-1}(correlation\\ coefficient)$")
ax.set_xticks([])
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(map(format_name, labels), handles))
ax.legend(by_label.values(), by_label.keys(), frameon=False)
save_figure(f, "supporting-f3-ppc")

In [None]:
# t = {
#     "Age effect difference: Adult - Old": (
#         idata_f3_basic.posterior["mu"].sel(age="adult")
#         - idata_f3_basic.posterior["mu"].sel(age="old")
#     )
# }
# f, ax = plt.subplots(figsize=[8, 3])
# forestplot(ax, t);
# save_figure(f, "supporting-f3-effects-age")

# t = {
#     f"Treatment effect difference: {format_name('hyper2')} - {format_name('hyper1')}": (
#         idata_f3_basic.posterior["a_treatment"].sel(treatment="hyper2")
#         - idata_f3_basic.posterior["a_treatment"].sel(treatment="hyper1")
#     )
# }
# f, ax = plt.subplots(figsize=[8, 3])
# forestplot(ax, t);
# save_figure(f, "supporting-f3-effects-hyper1hyper2")

In [None]:
t = {
    f"Measurement error parameter: {format_name(vt)}": idata_f3_basic.posterior["sigma"].sel(vessel_type=vt) 
    for vt in ["sphincter", "bulb", "cap1", "cap2"]
}
f, ax = plt.subplots(figsize=[8, 5])
forestplot(ax, t);
ax.set_xlim(0, ax.get_xlim()[1])

save_figure(f, "supporting-f3-sds")

## figure 4

In [None]:
idata_f4_basic_speed = az.from_zarr(os.path.join("..", "inferences", "flow-basic-speed", "idata"))
idata_f4_basic_flux = az.from_zarr(os.path.join("..", "inferences", "flow-basic-flux", "idata"))
idata_f4_big_speed = az.from_zarr(os.path.join("..", "inferences", "flow-big-speed", "idata"))
idata_f4_big_flux = az.from_zarr(os.path.join("..", "inferences", "flow-big-flux", "idata"))

prepped_f4_speed = load_prepared_data("../data/prepared/flow-speed.json")
prepped_f4_flux = load_prepared_data("../data/prepared/flow-flux.json")

msts_f4_speed = prepped_f4_speed.measurements
msts_f4_flux = prepped_f4_flux.measurements

idata_f4_basic_speed

In [None]:
raw = pd.read_csv("../data/raw/data_sphincter_paper.csv")

raw.groupby("vessel").apply(lambda subdf: subdf["flux"].notnull().sum())

In [None]:
msts_f4_speed

In [None]:
catcol = "vessel_type"
f, ax = plt.subplots(figsize=[12, 5])
plot_obs(ax, msts_f4_speed["speed"], cat=msts_f4_speed[catcol])
plot_predictive(ax, idata_f4_basic_speed.posterior_predictive["yrep"], cat=msts_f4_speed[catcol], zorder=-1, label="model")
ax.set(ylabel="Speed")
ax.set_xticks([])
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(map(format_name, labels), handles))
ax.legend(by_label.values(), by_label.keys(), frameon=False)
ax.semilogy()
save_figure(f, "supporting-f4-ppc-speed")

In [None]:
catcol = "vessel_type"
f, ax = plt.subplots(figsize=[12, 5])
plot_obs(ax, msts_f4_flux["flux"], cat=msts_f4_flux[catcol])
plot_predictive(ax, idata_f4_basic_flux.posterior_predictive["yrep"], cat=msts_f4_flux[catcol], zorder=-1, label="model")
ax.set(ylabel="Flux")
ax.set_xticks([])
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(map(format_name, labels), handles))
ax.legend(by_label.values(), by_label.keys(), frameon=False)
ax.semilogy()
save_figure(f, "supporting-f4-ppc-flux")

In [None]:
speed_comp = az.compare({"big": idata_f4_big_speed, "basic": idata_f4_basic_speed})
flux_comp = az.compare({"big": idata_f4_big_flux, "basic": idata_f4_basic_flux})

cols = ["elpd_loo", "se", "elpd_diff", "dse"]
display(speed_comp[cols])
display(flux_comp[cols])

In [None]:
idata_f4_basic_flux.posterior.coords["vessel_type"].values

In [None]:
comparison = "cap5"
vts = {
    "flux": ["sphincter", "cap2", "cap3", "cap4", "cap5"],
    "speed": ["sphincter", "cap1", "cap2", "cap3", "cap4", "cap5"],
}
t = {
    f"Difference in vessel type effect on RBC flux: {vt} - {format_name(comparison)}": (
        idata_f4_basic_flux.posterior["a_vessel_type"].sel(vessel_type=vt)
        - idata_f4_basic_flux.posterior["a_vessel_type"].sel(vessel_type=comparison)
    )
    for vt in vts["flux"] if vt != comparison

} | {
     f"Difference in vessel type effect on RBC speed: {vt} - {format_name(comparison)}": (
        idata_f4_basic_speed.posterior["a_vessel_type"].sel(vessel_type=vt)
        - idata_f4_basic_speed.posterior["a_vessel_type"].sel(vessel_type=comparison)
    )
    for vt in vts["speed"] if vt != comparison
}
f, ax = plt.subplots(figsize=[8, 5])
forestplot(ax, t);
save_figure(f, "supporting-f4-effects-vesseltype")


In [None]:
t = {
    f"RBC {mt} Measurement error parameter: {format_name(treatment)}": idata.posterior["sigma"].sel(treatment=treatment) 
    for mt, idata in [("flux", idata_f4_basic_flux), ("speed", idata_f4_basic_speed)]
    for treatment in idata.posterior.coords["treatment"].values

}
f, ax = plt.subplots(figsize=[8, 5])
forestplot(ax, t);
ax.set_xlim(0, ax.get_xlim()[1])

save_figure(f, "supporting-f4-sds")

In [None]:
vts = ["bulb", "sphincter", "cap1", "cap2", "cap3", "cap4", "cap5"]
treatments = ["hyper", "after_hyper", "after_ablation", "hyper2"]
treatment_to_compare = {
    "hyper": "baseline",
    "after_hyper": "baseline",
    "after_ablation": "after_hyper",
    "hyper2": "after_ablation",
}
t = {
    f"RBC speed effect: {format_name(vt)}:{format_name(treatment)} - {format_name(vt)}:{format_name(treatment_to_compare[treatment])}": (
        idata_f4_big_speed.posterior["a_vessel_type_treatment"].sel(vessel_type=vt, treatment=treatment)
        - idata_f4_big_speed.posterior["a_vessel_type_treatment"].sel(vessel_type=vt, treatment=treatment_to_compare[treatment])
    )   
    for vt in vts
    for treatment in treatments
}
f, ax = plt.subplots(figsize=[8, 11], sharex=True)
forestplot(ax, t);
save_figure(f, "supporting-f4-interaction-effects-speed")

In [None]:
vts = ['sphincter', 'cap2', 'cap3', 'cap4', 'cap5']
treatments = ['hyper', 'after_hyper', 'after_ablation', 'hyper2']
t = {
    f"RBC flux effect: {format_name(vt)}:{format_name(treatment)} - {format_name(vt)}:{format_name(treatment_to_compare[treatment])}": (
        idata_f4_big_flux.posterior["a_vessel_type_treatment"].sel(vessel_type=vt, treatment=treatment)
        - idata_f4_big_flux.posterior["a_vessel_type_treatment"].sel(vessel_type=vt, treatment=treatment_to_compare[treatment])
    )   
    for vt in vts
    for treatment in treatments
}
f, ax = plt.subplots(figsize=[8, 10], sharex=True)
forestplot(ax, t);
save_figure(f, "supporting-f4-interaction-effects-flux")

## Figure 5

In [None]:
idata_f5 = az.InferenceData.from_zarr(os.path.join("..", "inferences", "pulsatility-basic-full", "idata"))
data_f5 = load_prepared_data("../data/prepared/pulsatility.json")
msts_f5 = data_f5.measurements



In [None]:
f, ax = plt.subplots(figsize=[12, 5])
plot_obs(ax, msts_f5["pd1"], cat=msts_f5["vessel_type"]);
plot_predictive(ax, idata_f5.posterior_predictive["yrep"].sel(measurement_type="diameter"), cat=msts_f5["vessel_type"], zorder=-1, label="model")
ax.semilogy()
ax.set(ylabel="Diameter pulsatility (first harmonic)")
ax.set_xticks([]);
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(map(format_name, labels), handles))
ax.legend(by_label.values(), by_label.keys(), frameon=False, ncol=3);
save_figure(f, "supporting-f5-ppc-diameter")

In [None]:
f, ax = plt.subplots(figsize=[12, 5])
plot_obs(ax, msts_f5["pc1"], cat=msts_f5["vessel_type"]);
plot_predictive(ax, idata_f5.posterior_predictive["yrep"].sel(measurement_type="center"), cat=msts_f5["vessel_type"], zorder=-1, label="model")
ax.semilogy()
ax.set(ylabel="Center pulsatility (first harmonic)")
ax.set_xticks([])
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(map(format_name, labels), handles))
ax.legend(by_label.values(), by_label.keys(), frameon=False, ncol=3);
save_figure(f, "supporting-f5-ppc-center")

In [None]:
idata_f5

In [None]:
comparison = "pen_art"
vts = ["pen_art", "bulb", "cap1", "cap2", "cap3", "cap4", "cap5"]
t = {
    f"Difference in vessel type effect on diameter pulsatility: {format_name(vt)} - {format_name(comparison)}": (
        idata_f5.posterior["a_vessel_type"].sel(measurement_type="diameter", vessel_type=vt)
        - idata_f5.posterior["a_vessel_type"].sel(measurement_type="diameter", vessel_type=comparison)
    )
    for vt in vts if vt != comparison

} | {
     f"Difference in vessel type effect on center pulsatility: {format_name(vt)} - {format_name(comparison)}": (
        idata_f5.posterior["a_vessel_type"].sel(measurement_type="center", vessel_type=vt)
        - idata_f5.posterior["a_vessel_type"].sel(measurement_type="center", vessel_type=comparison)
    )
    for vt in vts if vt != comparison
}
f, ax = plt.subplots(figsize=[8, 5])
forestplot(ax, t);
save_figure(f, "supporting-f5-effects-vesseltype")


In [None]:

t = {
    f"Effect of diameter on diameter pulsatility": idata_f5.posterior["b_diameter"].sel(measurement_type="diameter"),
    f"Effect of diameter on center pulsatility": idata_f5.posterior["b_diameter"].sel(measurement_type="center"),
}
f, ax = plt.subplots(figsize=[8, 3])
forestplot(ax, t);
save_figure(f, "supporting-f5-effects-diameter")


## Figure 6

In [None]:
idata_f6 = az.InferenceData.from_zarr(os.path.join("..", "inferences", "diameter", "idata"))
prepared_data_f6 = load_prepared_data(os.path.join("..", "data", "prepared", "pulsatility.json"))
msts_f6 = prepared_data_f6.measurements
idata_f6

In [None]:
msts_f6.groupby(["treatment", "vessel_type"]).size().unstack()

In [None]:
t_overall = {
    "Difference in age effect: Adult - Old (Overall)" : (
        idata_f6.posterior["mu"].sel(age="adult")
        - idata_f6.posterior["mu"].sel(age="old")
    )
}
t_vt = {
    f"Difference in age effect: Adult - Old ({format_name(vt)})": (
        idata_f6.posterior["mu"].sel(age="adult") - idata_f6.posterior["mu"].sel(age="old")
        + idata_f6.posterior["a_age_vessel_type"].sel(age="adult", vessel_type=vt)
        - idata_f6.posterior["a_age_vessel_type"].sel(age="old", vessel_type=vt)
    )
    for vt in idata_f6.posterior.coords["vessel_type"].values
}
t = t_overall | t_vt
f, ax = plt.subplots(figsize=[8, 5])
forestplot(ax, t);
save_figure(f, "supporting-f6-effects-age")

In [None]:
t = {
    f"Treatment effect: {format_name(treatment)} - Baseline ({format_name(vt)})": (
        idata_f6.posterior["a_treatment"].sel(treatment=treatment) 
        - idata_f6.posterior["a_treatment"].sel(treatment="baseline")
        + idata_f6.posterior["a_vessel_type_treatment"].sel(treatment=treatment, vessel_type=vt)
        - idata_f6.posterior["a_vessel_type_treatment"].sel(treatment="baseline", vessel_type=vt)
    )
    for treatment in idata_f6.posterior.coords["treatment"].values
    for vt in idata_f6.posterior.coords["vessel_type"].values
    if treatment != "baseline"
}
f, ax = plt.subplots(figsize=[8, 9])
forestplot(ax, t);

In [None]:
yrep = idata_f6.posterior_predictive["yrep"]

f, ax = plt.subplots(1, 1, figsize=[16, 5])

plot_obs(ax, msts_f6["diameter"], cat=msts_f6["vessel_type"]);
plot_predictive(ax, yrep, zorder=-1, cat=msts_f6["vessel_type"], label="model")
ax.legend(frameon=False);
ax.semilogy();
ax.set_xticks([])
ax.set(ylabel="Diameter ($\\mu$m), log scale")
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(map(format_name, labels), handles))
ax.legend(by_label.values(), by_label.keys(), frameon=False, ncol=3);
save_figure(f, "supporting-f6-ppc")

## Figure 7

In [None]:
idata_f7_density = az.from_netcdf(os.path.join("..", "inferences", "collaterals", "ctls_per_area.nc"))
idata_f7_diameter = az.from_netcdf(os.path.join("..", "inferences", "collaterals", "ln_diameter_mean.nc"))
idata_f7_curved_length = az.from_netcdf(os.path.join("..", "inferences", "collaterals", "ln_curved_length.nc"))
idata_f7_tortuosity = az.from_netcdf(os.path.join("..", "inferences", "collaterals", "ln_m1_tortuosity.nc"))
idata_f7_craniotomy_diameter = az.from_netcdf("../inferences/collaterals/ln_craniotomy_diameter.nc")


msts_f7_bpts = pd.read_csv(os.path.join("..", "data", "prepared", "branchpoints.csv"))
msts_f7_ctls = pd.read_csv(os.path.join("..", "data", "prepared", "collaterals.csv"))
msts_f7_mice = pd.read_csv(os.path.join("..", "data", "prepared", "collaterals-mice.csv"))
msts_f7_mice.head()

In [None]:
for idata, ycol, ylabel in zip(
    [idata_f7_density, idata_f7_craniotomy_diameter],
    ["ln_collaterals_per_area", "ln_craniotomy_diameter"],
    ["Collaterals per $mm^2$", "Craniotomy diameter"],
):

    f, ax = plt.subplots(1, 1, figsize=[7, 4])
    
    yrep = np.exp(idata.posterior_predictive[ycol])
    plot_obs(ax, np.exp(msts_f7_mice[ycol]), cat=msts_f7_mice["age"]);
    plot_predictive(ax, yrep, zorder=-1, cat=msts_f7_mice["age"], label="model");
    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(map(format_name, labels), handles))
    ax.semilogy()
    ax.set_ylabel(ylabel);
    ax.legend(by_label.values(), by_label.keys(), frameon=False, ncol=3);
    save_figure(f, f"supporting-f7-ppc-{ycol}")

In [None]:
ycol_to_title = {
    "ln_diameter_mean": "Collateral Diameter",
    "ln_curved_length": "Collateral Length",
}

for (ycol, title), idata in zip(
    ycol_to_title.items(), 
    [idata_f7_diameter, idata_f7_curved_length]
):
    f, ax = plt.subplots(figsize=[7, 4])
    y = np.exp(msts_f7_ctls[ycol])
    yrep = np.exp(idata.posterior_predictive[ycol])
    plot_obs(ax, y, cat=msts_f7_ctls["age"]);
    plot_predictive(ax, yrep, zorder=-1, cat=msts_f7_ctls["age"], label="model");
    ax.semilogy()
    ax.set_ylabel(title)
    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(map(format_name, labels), handles))
    ax.semilogy()
    ax.set_xticks([])
    ax.set_ylabel(title.lower().capitalize());
    ax.legend(by_label.values(), by_label.keys(), frameon=False, ncol=3);
    save_figure(f, f"supporting-f7-ppc-{title.lower().replace(" ", "-")}")

In [None]:
ycol_to_title = {
    "ln_m1_tortuosity": "Collateral Tortuosity"
}

for (ycol, title), idata in zip(ycol_to_title.items(), [idata_f7_tortuosity]):
    f, ax = plt.subplots(figsize=[7, 4])
    y = np.exp(msts_f7_ctls[ycol] + 1)
    yrep = np.exp(idata.posterior_predictive[ycol] + 1)
    plot_obs(ax, y, cat=msts_f7_ctls["age"]);
    plot_predictive(ax, yrep, zorder=-1, cat=msts_f7_ctls["age"], label="model");
    ax.semilogy()
    ax.set_ylabel(title)
    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = dict(zip(map(format_name, labels), handles))
    ax.semilogy()
    ax.set_xticks([])
    ax.set_ylabel(title.lower().capitalize());
    ax.legend(by_label.values(), by_label.keys(), frameon=False, ncol=3);
    save_figure(f, f"supporting-f7-ppc-{title.lower().replace(" ", "-")}")

In [None]:
idata_f7_is_sphincter = az.from_netcdf("../inferences/branchpoints/is_sphincter.nc")
idata_f7_is_bulb = az.from_netcdf("../inferences/branchpoints/is_bulb.nc")

In [None]:
ycol = "is_sphincter"
idata = idata_f7_is_sphincter
gcol = "age"
msts = msts_f7_bpts

yrep = idata.posterior_predictive["is_sphincter"]

for ycol, idata in [("is_sphincter", idata_f7_is_sphincter), ("is_bulb", idata_f7_is_bulb)]:

    qs = (
        idata.posterior_predictive[ycol]
        .to_dataframe()
        .unstack(["chain", "draw"])
        .set_index(msts.index)
        .groupby(msts[gcol])
        .mean()
        .quantile([0.01, 0.99], axis=1)
        .T
        .add_prefix("q")
        .join(msts.groupby(gcol)[ycol].mean().rename("obs"))
    )
    f, ax = plt.subplots()
    x = np.linspace(0, 1, qs.shape[0])
    
    xlimlow, xlimhigh = ax.get_ylim()
    xtickxs = np.linspace(xlimlow, xlimhigh, len(qs)+2)
    xs = xtickxs[1:-1]
    ax.set_xlim(xtickxs[0], xtickxs[-1])
    for x, (label, subdf) in zip(xs, qs.groupby(gcol)):
        ax.scatter(x, subdf["obs"], label=f"Obs ({gcol}: {label})")
    ax.vlines(xs, qs["q0.01"], qs["q0.99"], color="gainsboro", zorder=0, label="model")
    ax.legend(frameon=False);
    ax.set_xticks([]);
    ax.set(ylabel=f"Proportion {ycol[3:].capitalize()}");
    save_figure(f, f"supporting-f7-ppc-{ycol}")