# Question 2

This notebook investigates the pulsatility data.

## Load data

In [1]:
import os

import arviz as az
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sphincter.data_preparation import load_prepared_data

In [2]:
PLOTS_DIR = os.path.join("..", "plots")

full_data = load_prepared_data("../data/prepared/pulsatility.json")
no_hyper_data = load_prepared_data("../data/prepared/pulsatility-no-hyper.json")

mts, mts_full = no_hyper_data.measurements, full_data.measurements

FileNotFoundError: [Errno 2] No such file or directory: '../data/prepared/pulsatility.json'

In [None]:
raw = pd.read_csv("../data/raw/data_sphincter_paper.csv")

## Plot measurements

The next cell plots the full dataset

In [None]:
def plot_obs_cat(ax, obs, catcol, cmap, extra_obs=None, **scatter_kwargs):
    colors = list(cmap.colors)
    d_dict = {"obs": obs, "cat": catcol}
    if extra_obs is not None:
        d_dict["extra"] = extra_obs
    d = pd.DataFrame(d_dict).sort_values("cat").assign(x=np.linspace(0, 1, len(obs)))
    scts = []
    for i, (cat, subdf) in enumerate(d.groupby("cat", observed=True)):
        color = colors[i % len(colors)]
        scts.append(ax.scatter(subdf["x"], subdf["obs"], label=cat, color=color, **scatter_kwargs))
        if extra_obs is not None:
            scts.append(ax.scatter(subdf["x"], subdf["extra"], marker="x", label=extra_obs.name, color="black"))            
    return scts



obs = mts_full["pd1"]
cmap = mpl.colormaps["Set2"]

f, axes = plt.subplots(2, 2, figsize=[14, 8], sharey=True)
axes = axes.ravel()
for (i, ax), col in zip(enumerate(axes), ["treatment", "age", "vessel_type", "mouse"]):
    catcol = mts_full[col]
    sct = plot_obs_cat(ax, obs, catcol, cmap)
    if col != "mouse":
        ax.legend(frameon=False);
    if i % 2 == 0:
        ax.set_ylabel("pd1");
    ax.set_xticks([])
    ax.set_title(col.capitalize())
    ax.semilogy()
f.suptitle("Diameter pulsatility measurements (first harmonic)");
f.tight_layout()

obs = mts_full["pc1"]
cmap = mpl.colormaps["Set2"]

f, axes = plt.subplots(2, 2, figsize=[14, 8], sharey=True)
axes = axes.ravel()
for (i, ax), col in zip(enumerate(axes), ["treatment", "age", "vessel_type", "mouse"]):
    catcol = mts_full[col]
    sct = plot_obs_cat(ax, obs, catcol, cmap)
    if col != "mouse":
        ax.legend(frameon=False);
    if i % 2 == 0:
        ax.set_ylabel("pd1");
    ax.set_xticks([])
    ax.set_title(col.capitalize())
    ax.semilogy()
f.suptitle("Center pulsatility measurements (first harmonic)");
f.tight_layout()

In order to go one step at a time I'm going to first look at the data excluding the hypertension case. 

Here's the same plot for this reduced dataset.

In [None]:
obs = mts["pd1"]
cmap = mpl.colormaps["Set2"]

f, axes = plt.subplots(2, 2, figsize=[14, 8], sharey=True)
axes = axes.ravel()
for (i, ax), col in zip(enumerate(axes), ["treatment", "age", "vessel_type", "mouse"]):
    catcol = mts[col]
    sct = plot_obs_cat(ax, obs, catcol, cmap)
    if col != "mouse":
        ax.legend(frameon=False);
    if i % 2 == 0:
        ax.set_ylabel("pd1");
    ax.set_xticks([])
    ax.set_title(col.capitalize())
    ax.semilogy()
f.suptitle("Diameter pulsatility measurements (first harmonic)");
f.tight_layout()
f.savefig(os.path.join(PLOTS_DIR, "pulsatility-diameter-measurements.png"), bbox_inches="tight")

obs = mts["pc1"]
cmap = mpl.colormaps["Set2"]

f, axes = plt.subplots(2, 2, figsize=[14, 8], sharey=True)
axes = axes.ravel()
for (i, ax), col in zip(enumerate(axes), ["treatment", "age", "vessel_type", "mouse"]):
    catcol = mts[col]
    sct = plot_obs_cat(ax, obs, catcol, cmap)
    if col != "mouse":
        ax.legend(frameon=False);
    if i % 2 == 0:
        ax.set_ylabel("pd1");
    ax.set_xticks([])
    ax.set_title(col.capitalize())
    ax.semilogy()
f.suptitle("Diameter pulsatility measurements (first harmonic)");
f.tight_layout()
f.savefig(os.path.join(PLOTS_DIR, "pulsatility-center-measurements.png"), bbox_inches="tight")

There is a clear pattern for the diameter power harmonics to get lower with the order of the vessel.

The treatment regime seems to make the diameter measurements closer together and somewhat higher.

The center power harmonics have an interesting cluster: one old mice seems to have had consistently high measurements.

The next cell gets that mouse's id and displays all of its center power harmonic measurements.


In [None]:
mts.loc[
    lambda df: df["mouse"] == mts.groupby("mouse", observed=True)["pc1"].max().idxmax(),
    ["age", "mouse", "vessel_type", "pd1", "pc1", "pressure_d"]
]

There is also a blood pressure measurement for each datapoint. The next plot shows that there isn't an obvious correlation between pressure and our measurements.

NOTE: try coarsening the data by binning.

In [None]:
f, axes = plt.subplots(1, 2, figsize=[15, 5]);

col = "pressure_d"
groupcol = "age"


def plots_with_bins_and_groups(axes, mts, col, groupcol):
    for ax, ycol in zip(axes, ["pd1", "pc1"]):
        for treatment, subdf in mts.dropna(subset=col).groupby(groupcol, observed=True):
            bins = pd.qcut(subdf[col], 5)
            coarse = subdf.groupby(bins, observed=True)[[col, ycol]].mean()
            sct = ax.scatter(subdf[col], subdf[ycol], label=treatment, alpha=0.2);
            ax.scatter(
                coarse[col], 
                coarse[ycol], 
                marker="o", 
                label=treatment + " coarse", 
                color=sct.get_facecolor(),
                alpha=1,
                s=50
            );
        ax.legend(frameon=False)
        ax.semilogy();
        ax.set(title=ycol, xlabel=col, ylabel=ycol);
        # ax.semilogx();
    return axes

plots_with_bins_and_groups(axes, mts, col, groupcol);
f.suptitle("Pressure and age vs pulsatility measurements");
f.savefig(os.path.join(PLOTS_DIR, "pressure-data.png"), bbox_inches="tight")

In [None]:
f, axes = plt.subplots(1, 2, figsize=[15, 5])

plots_with_bins_and_groups(axes, mts, "diameter", "vessel_type");
axes[0].semilogx();
f.savefig(os.path.join(PLOTS_DIR, "pulsatility-diameter-data.png"), bbox_inches="tight")

In [None]:
mts.groupby(["age", "vessel_type"], observed=True)["diameter"].mean().unstack("age").plot();

There are no measurements beyond cap1 for the hyper or hyper2 treatments:

In [None]:
mts_full.groupby(["treatment", "vessel_type"], observed=True).size().unstack()

## Speed

Where possible,  the speed of red blood cells through the vessels was also measured.

The next few cells look at how many speed measurements are available, how speed is distributed depending on age and vessel type, and how measured speed is related to pulsatility.

In [None]:
pd.DataFrame(
    {
        "measurements with speed": mts.groupby(["vessel_type"], observed=True)["speed"].count(), 
        "total measurements": mts.groupby("vessel_type", observed=True).size(),
        "mean speed": mts.groupby("vessel_type", observed=True)["speed"].mean(),
    }
)

In [None]:
f, axes = plt.subplots(2, 3, figsize=[15, 10], sharex=True, sharey=True)
axes = axes.ravel()
bins=np.linspace(mts["speed"].min(), mts["speed"].max(), 10)

for ax, (vt, subdf) in zip(axes, mts.dropna(subset="speed").groupby(["vessel_type"], observed=True)):
    ax.grid(alpha=0.5)
    for t, subsubdf in subdf.groupby("age", observed=True):
        ax.hist(subsubdf["speed"], alpha=0.5, bins=bins, label=t, stacked=True);
    ax.set_axisbelow(True)
    ax.set(title=vt[0])
    ax.legend(frameon=False)
for ix in [0, 3]:
    axes[ix].set_ylabel("Count")
for ix in [3, 4, 5]:
    axes[ix].set_xlabel("Speed")
f.suptitle("Distribution of measured red blood cell speeds for each vessel type", y=0.95);

In [None]:
f, axes = plt.subplots(1, 2, figsize=[15, 5])
plots_with_bins_and_groups(axes, mts, "speed", "vessel_type");
axes[0].semilogx();
axes[1].semilogx();

## Models

The first model I fit to this dataset predicts diameter and center power harmonics independently and in the same way, using the sum of four (also independent) parameters: an intercept $\mu$, an age effect $\alpha^{age}$ a treatment effect $\alpha^{treatment}$ and a vessel type effect $\alpha^{vessel\ type}$. The model creates a linear predictor out of these parameters for each measurement and fits it using an exponential GLM.

I called this model the "basic" model.

The next cell loads the results of fitting the basic model and runs a diagnostic to estimate its out of sample predictive performance. The main output metric (`elpd_loo`) is in principle absolute, representing the estimated total out of sample log likelihood under leave-one-out cross validation. However, this doesn't mean too much in isolation as it isn't clear in advance what would be a good without another model to compare with. Still, the fact that the check runs without warnings is a good sign, indicating that there weren't many very influential observations.

In [None]:
idatas = {
    "pulsatility-basic": az.InferenceData.from_zarr(os.path.join("..", "inferences", "pulsatility-basic", "idata")),
    "pulsatility-interaction": az.InferenceData.from_zarr(os.path.join("..", "inferences", "pulsatility-interaction", "idata")),
    "pulsatility-pressure": az.InferenceData.from_zarr(os.path.join("..", "inferences", "pulsatility-pressure", "idata")),
    "pulsatility-pressure-no-age": az.InferenceData.from_zarr(os.path.join("..", "inferences", "pulsatility-pressure-no-age", "idata"))
}
cmp = az.compare(idatas)
cmp

In [None]:
for name, idata in idatas.items():
    print(f"Number of diverging transitions for {name} model: " +  str(idata.sample_stats.diverging.values.sum()))

In [None]:
loo = az.loo(idatas["pulsatility-pressure"], pointwise=True)
loo

In [None]:
f, ax = plt.subplots()
az.plot_compare(cmp, insample_dev=True, ax=ax, plot_standard_error=False);
f.savefig(os.path.join(PLOTS_DIR, "pulsatility-elpd-comparison.png"), bbox_inches="tight")

The next cell plots the marginal distributions of the basic model's main parameters.

In [None]:
vars = ["tau_vessel_type", "tau_treatment", "mu", "a_vessel_type", "a_treatment", "b_pressure", "b_diameter"]
f, axes = plt.subplots(1, 2, figsize=[15, 10], sharex=True)
for ax, mt in zip(axes, ["diameter", "center"]):
    az.plot_forest(
        [idatas["pulsatility-basic"], idatas["pulsatility-interaction"], idatas["pulsatility-pressure"], idatas["pulsatility-pressure-no-age"]],
        model_names=["basic", "interaction", "pressure", "pressure-no-age"], 
        combined=True, 
        ax=ax,
        var_names=vars,
        coords={"measurement_type": mt}
    );
    ax.axvline(0, color="black")
    ax.set(title=mt.capitalize())
f.tight_layout()
f.suptitle("Shared effects")
f.savefig(os.path.join(PLOTS_DIR, "pulsatility-effects.png"), bbox_inches="tight")

In [None]:
vars = ["a_age_treatment", "a_age_treatment_vessel_type"]
f, axes = plt.subplots(1, 2, figsize=[15, 10], sharex=True)
for ax, mt in zip(axes, ["diameter", "center"]):
    az.plot_forest(idatas["pulsatility-interaction"], combined=True, ax=ax, var_names=vars, coords={"measurement_type": mt});
    ax.axvline(0, color="black")
    ax.set(title=mt.capitalize())
f.suptitle("Interaction model unique effect distributions")
f.tight_layout()
f.savefig(os.path.join(PLOTS_DIR, "pulsatility-interaction-effects.png"), bbox_inches="tight")

The next cell does a posterior predictive check, comparing measurements simulated using the model with the actually realised observations.

This shows an overall fairly good fit, though there are quite a few extreme center measurements that the model can't capture, and it seems to underfit for some mice as a result.

In [None]:
def plot_lines_cat(ax, yrep, catcol, cmap, **vlines_kwargs):
    colors = list(cmap.colors)
    d = pd.DataFrame(
        {
            "cat": catcol,
            "q1": yrep.quantile(0.01, dim=["chain", "draw"]).values,
            "q99": yrep.quantile(0.99, dim=["chain", "draw"]).values,
        }
    ).sort_values("cat").assign(x=np.linspace(0, 1, len(catcol)))
    linesets = []
    for i, (cat, subdf) in enumerate(d.groupby("cat", observed=True)):
        lines = ax.vlines(subdf["x"], subdf["q1"], subdf["q99"], color="gainsboro", **vlines_kwargs)
    return lines

def plot_ppc(axes, mts, obs, yrep, cmap):
    for (i, ax), col in zip(enumerate(axes), ["treatment", "age", "vessel_type", "mouse"]):
        catcol = mts[col].cat.remove_unused_categories()
        scts = plot_obs_cat(ax, obs, catcol, cmap)
        lines = plot_lines_cat(ax, yrep, catcol, cmap, zorder=0, label="model")
        if col != "mouse":
            ax.legend(scts + [lines], list(catcol.cat.categories) + ["model"], frameon=False, bbox_to_anchor=[1,0.5], loc="center left");
        if i % 2 == 0:
            ax.set_ylabel("pd1");
        ax.set_xticks([])
        ax.set_title(col.capitalize())
        ax.semilogy()
        ax.set_ylim(1e-2, 1e5)

mt_to_obs = {"diameter": mts["pd1"], "center": mts["pc1"]}
cmap = mpl.colormaps["Set2"]

for mt in ["diameter", "center"]:
    for group, mode in zip(
        [idatas["pulsatility-pressure"].prior_predictive, idatas["pulsatility-pressure"].posterior_predictive], 
        ["prior", "posterior"]
    ):
        yrep = group.sel(measurement_type=mt)["yrep"]
        obs = mt_to_obs[mt]
        f, axes = plt.subplots(2, 2, figsize=[16, 8], sharey=True)
        axes = axes.ravel()
        axes = plot_ppc(axes, mts, obs, yrep, cmap)
        f.suptitle(f"{mt.capitalize()} pulsatility measurements (first harmonic) vs {mode} simulations");
        f.tight_layout()
        f.savefig(os.path.join(PLOTS_DIR, f"pulsatility-{mode}-check-{mt}.png"), bbox_inches="tight")

In [None]:
idatas["pulsatility-pressure"].prior_predictive["yrep"].sel(measurement_type="diameter").quantile([0.01, 0.99])

In [None]:
f, ax = plt.subplots(figsize=[12, 5])

az.plot_ppc(
    idatas["pulsatility-pressure"],
    data_pairs={"y":"yrep"},
    coords={"measurement_type": "diameter"},
    var_names=["y"], 
    ax=ax,
    group="posterior"
);

In [None]:
f, ax = plt.subplots()
ax.hist(np.log(idatas["pulsatility-pressure"].prior_predictive["yrep"].values.flatten()), bins=100);
ax.semilogy()

The next cell plots the differences in age effects for each measurement type in the basic model. 

The plots show that, according to the basic model, diameter power harmonics tended to be higher for adult mice while center harmonics tended to be higher for old mice.

In [None]:
f, axes = plt.subplots(1, 2, figsize=[12, 5])
f.suptitle("Age effect differences")
for ax, mt in zip(axes, ["diameter", "center"]):
    comp = idatas["pulsatility-pressure"].posterior["mu"].sel(measurement_type=mt, age="adult").values.flatten()
    base = idatas["pulsatility-pressure"].posterior["mu"].sel(measurement_type=mt, age="old").values.flatten()
    ax.hist(comp - base, bins=30)
    ax.set(title=mt.capitalize(), xlabel="Difference (adult minus old, arbitrary units)")
f.savefig(os.path.join(PLOTS_DIR, "pulsatility-age-effects.png"), bbox_inches="tight")

In [None]:
f, axes = plt.subplots(1, 2, figsize=[12, 5])
f.suptitle("Pressure effect differences (adult minus old)")
for ax, mt in zip(axes, ["diameter", "center"]):
    comp = idatas["pulsatility-pressure"].posterior["b_pressure"].sel(measurement_type=mt, age="adult").values.flatten()
    base = idatas["pulsatility-pressure"].posterior["b_pressure"].sel(measurement_type=mt, age="old").values.flatten()
    ax.hist(comp - base, bins=30)
    ax.set(title=mt.capitalize(), xlabel="Difference (adult minus old, arbitrary units)")
f.savefig(os.path.join(PLOTS_DIR, "pulsatility-pressure-effects.png"), bbox_inches="tight")

In [None]:
f, axes = plt.subplots(1, 2, figsize=[12, 5])
f.suptitle("Treatment effect differences relative to baseline")
for ax, mt in zip(axes, ["diameter", "center"]):
    for treatment in mts["treatment"].cat.remove_unused_categories().cat.categories:
        if treatment != "baseline":
            comp = idatas["pulsatility-pressure"].posterior["a_treatment"].sel(measurement_type=mt, treatment=treatment).values.flatten()
            base = idatas["pulsatility-pressure"].posterior["a_treatment"].sel(measurement_type=mt, treatment="baseline").values.flatten()
            pr = ((comp - base) > 0).mean()
            ax.hist(comp - base, bins=30, alpha=0.6, label=treatment + f"\nprob +ve: {pr.round(2)}");
    ax.set(title=mt.capitalize(), xlabel="Difference (treatment minus baseline, arbitrary units)")
    ax.legend(frameon=False)
axes[0].text(0.4, 200, "The model thinks that\nspincter ablation increases\nthe first diameter harmonic", fontsize="small");
f.savefig(os.path.join(PLOTS_DIR, "pulsatility-treatment-effects.png"), bbox_inches="tight")

There is probably a relationship between the diameter and center pulsatility. This is because the tissue around the vessel is likely not uniform in stiffness: in this case the vessel does not expand and contract uniformly in all directions, so the center moves due to this. The movement of the center therefore has two components: the macro component due to the pressure waves that propagate through the brain and the micro component due to non-isotropic surrounding tissue.

Maybe there should be an effect of absolute diameter, as the treatments tend to change this. As the diameter increases, the vessel becomes less elastic, which you might expect to reduce the pulsatility.

In [None]:
idata_full = az.InferenceData.from_zarr(os.path.join("..", "inferences", "pulsatility-basic-full", "idata"))
idata_full


In [None]:
vars = ["tau_vessel_type", "tau_treatment", "mu", "a_vessel_type", "a_treatment", "b_diameter"]
f, axes = plt.subplots(1, 2, figsize=[15, 10], sharex=True)
for ax, mt in zip(axes, ["diameter", "center"]):
    az.plot_forest(
        [idatas["pulsatility-basic"], idata_full],
        model_names=["no hypertension", "all treatments"], 
        combined=True, 
        ax=ax,
        var_names=vars,
        coords={"measurement_type": mt}
    );
    ax.axvline(0, color="black")
    ax.set(title=mt.capitalize())
f.tight_layout()
f.suptitle("Shared effects")
f.savefig(os.path.join(PLOTS_DIR, "pulsatility-effects-basic-datasets.png"), bbox_inches="tight")

In [None]:
mt_to_obs = {"diameter": mts_full["pd1"], "center": mts_full["pc1"]}
cmap = mpl.colormaps["Set2"]

for mt in ["diameter", "center"]:
    for group, mode in zip(
        [idata_full.prior_predictive, idata_full.posterior_predictive], 
        ["prior", "posterior"]
    ):
        yrep = group.sel(measurement_type=mt)["yrep"]
        obs = mt_to_obs[mt]
        f, axes = plt.subplots(2, 2, figsize=[16, 8], sharey=True)
        axes = axes.ravel()
        axes = plot_ppc(axes, mts_full, obs, yrep, cmap)
        f.suptitle(f"{mt.capitalize()} pulsatility measurements (first harmonic) vs {mode} simulations");
        f.tight_layout()
        f.savefig(os.path.join(PLOTS_DIR, f"pulsatility-{mode}-full-check-{mt}.png"), bbox_inches="tight")

In [None]:
f, axes = plt.subplots(1, 2, figsize=[12, 5])
f.suptitle("Treatment effect differences relative to baseline (basic model with full dataset)")
for ax, mt in zip(axes, ["diameter", "center"]):
    for treatment in mts_full["treatment"].cat.remove_unused_categories().cat.categories:
        if treatment != "baseline":
            comp = idata_full.posterior["a_treatment"].sel(measurement_type=mt, treatment=treatment).values.flatten()
            base = idata_full.posterior["a_treatment"].sel(measurement_type=mt, treatment="baseline").values.flatten()
            pr = ((comp - base) > 0).mean()
            ax.hist(comp - base, bins=30, alpha=0.6, label=treatment + f"\nprob +ve: {pr.round(2)}");
    ax.set(title=mt.capitalize(), xlabel="Difference (treatment minus baseline, arbitrary units)")
    ax.legend(frameon=False)
f.savefig(os.path.join(PLOTS_DIR, "pulsatility-treatment-effects-full.png"), bbox_inches="tight")