In [None]:
%config Completer.use_jedi = False
import numpy as np
import pickle
import pathlib
import shedding
from matplotlib import pyplot as plt
import matplotlib as mpl
from scipy import stats, optimize, integrate
import glob
from tqdm.notebook import tqdm
from pathlib import Path
import re
import collections
import itertools as it
import textwrap

mpl.style.use('scrartcl.mplstyle')


def style_violins(violins, **kwargs):
    """
    Apply keyword arguments to a violinplot.
    """
    kwargs.setdefault('zorder', 1)
    [plt.setp(value, **kwargs) for value in violins.values()]
    
    
def evaluate_pcolormesh_edges(x, scale='linear'):
    """
    Evaluate the edges of cells for a `pcolormesh` visualisation.
    """
    if scale == 'log':
        x = np.log(x)
    elif scale != 'linear':
        raise ValueError(scale)
        
    # Find the (n - 1) midpoints
    midpoints = (x[1:] + x[:-1]) / 2
    # Find the endpoints
    left = 2 * x[0] - midpoints[0]
    right = 2 * x[-1] - midpoints[-1]
    # Construct the edges
    edges = np.concatenate([[left], midpoints, [right]])
    
    if scale == 'log':
        edges = np.exp(edges)
    return edges

def evaluate_mode(x, lin=200):
    """
    Evaluate the mode of a univariate distribution using kernel density estimation.
    """
    kde = stats.gaussian_kde(x)
    if isinstance(lin, int):
        lin = np.linspace(np.min(x), np.max(x), lin)
    y = kde(lin)
    return lin[np.argmax(y)]


def alpha_cmap(color):
    """
    Create a colormap which interpolates between transparent and the given color.
    """
    if isinstance(color, int):
        color = f'C{color}'
    return mpl.colors.LinearSegmentedColormap.from_list("", [
        mpl.colors.to_rgba(color, alpha=0),
        mpl.colors.to_rgba(color, alpha=1),
    ])


def binary_cmap(color1, color2):
    return mpl.colors.LinearSegmentedColormap.from_list("", [
        color1,
        color2,
    ])


def evaluate_hpd_levels(pxf, pvals):
    """
    Evaluate the levels for given highest-proability-density regions. 
    Works for multimodal distributions.
    """
    if isinstance(pvals, int):
        pvals = (pvals - np.arange(pvals)) / (pvals + 1)
    pvals = np.atleast_1d(pvals)
    idx = np.argsort(-pxf.ravel())
    cum = np.cumsum(pxf.ravel()[idx])
    cum /= cum[-1]
    j = np.argmax(cum[:, None] > pvals, axis=0)
    return pxf.ravel()[idx][j]


def evaluate_hpd_levels(pdf, pvals):
    """
    Evaluate the levels that include a given fraction of the the probability mass.

    Parameters
    ----------
    pdf : array_like
        Probability density function evaluated over a mesh.
    pvals : array_like or int
        Probability mass to be included within the corresponding level or the number of levels.

    Returns
    -------
    levels : array_like
        Contour levels of the probability density function that enclose the desired probability
        mass.
    """
    # Obtain equidistant levels if only the number is given
    if isinstance(pvals, int):
        pvals = (pvals - np.arange(pvals)) / (pvals + 1)
    pvals = np.atleast_1d(pvals)
    # Sort the probability density and evaluate the normalised cumulative distribution.
    # We aggregate identical pdf values so we can interpolate.
    pdf, weights = np.unique(-pdf, return_counts=True)
    pdf = - pdf
    cum = integrate.cumtrapz(pdf * weights)
    cum = np.concatenate([np.zeros(1), cum])
    cum /= cum[-1]
    # Find the first index that encloses more than the desired mass
    js = np.argmax(cum[:, None] > pvals, axis=0)
    levels = []
    for j, pval in zip(js, pvals):
        i = j - 1
        # Get the upper and lower bounds and interpolate
        y2 = cum[j]
        y1 = cum[i]
        x2 = pdf[j]
        x1 = pdf[i]
        slope = (y2 - y1) / (x2 - x1)
        offset = y1 - slope * x1
        level = (pval - offset) / slope
        levels.append(level)

    return np.asarray(levels)



    
colors_by_key = {
    ('standard', 'constant'): 'C0',
    ('inflated', 'constant'): 'C2',
    ('standard', 'temporal'): 'C1',
    ('inflated', 'temporal'): 'C3',
}
# Add shorthand including the general key
colors_by_key.update({
    ('general',) + key: value for key, value in colors_by_key.items()
})

markers_by_key = {
    ('standard', 'constant'): 'o',
    ('inflated', 'constant'): 's',
    ('standard', 'temporal'): 'v',
    ('inflated', 'temporal'): 'D',
}
# Add shorthand including the general key
markers_by_key.update({
    ('general',) + key: value for key, value in markers_by_key.items()
})

violin_widths = 0.7
log10formatter = mpl.ticker.FuncFormatter(lambda x, _: f'$10^{{{x if x % 1 else int(x)}}}$')


# labels
population_shape = 'Q'
population_scale = 'S'
patient_shape = 'q'
patient_scale = r'\sigma'

shape_label = {
    'population': population_shape,
    'patient': patient_shape,
}
scale_label = {
    'population': population_scale,
    'patient': patient_scale,
}

# Illustration for the model structure

In [None]:
class DistHandler:
    """
    Generate a distribution legend handle.
    """
    def legend_artist(self, legend, orig_handle, fontsize, handlebox):
        color = orig_handle.get_color()
        x0, y0 = handlebox.xdescent, handlebox.ydescent
        width, height = handlebox.width, handlebox.height
        scale = width / 4
        lin = x0 + np.linspace(0, width)
        z = (lin - (x0 + width / 2)) / scale
        pdf = height * np.exp(- z ** 2 / 2)
        
        patch = mpl.lines.Line2D(lin, pdf, color=color, ls=orig_handle.get_linestyle())
        handlebox.add_artist(patch)
        
        xy = [(x0, y0)]
        xy.extend(zip(lin, pdf))
        xy.append((x0 + width, y0))
        patch = mpl.patches.Polygon(xy, color=color, alpha=alpha)
        handlebox.add_artist(patch)

fig, ax = plt.subplots()
np.random.seed(8)

num_patients = 3
time = np.linspace(0, 25, 200)
profile_func = lambda x: 1e4 * x * np.exp(-x / 2)
profile = profile_func(time)
earliest = 6
alpha = 0.25
loq = 1e-1
ax.axvspan(-1, earliest, facecolor='gray', alpha=.15)
mpl.rcParams['hatch.color'] = mpl.colors.to_rgba('C0', alpha)
mpl.rcParams['hatch.linewidth'] = 8

offset_scale = 3
offsets =  np.random.normal(0, offset_scale, num_patients)
offsets = np.exp(offsets - np.mean(offsets))
focus_patient = np.argmin(offsets)

for i in np.arange(num_patients):
    ls = '-' if i == focus_patient else ':'
    offset = offset_scale * np.exp(i)
    offset = offsets[i]
    fltr = time <= earliest
    ax.plot(time[fltr], profile[fltr] * offset, color='gray', ls=ls)
    fltr = time >= earliest
    ax.plot(time[fltr], profile[fltr] * offset, color='black', ls=ls)

ax.scatter(np.ones(num_patients) * earliest, offsets * profile_func(earliest), zorder=9, marker='s')

lin = np.linspace(-3 * offset_scale, 3 * offset_scale)
y = np.exp(lin) * profile_func(earliest)
left = earliest - 10 * offset_scale * stats.norm.pdf(lin, 0, offset_scale)
right = earliest * np.ones_like(lin)
fill = ax.fill_betweenx(y, left, right, alpha=alpha, hatch='/')
fill.set_facecolor((1, 1, 1, 0))
ax.plot(left, y, ls='--')

sample_scale = 1
fltr = time >= earliest
y = offsets[focus_patient] * profile_func(time[fltr])
ax.fill_between(time[fltr], y / np.exp(sample_scale), y * np.exp(sample_scale), color='C1', alpha=alpha)
ax.fill_between(time[fltr], y / np.exp(2 * sample_scale), y * np.exp(2 * sample_scale), color='C1', alpha=alpha)

lin = np.linspace(-3 * sample_scale, 3 * sample_scale)
y = np.exp(lin) * profile_func(earliest) * offsets[focus_patient]
right = earliest + 10 * sample_scale * stats.norm.pdf(lin, 0, sample_scale)
left = earliest * np.ones_like(lin)
fill = ax.fill_betweenx(y, left, right, color='C1', alpha=alpha)
ax.plot(right, y)

num_samples = 5
sample_offsets = np.random.normal(0, sample_scale, num_samples)
sample_offsets = np.exp(sample_offsets - np.mean(sample_offsets))
sample_times = np.random.randint(earliest, time.max() + 1, num_samples)
sample_values = profile_func(sample_times) * sample_offsets * offsets[focus_patient]
fltr = sample_values > loq
ax.scatter(sample_times[fltr], sample_values[fltr], zorder=9, marker='o', color='C1')
ax.scatter(sample_times[~fltr], sample_values[~fltr], zorder=9, marker='X', color='C1')
ax.axhline(loq, ls='--', color='k')

ax.set_yscale('log')
ax.set_xlim(-1)
pop_handle = mpl.lines.Line2D([], [], color='C0', ls='--')
patient_handle = mpl.lines.Line2D([], [], color='C1')
handles_labels = [
    [pop_handle, 'population-level distribution, describing\nvariation in shedding between patients'],
    [mpl.lines.Line2D([], [], color='C0', ls='none', marker='s'), 
     r'patient location parameter $\mu$, describing the''\namplitude of individual shedding profiles'],
    
    [patient_handle, 'patient-level distribution, describing variation\nbetween samples from the same patient'],
    [(mpl.lines.Line2D([], [], color='C1', ls='none', marker='o'),
      mpl.lines.Line2D([], [], color='C1', ls='none', marker='X')), r'sample-level RNA load $y$ quantified by''\nRT-qPCR assays'],
    [mpl.lines.Line2D([], [], color='k'), r'shedding profile $g(t)$, describing the''\nevolution of RNA loads'],
    [mpl.lines.Line2D([], [], color='k', ls='--'), r'limit of quantification $\theta$'],
]

ax.legend(*zip(*handles_labels), loc='best', fontsize='small', handler_map={
    tuple: mpl.legend_handler.HandlerTuple(None),
    pop_handle: DistHandler(),
    patient_handle: DistHandler(),
}, edgecolor='none')
plt.setp(ax.xaxis.get_ticklabels(), visible=False)
plt.setp(ax.yaxis.get_ticklabels(), visible=False)
ax.set_xlabel('Days past symptom onset $t$')
ax.set_ylabel(r'$\log$ SARS-CoV-2 RNA load')

pos = 5e-3
kwargs = {
    'arrowstyle': 'simple,tail_width=45,head_width=55,head_length=20',
    'zorder': 9,
    'shrinkA': 0,
    'shrinkB': 0,
    'edgecolor': 'none',
}
patch = mpl.patches.FancyArrowPatch((earliest, pos), (earliest - 6.5, pos),
                                    facecolor='w', **kwargs)
ax.add_artist(patch)
patch = mpl.patches.FancyArrowPatch((earliest, pos), (earliest + 6.5, pos), facecolor='gray', 
                                    alpha=.15, **kwargs)
ax.add_artist(patch)

kwargs = dict(va='center', fontsize='small', zorder=10)
ax.text(earliest - 0.5, pos, 'early shedding;\ncurrently uncon-\nstrained by data', ha='right', **kwargs)
ax.text(earliest + 0.5, pos, 'late shedding;\nconsistent with ex-\nponential profile', ha='left', **kwargs)
fig.tight_layout()
fig.savefig('figures/model.pdf')

# Evidences for the different models

In [None]:
# Load all the results for a given seed.
seed = 0
filenames = glob.glob(f'workspace/*-*-*-{seed}*/polychord/result.pkl')
print(f'found {len(filenames)} results\n')
results = {}
evidences = {}

for filename in sorted(filenames):
    with open(filename, 'rb') as fp:
        result = pickle.load(fp)
        model = result['model']
        key = (model.parametrisation.value, 
               'inflated' if model.inflated else 'standard', 
               'temporal' if model.temporal.value else 'constant')
        
        if result['args'].evidence:
            x, err = result['evidence']
            print(f'model: {key}; evidence: {x:.2f} +- {err:.2f}')
            evidences[key] = result['evidence']
        
        if result['args'].evidence and model.temporal == shedding.Profile.CONSTANT:
            key = key + ('evidence-calculation',)
        results[key] = result

# Evaluate summary statistics for the mean shedding rate

In [None]:
for key, result in results.items():
    lines = []
    
    samples = result['samples']
    means = np.asarray([
        shedding.gengamma_mean(x['population_shape'], x['population_loc'], x['population_scale'])
        for x in shedding.transpose_samples(samples)
    ])
    
    mode = evaluate_mode(np.ravel(np.log10(means)))
    lines.append(f'mode: {mode:.3f}')
    
    fig, ax = plt.subplots()
    kde = stats.gaussian_kde(np.log10(means))
    xmin = kde.dataset.min()
    xmax = kde.dataset.max()
    xrng = xmax - xmin
    lin = np.linspace(xmin - 0.5 * xrng, xmax + 0.5 * xrng, 500)
    pdf = kde(lin)
    plt.plot(lin, pdf)
    level = evaluate_hpd_levels(pdf, 0.95)
    plt.axhline(level, color='C1', ls='--')
    plt.scatter(mode, kde(mode), label='mode', marker='o')
    plt.axvline(mode, ls='--')
    plt.title(str(key))
    
    # Evaluate the two points at which things intersect
    x0s = np.percentile(kde.dataset, [2.5, 97.5])
    lims = [optimize.minimize(lambda x: (kde(x) - level) ** 2, x0)
            for x0 in x0s]
    lims = np.asarray([lim.x[0] for lim in lims])
    
    plt.axvspan(*lims, color='C1', alpha=.25, label='95% HPD')

    lines.append(f'95% HPD: {lims[0]:.3f} -- {lims[1]:.3f}')
    plt.text(0.95, 0.95, '\n'.join(lines), transform=ax.transAxes, va='top', ha='right')
    plt.legend(loc='upper left')
    
    print('\n'.join([str(key)] + lines + ['']))

# Faecal shedding profile and halflife

In [None]:
# Load the datasets.
datasets = shedding.load_datasets([
    'Woelfel2020', 
    'Lui2020', 
    'Wang2020',
    'Han2020',
], 'publications/')

In [None]:
fig, (ax_days, ax) = plt.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios': [1, 5]})

key = ('general', 'standard', 'temporal')
result = results[key]
samples = result['samples']
data = result['data']

# Add the shaded region
ax.axvspan(0, 5, facecolor='#eee', zorder=0)

# Plot the fit
lin = np.linspace(0, 40)
ys = np.exp(samples['population_loc'][:, None] + samples['slope'][:, None] * lin)
ax.fill_between(lin, *np.percentile(ys, [2.5, 97.5], axis=0), color='black', zorder=0.1, alpha=.1)
ax.plot(lin, np.median(ys, axis=0), color='k', label=r'shedding profile', 
                zorder=0)

x = data['day']
y = np.where(data['positive'], data['load'], 50)
colors_by_dataset = {}
for i, key in enumerate(set(data['dataset'])):
    color = f'C{i}'
    colors_by_dataset[key] = color
    ax.axhline(10 ** datasets[key]['loq'], color=color, ls='--', zorder=0)
    for p, m in zip([True, False], 'ox'):
        label = datasets[key]['key'] if p else None
        fltr = (data['dataset'] == key) & (data['positive'] == p)
        offset = (i - 1) / 3
        offset = 0
        pts = ax.scatter(x[fltr] + offset, y[fltr], alpha=.75, marker=m, color=color, label=label, s=10)

if False:
    for i in np.arange(data['num_patients']):
        f = data['idx'] == i
        key = data['dataset'][f][0]
        ax.plot(x[f], y[f], color=colors_by_dataset[key], alpha=.5)
        
pos = 4e5
kwargs = {
    'arrowstyle': 'simple,tail_width=40,head_width=50,head_length=15',
    'zorder': 0,
    'shrinkA': 0,
    'shrinkB': 0,
    'edgecolor': 'none',
}
patch = mpl.patches.FancyArrowPatch((5, pos), (.25, pos),
                                    facecolor='w', **kwargs)
ax.add_artist(patch)

kwargs = dict(va='center', fontsize='x-small', zorder=10)
ax.text(4.5, pos, 'uncon-\nstrained\nearly\nshedding', ha='right', **kwargs)


ax.set_yscale('log')
legend = ax.legend(ncol=2, loc='upper left', fontsize='x-small')
ax.set_ylim(20, 1e9)
ax.set_xlim(0, 36)
ax.set_xlabel('Days past symptom onset $t$')
ax.set_ylabel('Gene copies per mL')

child = ax.inset_axes((.65, .5, .33, .48))
halflife = - np.log(2) / samples['slope'].ravel()
shedding.plot_kde(halflife, ax=child, xmax=50/24, numlin=100, color='k')
child.yaxis.set_ticks([])
child.set_ylabel(r'Posterior $P(\tau_½)$')
child.set_xlabel(r'Profile half-life $\tau_½$ (days)')
#child.xaxis.set_ticks([24, 36, 48])
# child.axvline(evaluate_mode(halflife), ls='--', color='k')

# Plot the histogram
days_per_bin = 1
ax_days.hist(data['day'], range=(-.5, 35.5), bins=36 // days_per_bin, 
             weights=np.ones_like(data['day']) / days_per_bin, color='#aaa')
ax_days.set_ylabel('Samples\nper day')


child.text(0.05, 0.95, '(c)', transform=child.transAxes, ha='left', va='top')
ax_days.text(0.02, 0.9, '(a)', transform=ax_days.transAxes, ha='left', va='top')
ax.text(0.02, 0.03, '(b)', transform=ax.transAxes, ha='left', va='bottom')

fig.tight_layout()
fig.savefig('figures/decay.pdf')

# Estimates of halflife

In [None]:
for inflated in ['standard', 'inflated']:
    key = ('general', inflated, 'temporal')
    result = results[key]
    halflife = - np.log(2) / result['samples']['slope'] * 24
    
    lin = np.linspace(0, 3 * 24, 5001)
    pdf = stats.gaussian_kde(halflife)(lin)
    plt.figure()
    plt.plot(lin, pdf)
    
    level = evaluate_hpd_levels(pdf, 0.95)
    plt.axhline(level)
    delta = np.diff(level > pdf)
    idx, = np.nonzero(delta)
    plt.title(inflated)
    
    
    print('model', inflated)
    print('mode', lin[np.argmax(pdf)])
    print('95% hpd', lin[idx])
    print()

# Shedding prevalence estimates

In [None]:
for temporal in ['temporal', 'constant']:
    plt.figure()
    plt.title(temporal)
    key = ('general', 'inflated', temporal)
    result = results[key]
    rho = result['samples']['rho']
    
    # Evaluate a "reflected" kernel density estimate to account for the boundary.
    lin = np.linspace(0, 2, 2000)
    pdf = stats.gaussian_kde(rho)(lin)
    a, b = np.array_split(pdf, 2)
    pdf = a + b[::-1]
    plt.hist(rho, bins=20, density=True, alpha=.5)
    plt.plot(lin[:len(pdf)], pdf)
    
    print('model', temporal)
    print('mode', lin[np.argmax(pdf)])
    
    level = evaluate_hpd_levels(pdf, 0.95)
    plt.axhline(level)
    delta = np.diff(level > pdf)
    idx, = np.nonzero(delta)
    print('95% HPD', lin[idx])
    print()

# Shedding profile comparison (gamma/Teunis)

In [None]:
# Comparison for different shedding profiles
lin = np.linspace(-15, 30, 200)
np.random.seed(0)

with open('workspace/profile-gamma-0/result.pkl', 'rb') as fp:
    result = pickle.load(fp)
    
samples = samples_gamma = result['samples']
offsets_gamma = samples['profile_offset'][:, None]
dt = lin - offsets_gamma
profiles_gamma = np.exp(samples['population_loc'][:, None] - samples['profile_scale'][:, None] * dt) * \
    dt ** samples['profile_shape'][:, None]
    
with open('workspace/profile-teunis-0/result.pkl', 'rb') as fp:
    result = pickle.load(fp)
    
samples = samples_teunis = result['samples']
offsets_teunis = samples['profile_offset'][:, None]
dt = lin - offsets_teunis
profiles_teunis = np.exp(samples['population_loc'][:, None] - samples['profile_decay'][:, None] * dt) * \
    (1 - np.exp(-samples['profile_rise'][:, None] * dt))
profiles_teunis = np.where(profiles_teunis < .1, np.nan, profiles_teunis)

result = results[('general', 'standard', 'temporal')]
samples = result['samples']
profiles_exponential = np.exp(samples['population_loc'][:, None] + samples['slope'][:, None] * lin)

In [None]:
fig, axes = plt.subplots(2, 2, sharex=True, sharey='row', gridspec_kw={'height_ratios': [2, 5]})

top_axes, bottom_axes = axes

for ax, key in zip(axes.ravel(), 'abcd'):
    ax.axvspan(lin.min(), 5, facecolor='#eee')
    ax.text(0.95, 0.95, f'({key})', ha='right', va='top', transform=ax.transAxes)

labels = ['gamma', r'Teunis $et\ al.$']
for ax, profiles, label in zip(bottom_axes, [profiles_gamma, profiles_teunis], labels):
    idx = np.random.choice(len(profiles), 500, False)
    ax.plot(lin, profiles[idx].T, color='C0', alpha=.1)
    line, = ax.plot(lin, np.median(profiles_exponential, axis=0), color='C1')
    bounds = np.percentile(profiles_exponential, [2.5, 97.5], axis=0)
    ax.plot(lin, np.transpose(bounds), color=line.get_color(), ls='--')
    ax.fill_between(lin, *bounds, color=line.get_color(), alpha=.2, zorder=9)
    ax.set_xlabel('Days past symptom onset $t$')
    handles_labels = [
        (mpl.lines.Line2D([], [], color='C0'), f'{label} profile'),
        (mpl.lines.Line2D([], [], color='C1'), 'exponential profile'),
    ]
    ax.legend(*zip(*handles_labels), loc='lower left', fontsize=8)
    
ax.set_yscale('log')
ax.set_ylim(10, 1e11)
ax.set_xlim(lin.min(), lin.max())
axes[0, 0].set_ylabel(r'Posterior $P(t_\mathrm{peak})$')
axes[1, 0].set_ylabel('Gene copies per mL')

mode_gamma = samples_gamma['profile_offset'] + samples_gamma['profile_shape'] / samples_gamma['profile_scale']
mode_teunis = samples_teunis['profile_offset'] + np.log((samples_teunis['profile_rise'] + samples_teunis['profile_decay']) /
                                                samples_teunis['profile_decay']) / samples_teunis['profile_rise']

for ax, offsets, label in zip(top_axes, [mode_gamma, mode_teunis], labels):
    ax.hist(offsets.ravel(), range=(-14, 7), bins=21, density=True, 
            label='peak shedding\ntime 'r'$t_\mathrm{peak}$')
    ax.legend(loc='center right', fontsize=8)
    # sb.kdeplot(offsets.ravel(), ax=ax, legend=False)
    
axes[0, 0].set_title('Gamma profile')
axes[0, 1].set_title(r'Teunis $et\ al.$ profile')

if False:
    for ax in [ax1, ax2]:
        pos = 50
        length = 17
        kwargs = {
            'arrowstyle': 'simple,tail_width=30,head_width=35,head_length=20',
            'zorder': 9,
            'shrinkA': 0,
            'shrinkB': 0,
            'edgecolor': 'none',
        }
        patch = mpl.patches.FancyArrowPatch((5, pos), (5 - length, pos),
                                            facecolor='w', **kwargs)
        ax.add_artist(patch)
        patch = mpl.patches.FancyArrowPatch((5, pos), (5 + length, pos), facecolor='gray', 
                                            alpha=.15, **kwargs)
        ax.add_artist(patch)

        kwargs = dict(va='center', fontsize='x-small', zorder=10)
        ax.text(5 - 0.5, pos, 'early shedding;\ncurrently uncon-\nstrained by data', ha='right', **kwargs)
        ax.text(5 + 0.5, pos, 'late shedding;\nconsistent with ex-\nponential profile', ha='left', **kwargs)

fig.tight_layout()
fig.savefig('figures/profiles.pdf')




# Replicates and prediction

In [None]:
replicates_by_key = {}
predictions_by_study_and_key = {}
num_samples = None

prediction_hyperparameters = {
    'Kim et al.': {
        'num_patients': 38,
        'num_samples': 129,
        'loq': 1,  # loq is irrelevant because we only care about the maximum
        'min_samples_per_patient': 1,
    },
    'Ng et al.': {
        'num_patients': 21,
        'num_samples': 81,
        'loq': 10 ** 2.5403294748,
        'min_samples_per_patient': 1,
    },
}

for key, result in tqdm(results.items()):
    # Skip the one we just used for evidence calculations
    if len(key) > 3:
        continue
        
    # Potentially filter the samples
    samples = {key: value for key, value in result['samples'].items()}
    ys = shedding.transpose_samples(samples)
    if num_samples:
        ys = [ys[i] for i in np.random.choice(len(ys), num_samples, False)]
    model = result['model']
    
    # Replicate the data we have fit to for posterior predictive checks
    data = result['data']
    replicates_by_key[key] = model.simulate(ys, data, 'existing_patients')
    
    if key[-1] == 'temporal':
        continue
        
    # Predict data for unobserved studies for external validation
    with tqdm(total=len(ys) * len(prediction_hyperparameters), desc=str(key)) as progress:
        for study, hyperparameters in prediction_hyperparameters.items():
            for y in ys:
                # Sample the number of samples per patient (respecting the minimum number of samples per patient)
                samples_to_distribute = hyperparameters['num_samples'] - \
                    hyperparameters['min_samples_per_patient'] * hyperparameters['num_patients']
                if samples_to_distribute < 0:
                    raise ValueError
                p = np.ones(hyperparameters['num_patients']) / hyperparameters['num_patients']
                num_samples_by_patient = np.ones(hyperparameters['num_patients'], int) * \
                    hyperparameters['min_samples_per_patient'] + np.random.multinomial(samples_to_distribute, p)
                # Create a lookup index
                idx = np.repeat(1 + np.arange(hyperparameters['num_patients']), num_samples_by_patient)
                # Generate a prediction and add it to the list
                prediction_data = {
                    **hyperparameters, 
                    'num_samples_by_patient': num_samples_by_patient,
                    'idx': idx,
                    'day': 0,
                }
                prediction_data = model.simulate(y, prediction_data, 'new_patients')
                predictions_by_study_and_key.setdefault(study, {}).setdefault(key, []).append(
                    prediction_data
                )
                progress.update()

## Predictions for held-out-datasets

In [None]:
def plot_violin(data_sequence, target, x, color, ax, 
                reference=None, pax=None, **kwargs):
    kwargs.setdefault('showextrema', False)
    kwargs.setdefault('widths', violin_widths)
    label = kwargs.pop('label', None)
    marker = kwargs.pop('marker', 'o')
    
    values = np.fromiter(map(target, data_sequence), float)
    violins = ax.violinplot(values[:, None], [x], **kwargs)
    style_violins(violins, color=color)
    mode = evaluate_mode(values, kwargs.get('points', 200))
    ax.scatter(x, mode, marker=marker, color=color, 
               zorder=9, label=label, edgecolor='w')

    if isinstance(reference, dict):
        reference = target(reference)
    if reference and pax:
        pval = min(
            np.mean(reference < values),
            np.mean(reference > values),
        )
        # Correct for equality
        pval += np.mean(reference == values) / 2
        pax.scatter(x, pval, marker=marker, color=color, edgecolor='w')
        
    # Restore the edges
    for violin in violins['bodies']:
        facecolor = violin.get_facecolor()
        violin.set_alpha(None)
        violin.set_facecolor(facecolor)
        violin.set_edgecolor(facecolor[:, :3])
        
    return violins
pad = 0.6

## Prediction

In [None]:
fig = plt.figure()
gs = fig.add_gridspec(1, 3)
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1], sharey=ax1)
ax3 = fig.add_subplot(gs[0, 2])

# Plot the replicated sample means
references_by_study = {
    ('Ng et al.', 'max'): 7.1,
    ('Kim et al.', 'max'): 7.4363217001,
}
studies = ['Kim et al.', 'Ng et al.']
labels = [
    r'$\bar{x}^\mathrm{rep}$''\nconstant',
    r'$\bar{x}^\mathrm{rep}$''\ntemporal',
    r'$\max\,x^\mathrm{pred}$''\nKim et al.',
    r'$\max\,x^\mathrm{pred}$''\nNg et al.',
]


for i, inflated in enumerate(['standard', 'inflated']):
    j = 2
        
    def target(data):
        value = np.log10(data['load'].max())
        return value if np.isfinite(value) else 1
    
    for ax, study in zip([ax1, ax2], studies):
        j = 0
        reference = references_by_study[(study, 'max')]
        if i == 0:
            ax.plot((2 * j - violin_widths * pad, 2 * j + 1 + violin_widths * pad), 
                     (reference, reference), color='k', ls='--')
        
        key = ('general', inflated, 'constant')
        predictions_by_key = predictions_by_study_and_key[study]
        color = colors_by_key[key]
        plot_violin(predictions_by_key[key], target, i + 2 * j, color, ax, 
                    marker=markers_by_key[key])
        
        j += 1
        
    # Add the median number of positive samples for Ng et al.
    predictions_by_key = predictions_by_study_and_key['Ng et al.']
    target = lambda data: np.median(data['num_positives_by_patient'])
    if i == 0:
        reference = 2
        ax3.plot((- violin_widths * pad, 1 + violin_widths * pad), 
                 (reference, reference), color='k', ls='--')

    key = ('general', inflated, 'constant')
    color = colors_by_key[key]
    values = list(map(target, predictions_by_key[key]))
    points = int((max(values) - min(values)) / .5 + 1)
    plot_violin(predictions_by_key[key], target, i, color, ax3, 
                bw_method=1e-9, points=points, marker=markers_by_key[key],
                label='subpopu-\nlation' if inflated == 'inflated' else 'standard')
         
ax1.yaxis.set_major_formatter(log10formatter)
ax1.set_ylim(5, 11)
ax1.set_ylabel(r'Maximum $\max\,x^\mathrm{pred}$ (gene copies per mL)')
ax1.set_xlabel('Kim et al.')

ax2.set_ylabel(r'Maximum $\max\,x^\mathrm{pred}$ (gene copies per mL)')
ax2.set_xlabel('Ng et al.')

ax3.set_ylabel('Median number of positive samples\nper patient'r' $\mathrm{median}\,m_{\bullet(+)}^\mathrm{pred}$')
ax3.set_yticks(np.arange(5))
ax3.set_xlabel('Ng et al.')
ax3.legend(loc='lower center', handletextpad=.25)
ax3.set_ylim(-.85)

for ax, label in zip([ax1, ax2, ax3], '(a) (b) (c)'.split()):
    ax.text(0.05, 0.975, label, ha='left', va='top', transform=ax.transAxes)
    ax.xaxis.set_ticks([])

fig.tight_layout()
fig.savefig('figures/prediction.pdf')

## Replication

In [None]:
step = 1
# fig, axes = plt.subplots(2, 2, sharex='col', sharey=True)
fig = plt.figure()
gs = fig.add_gridspec(1, 2, width_ratios=[2, 1])
gsl = gs[0, 0].subgridspec(2, 2, wspace=.2)
ax1 = fig.add_subplot(gsl[0, 0])
ax2 = fig.add_subplot(gsl[0, 1], sharey=ax1)
ax3 = fig.add_subplot(gsl[1, 0], sharex=ax1, sharey=ax1)
ax4 = fig.add_subplot(gsl[1, 1], sharex=ax2, sharey=ax1)
axes = np.asarray([
    [ax1, ax2],
    [ax3, ax4],
])

i = 0

keys = it.product(['standard', 'inflated'], ['constant', 'temporal'])
for ax, key in zip(np.ravel(axes), keys):
    key = ('general',) + key
    replicates = replicates_by_key[key][::step]
    shedders = [np.sum(replicate['num_positives_by_patient'] > 0) for replicate in replicates]
    positive = [np.sum(replicate['positive']) for replicate in replicates]

    if key[2] == 'temporal':
        rng = [
            [14.5, 23.5],
            [82.5, 111.5],
        ]
    else:
        rng = [
            [20.5, 34.5],
            [82.5, 123.5],
        ]
    steps = [1, 1]
    bins = [int((b - a) / step) for (a, b), step in zip(rng, steps)]
    z, x, y = np.histogram2d(shedders, positive, bins=bins, range=rng, density=True)
    x = (x[1:] + x[:-1]) / 2
    y = (y[1:] + y[:-1]) / 2
    # ax.contour(x, y, z.T, cmap=alpha_cmap(colors_by_key[key]))
    ax.imshow(z.T, aspect='auto', extent=np.ravel(rng), origin='lower', cmap=alpha_cmap(colors_by_key[key]))

    # sb.kdeplot(x=shedders, y=positive, cmap=alpha_cmap(colors_by_key[key]), ax=ax)
    
    # Plot the actual value
    data = results[key]['data']
    ax.scatter(np.sum(data['num_positives_by_patient'] > 0), 
               np.sum(data['positive']), 
               marker='x', color='k', zorder=9)
    
    # Plot the 95% credible contour
    alphas = [.95, -np.expm1(-1)]
    levels = evaluate_hpd_levels(z, alphas)
    cs = ax.contour(x, y, z.T, 
                    levels=levels, colors='k', linestyles=[':', '--'])
    labels = ax.clabel(cs, fmt={level: f'{100 * alpha:.0f}%' for level, alpha in zip(levels, alphas)})
    
    label = key[2] + (" sub-\npopulation" if key[1] == "inflated" else "\nstandard")
    handle = mpl.lines.Line2D([], [], color=colors_by_key[key], marker=markers_by_key[key], ls='none')
    handle.set_markeredgecolor('w')
    ax.scatter(0.9, 0.1, color=colors_by_key[key], marker=markers_by_key[key], transform=ax.transAxes,
               edgecolor='w')
    # ax.legend([handle], [label], loc='upper left', frameon=False, handletextpad=.25)
    
    i += 1
    
ax.set_xlim(15, 23)
    
ax = fig.add_subplot(gs[0, 0], frameon=False)
plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
ax.set_xlabel('Number of positive patients $n_{(+)}^\mathrm{rep}$')
ax.set_ylabel('Number of positive samples $m_{(+)}^\mathrm{rep}$')

labels_ = [
    '(a) constant\nstandard',
    '(b) temporal\nstandard',
    '(c) constant sub-\npopulation',
    '(d) temporal sub-\npopulation',
]
for ax, key in zip(np.ravel(axes), labels_):
    ax.text(0.05, 0.95, key, ha='left', va='top', transform=ax.transAxes, size='small')
        
[plt.setp(ax.yaxis.get_ticklabels(), visible=False) for ax in axes[:, 1]]
[plt.setp(ax.xaxis.get_ticklabels(), visible=False) for ax in axes[0]]

ax = fig.add_subplot(gs[0, 1])
ax.text(0.04, 0.975, '(e)', ha='left', va='top', transform=ax.transAxes, size='small')
ax.set_ylabel(r'Sample mean $\bar{x}^\mathrm{rep}$ (gene copies per mL)')

for i, inflated in enumerate(['standard', 'inflated']):
    # Plot constant and temporal means
    j = 0
    for temporal in ['constant', 'temporal']:
        target = lambda data: np.log10(data['load'][data['positive']].mean())
        
        key = ('general', inflated, temporal)
        color = colors_by_key[key]
        result = results[key]
        reference = target(result['data'])
        if i == 0:
            x = (
                2 * j - violin_widths * pad, 
                2 * j + 1 + violin_widths * pad
            )
            ax.plot(x, (reference, reference), color='k', ls='--')
        
        plot_violin(replicates_by_key[key], target, i + 2 * j, color, ax, marker=markers_by_key[key],
                    label=f'{temporal}\n{"subpopulation" if inflated == "inflated" else inflated}')
        j += 1
        
ax.yaxis.set_major_formatter(log10formatter)
ax.set_ylim(5, 8.5)
ax.set_xlim(-.75, 3.75)
ax.set_xticks([])
ax.yaxis.set_major_locator(mpl.ticker.MaxNLocator(integer=True))
ax.set_xlabel('Models')
        
fig.tight_layout()
fig.savefig('figures/replication.pdf')
# gs.update(wspace=.35)
#for ax in axes[1]:

#for ax in axes[:, 0]:

# Effect of shape and scale parameters

In [None]:
# Precompute the density estimates to iterate faster on the plots
zzs = {}
for level in ['population', 'patient']:
    xmin, xmax = 0, 3
    if level == 'population':
        ymin, ymax = 1.5, 6.5
    else:
        ymin, ymax = 1.5, 4.75
    
    for temporal in ['constant', 'temporal']:
        key = ('general', 'standard', temporal)
        result = results[key]
        samples = result['samples']

        x = samples[f'{level}_shape'][::step].ravel()
        y = samples[f'{level}_scale'][::step].ravel()
        kde = stats.gaussian_kde((x, y))
        
        nx = 200
        linx = np.linspace(-xmax, xmax, nx)
        liny = np.linspace(ymin, ymax, 101)
        xx, yy = np.meshgrid(linx, liny)
        zz = kde.evaluate((xx.ravel(), yy.ravel())).reshape(xx.shape)
        # Flip and add
        zz = zz[:, :nx // 2:][:, ::-1] + zz[:, nx // 2:]
        linx = linx[nx // 2:]
        zzs.setdefault(key, {})[level] = (linx, liny, zz)
        
        print(level)
        print(key, 'Weibull', np.mean(samples[f'{level}_shape'] > 1))
        print(key, 'Gamma', np.mean(samples[f'{level}_shape'] > samples[f'{level}_scale']))
        
        
# Add the heatmap of the means
key = ('general', 'standard', 'constant')
result = results[key]
samples = result['samples']
q = samples['population_shape'].ravel()[::step]
mu = samples['population_loc'].ravel()[::step]
sigma = samples['population_scale'].ravel()[::step]
mean = shedding.gengamma_mean(q, mu, sigma)

ymin, ymax = 1.5e5, 1e12
# sb.kdeplot(x=q, y=mean, log_scale=(False, True), cmap=alpha_cmap(colors_by_key[key]))
kde = stats.gaussian_kde((q, np.log(mean)))
nx = 200
linx = np.linspace(-xmax, xmax, nx)
liny = np.linspace(np.log(ymin), np.log(ymax), 101)
xx, yy = np.meshgrid(linx, liny)
zz = kde.evaluate((xx.ravel(), yy.ravel())).reshape(xx.shape)
# Flip and add
linx = linx[nx // 2:]
liny = np.exp(liny)
zz = zz[:, :nx // 2:][:, ::-1] + zz[:, nx // 2:]
zzs['mean'] = (linx, liny, zz)

In [None]:
def evaluate_hpd_mass(pdf):
    """
    Evaluate the highest posterior density mass excluded from isocontours.
    
    Parameters
    ----------
    pdf : array_like
    """
    shape = np.shape(pdf)
    pdf = np.ravel(pdf)
    idx = np.argsort(-pdf)
    cum = np.cumsum(pdf[idx])
    cum /= cum[-1]
    return 1 - np.reshape(cum[np.argsort(idx)], shape)

In [None]:
fig = plt.figure()
gs = mpl.gridspec.GridSpec(2, 2)
step = 1

ax1 = fig.add_subplot(gs[0, 1])
ax2 = fig.add_subplot(gs[1, 1], sharex=ax1)
axes = [ax1, ax2]

contours95 = []

for ax, level in zip(axes, ['population', 'patient']):
    gamma = np.logspace(-2, 1)
    ax.plot(gamma, gamma, color='k', ls=':', zorder=0)
    
    xmin, xmax = 0, 2.1
    if level == 'population':
        ymin, ymax = 1.5, 6.5
    else:
        ymin, ymax = 1.5, 4.75
    
    for temporal in ['constant', 'temporal']:
        key = ('general', 'standard', temporal)
        linx, liny, zz = zzs[key][level]

        mass = evaluate_hpd_mass(zz)
        # Flip and add
        # levels = evaluate_hpd_levels(zz, 9)
        nlevels = 9
        levels = (1 + np.arange(nlevels)) / (nlevels + 1)
        
        ax.contour(linx, liny, mass, levels=levels, 
                   cmap=alpha_cmap(colors_by_key[key]))
        contours95.append(
            ax.contour(linx, liny, mass, levels=[.05], 
                       colors=colors_by_key[key], linestyles=':'))
        
        # ax.set_xscale('log')
        
    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)
    ax.yaxis.tick_right()
    ax.yaxis.set_label_position('right')
    ax.set_ylabel(f'{level.title()} scale ${scale_label[level]}$')
    ax.set_xlabel(f'{level.title()} shape ${shape_label[level]}$')
    
    
# Plot the mean
ax = fig.add_subplot(gs[:, 0], sharex=ax1)
# Add the heatmap of the means
key = ('general', 'standard', 'constant')
result = results[key]
samples = result['samples']
data = result['data']
q = samples['population_shape'].ravel()[::step]
mu = samples['population_loc'].ravel()[::step]
sigma = samples['population_scale'].ravel()[::step]
mean = shedding.gengamma_mean(q, mu, sigma)

ymin, ymax = 1.5e5, 1e9

linx, liny, zz = zzs['mean']

if True:
    # Evaluate the thresholds
    cmap = alpha_cmap(colors_by_key[key])
    pvals = (np.arange(9) + 1) / 10
    for pval in pvals:
        levels = np.squeeze([evaluate_hpd_levels(z, pval) for z in zz.T])
        x = zz / levels
        ax.contour(linx, liny, x, levels=[1], colors=[cmap(1 - pval)])

    levels = np.squeeze([evaluate_hpd_levels(z, .95) for z in zz.T])
    x = zz / levels
    contours95.append(ax.contour(linx, liny, x, levels=[1], colors=colors_by_key[key], linestyles=':'))
    
else:
    conditional = zz / np.sum(zz, axis=0, keepdims=True)
    ax.pcolormesh(evaluate_pcolormesh_edges(linx), evaluate_pcolormesh_edges(liny, 'log'), conditional)

    ax.contour(linx, liny, np.transpose([evaluate_hpd_mass(z) for z in zz.T]), levels=[.05, np.exp(-1)], 
               colors='w', linestyles=[':', '--'])
    
ax.set_ylim(ymin, ymax)
ax.set_yscale('log')
    
ax.set_ylabel(r'Mean RNA copies per mL')
ax.set_xlabel(f'Population shape ${shape_label["population"]}$')
ax.axhline(data['load'][data['positive']].mean(), color='k', ls='--')

ax.legend([
    mpl.lines.Line2D([], [], color=colors_by_key[('standard', 'constant')]),
    mpl.lines.Line2D([], [], color=colors_by_key[('standard', 'temporal')]),
    mpl.lines.Line2D([], [], color='k', ls='-.'),
    mpl.lines.Line2D([], [], color='k', ls=':'),
    mpl.lines.Line2D([], [], color='k', ls='--'),
], [
    'constant\nstandard',
    'temporal\nstandard',
    'Weibull',
    'gamma',
    'sample\nmean',
], ncol=1, loc='upper right')


clabel_pos = [
    (1.75, 4),
    (0.75, 4.5),
    (1.25, 3.5),
    (0.75, 1.75),
    (1.5, 2e7),
]

for cs, manual in zip(contours95, clabel_pos):
    cs.axes.clabel(cs, fmt={cs.levels[0]: '95%'}, manual=[manual])

y = 0.95
for ax, label in zip([ax, ax1, ax2], ['(a)', '(b) population', '(c) patient']):
    ax.text(0.05, y, label, transform=ax.transAxes, ha='left', va='top', size='small')
    ax.axvline(1, color='k', ls='-.', zorder=0)
    
fig.tight_layout()
fig.savefig('figures/shape-scale.pdf')