In [None]:
import argparse
import functools as ft
import itertools as it
from localscope import localscope
import logging
logging.basicConfig(level='WARNING')
from matplotlib import pyplot as plt
import matplotlib as mpl
import numpy as np
import os
import pypolychord as pc
import pypolychord.settings
import shedding
from scipy import stats, special
from tqdm.notebook import tqdm
import sys
import pickle

mpl.rcParams['figure.dpi'] = 144

args = os.environ.get('ARGS', "-i -s 0 -l 1 -r 1 -f --day-noise 1 general workspace/test").split()
print(args)
parser = argparse.ArgumentParser()
# Positional
parser.add_argument('parametrisation', type=shedding.Parametrisation, 
                    help='parametrisation to use')
parser.add_argument('basedir')

# Keyword
parser.add_argument('--inflated', '-i', action='store_true', 
                    help='whether to use a zero-inflated model')
parser.add_argument('--temporal', '-t', choices=['gamma', 'exponential', 'teunis'],
                    help='whether to use a time-dependent model', default=False)
parser.add_argument('--force', '-f', action='store_true', 
                    help='force regeneration of samples')
parser.add_argument('--seed', '-s', type=int, help='random number generator seed')
parser.add_argument('--nlive-factor', '-l', type=float, default=25, 
                    help='multiplicative factor for number of live points')
parser.add_argument('--nrepeat-factor', '-r', type=float, default=5, 
                    help='multiplicative factor for number of monte carl steps')
parser.add_argument('--evidence', '-e', action='store_true', help='focus on evaluating the evidence, '
                    'make sure we use the same data for temporal and constant models')
parser.add_argument('--day-noise', help='amount of noise to add to days past symptom onset',
                    type=int, default=0)
args = parser.parse_args(args)
os.makedirs(args.basedir, exist_ok=True)
args

In [None]:
np.random.seed(args.seed)

# Load and flatten some of the datasets
datasets = [
    'Woelfel2020', 
    'Lui2020', 
    'Han2020',
]
# Include Wang for non-temporal models if we just want samples, not evidence
if not args.evidence and not args.temporal:
    datasets.append('Wang2020')
    
datasets = shedding.load_datasets(datasets, 'publications/')

# Add random bits of noise to the days to test whether there's sensitivity to misreporting
# the number of days since symptom onset
for dataset in datasets.values():
    noise_by_patient = {}
    for load in dataset['loads']:
        noise = noise_by_patient.setdefault(
            load['patient'], 
            np.random.randint(-args.day_noise, args.day_noise + 1)
        )
        if 'day' in load:
            load['day'] += noise


data = shedding.flatten_datasets(datasets, loq_fill_value=-99)
print(f'Number of patients: {data["num_patients"]}')
print(f'Number of patients with one or more positive samples: '
      f'{(data["num_positives_by_patient"] > 0).sum()}')
print(f'Number of samples: {data["num_samples"]}')
print(f'Number of positive samples: {data["positive"].sum()}')

model = shedding.Model(data['num_patients'], parametrisation=args.parametrisation, inflated=args.inflated,
                       temporal=args.temporal)
print(f'Number of parameters: {model.size}')

with open(os.path.join(args.basedir, 'model.pkl'), 'wb') as fp:
    pickle.dump(model, fp)
    
with open(os.path.join(args.basedir, 'data.pkl'), 'wb') as fp:
    pickle.dump(data, fp)

In [None]:
settings = pc.settings.PolyChordSettings(model.size, 0)
settings.base_dir = args.basedir
settings.file_root = 'chain'
settings.read_resume = False
settings.feedback = 3
settings.num_repeats = int(args.nrepeat_factor * model.size)
settings.boost_posterior = min(settings.num_repeats, 10)
settings.nlive = int(args.nlive_factor * model.size)
settings.seed = -1 if args.seed is None else args.seed
settings.write_resume = False

filename = os.path.join(settings.base_dir, settings.file_root + '.paramnames')
shedding.write_paramnames_file(model.parameters, filename)

vars(settings)

In [None]:
# Validate that we can sample and evaluate
x = np.random.uniform(size=model.size)
y = model.sample_params_from_vector(x)
model.evaluate_log_likelihood_from_vector(y, data)

In [None]:
shedding.vector_to_values(model.parameters, y)

In [None]:
filename = os.path.join(settings.base_dir, settings.file_root + '.txt')
if os.path.isfile(filename) and not args.force:
    print(f'{filename} already exists; remove it to regenerate samples')
else:
    log_likelihood = ft.partial(model.evaluate_log_likelihood_from_vector, data=data)
    with tqdm() as progress: 
        output = pc.run_polychord(log_likelihood, model.size, 0, settings, model.sample_params_from_vector, 
                                  lambda *args: progress.update())

In [None]:
samples = np.loadtxt(os.path.join(settings.base_dir, settings.file_root + '_equal_weights.txt'))[:, 2:]
num_samples = len(samples)
print(f'Obtained {num_samples} samples.')
samples = shedding.transpose_samples(samples, model.parameters)


# Save the results in a format that makes our life easier
with open(os.path.join(args.basedir, 'result.pkl'), 'wb') as fp:
    pickle.dump({
        'samples': samples,
        'model': model,
        'args': args,
        'evidence': (output.logZ, output.logZerr),
        'local_evidences': (output.logZs, output.logZerrs),
        'data': data,
    }, fp)

In [None]:
# Create a figure with the right number of columns
model_pars = [par for par, value in samples.items() if not par.endswith('_') and
              np.ndim(value) < 2]
nrows = ncols = len(model_pars) - 1
fig, axes = plt.subplots(nrows, ncols, sharex='col', sharey='row',
                         figsize=(6, 6))

# Iterate over the pairs
step = max(1, num_samples // 500)
for i in range(len(model_pars) - 1):
    a = model_pars[i]
    for j in range(1, len(model_pars)):
        ax = axes[i, j - 1]
        if j <= i:
            ax.set_axis_off()
            continue
            
        b = model_pars[j]
        x = samples[a][::step]
        y = samples[b][::step]

        ax.scatter(y, x, marker='.', alpha=.1)
        if i == 0:
            ax.set_xlabel(b, size='small')
            ax.xaxis.set_label_position('top')
            ax.xaxis.tick_top()
            ax.xaxis.set_tick_params(which='both', labeltop=True)
            
        if j == ncols:
            ax.set_ylabel(a, size='small')
            ax.yaxis.set_label_position('right')
            ax.yaxis.tick_right()
            ax.yaxis.set_tick_params(which='both', labelright=True)
            
        # Draw prior bounds if they exist
        for k, line in zip([b, a], [ax.axvline, ax.axhline]):
            prior = model.priors.get(k)
            if not prior:
                continue
            bounds = prior.bounds
            for bound in bounds:
                if bound is not None:
                    line(bound, color='k', ls=':')
            
fig.tight_layout()

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)
for ax, key in zip((ax1, ax2), ('population', 'patient')):
    a = f'{key}_shape'
    b = f'{key}_scale'
    x = samples[a]
    y = samples[b]
    ax.scatter(x, y, marker='.', alpha=.5)
    ax.set_xlabel(r'Shape $q$')
    
    xmin = max(x.min(), y.min())
    xmax = min(x.max(), y.max())
    lin = np.linspace(xmin, xmax)
    ax.plot(lin, lin, color='k', ls='--')
    
    ax.axvline(1, color='k', ls=':')
    # ax.set_aspect('equal')
    if False:
        ax.set_xscale('log')
        ax.set_yscale('log')
        
ax1.set_ylabel(r'Scale $\sigma$')
ax1.set_title('Population')
ax2.set_title('Patient')
    
fig.tight_layout()

In [None]:
zipped = zip(samples['population_shape'], samples['population_loc'], samples['population_scale'])
x = np.log10([shedding.gengamma_mean(q, mu, sigma) for q, mu, sigma in zipped])
if model.inflated:
    x += np.log10(samples['rho'])

fig, ax = plt.subplots()
ax.hist(x, density=True, alpha=.5)
shedding.plot_kde(x, color='C0', ax=ax, label=r'population mean $\langle y\rangle$')
ax.axvline(np.log10(data['load'].max()), color='k', ls=':', label=r'$\max x$')
ax.axvline(np.log10(data['load'][data['positive']].mean()), color='k', ls='--', 
           label=r'sample mean $\bar x$')
ax.set_xlabel(r'$\log_{10}$ gene copies per mL')
ax.legend()
fig.tight_layout()

In [None]:
if model.inflated:
    fig, ax = plt.subplots()
    ax.hist(samples['rho'], density=True, range=(samples['rho'].min(), 1), alpha=.5)
    ax.axvline(np.mean(data['num_positives_by_patient'] > 0), color='k', ls='--')
    ax.set_xlabel(r'$\rho$')
    fig.tight_layout()

In [None]:
if model.temporal != shedding.Profile.CONSTANT:
    fig, ax = plt.subplots()

    keys = set(data['dataset'])
    for key in keys:
        fltr = data['positive'] & (data['dataset'] == key)
        pts = ax.scatter(data['day'][fltr], data['load'][fltr], alpha=.5, 
                         label=key)
        c = np.squeeze(pts.get_facecolor())
        ax.axhline(data['loq'][fltr][0], color=c, ls=':')

    lin = np.linspace(-21, data['day'][data['positive']].max(), 200)
    profile = np.exp(samples['population_loc'] + 
                     model.temporal.evaluate_offset(lin[:, None], samples)).T
    line, = ax.plot(lin, np.median(profile, axis=0), color='k')
    ax.fill_between(lin, *np.percentile(profile, [2.5, 97.5], axis=0), alpha=.2, color=line.get_color())
    ax.set_yscale('log')
    ax.legend()

In [None]:
if model.temporal != shedding.Profile.CONSTANT:
    fig, ax = plt.subplots()
    scale = None
    for key in ['profile_decay', 'profile_scale']:
        if key in samples:
            scale = samples[key]
            break
    if scale is None:
        scale = - samples['slope']
    halflife = 24 * np.log(2) / scale
    plt.hist(halflife, density=True, bins=20)