In [9]:
%matplotlib inline
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
from scipy.special import logsumexp
import pandas as pd
import seaborn as sns
sns.set_context('paper', font_scale=1.3)
red, blue, green = sns.color_palette('Set1', 3)

import os
from datetime import datetime, timedelta

from rakott.mpl import fig_panel_labels, fig_xlabel, fig_ylabel, savefig_bbox

from inference import find_start_day

def load_chain(job_id, country, burn_fraction=0.6):
    fname = os.path.join(output_folder, job_id, 'inference', '{}.npz'.format(country))
    inference_data = np.load(fname)
    nsteps, ndim, N, Td1, Td2, model_type = inference_data['params']
    logliks = inference_data['logliks']
    nchains = logliks.size // nsteps
    logliks = logliks.reshape(nchains, nsteps)
    nburn = int(nsteps*burn_fraction)
    logliks = logliks[:, nburn:]
    return logliks

def inliers(logliks, PLOT=False):
    chain_mean_loglik = logliks.mean(axis=1)
    std_mean_loglikg = chain_mean_loglik.std(ddof=1)
    mean_mean_loglikg = chain_mean_loglik.mean()
    idx = abs(chain_mean_loglik - mean_mean_loglikg) < 3*std_mean_loglikg
    if PLOT:
        if idx.any():
            plt.plot(logliks[idx, ::1000].T, '.k', label='inliers')
        if (~idx).any():
            plt.plot(logliks[~idx, ::1000].T, '.r', label='outliers')
        plt.ylabel('Log-likelihood')
        plt.legend()
    return idx

def WAIC(logliks):
    logliks = logliks[inliers(logliks)]
    S = logliks.size
    llpd = -np.log(S) + logsumexp(logliks)
    p1 = 2*(-np.log(S) + logsumexp(logliks) - logliks.mean())
    p2 = np.var(logliks, ddof=1)
    return -2*(llpd + -p1), -2*(llpd + -p2)

In [10]:
job_ids = ['2020-05-14-n1-normal-1M', '2020-05-14-n1-uniform-1M']
countries = 'Austria Belgium Denmark France Germany Italy Norway Spain Sweden Switzerland United_Kingdom Wuhan'.split(' ')
output_folder = r'../output'

In [16]:
%%time
results = []
for country in countries:
    for job_id in job_ids:
        chain_fname = os.path.join(output_folder, job_id, 'inference', '{}.npz'.format(country))
        logliks = load_chain(job_id, country)
        waic1, waic2 = WAIC(logliks)
        results.append(dict(
            country=country,
            job_id=job_id,
            WAIC1=waic1,
            WAIC2=waic2
        ))

CPU times: user 36.7 s, sys: 14.1 s, total: 50.8 s
Wall time: 55 s


In [17]:
df = pd.DataFrame(results)
df.loc[df['job_id'] == '2020-05-14-n1-normal-1M', 'job_id'] = 'Normal'
df.loc[df['job_id'] == '2020-05-14-n1-uniform-1M', 'job_id'] = 'Uniform'
df = df.rename(columns={'country':'Country', 'job_id':'Model'})
df['Country'] = [x.replace('_', ' ') for x in df['Country']]
df.loc[df['Country']=='Wuhan', 'Country'] = 'Wuhan China'
df.head()

Unnamed: 0,Country,Model,WAIC1,WAIC2
0,Austria,Normal,26.791404,28.401017
1,Austria,Uniform,26.876521,28.620925
2,Belgium,Normal,29.118949,30.62063
3,Belgium,Uniform,28.46068,29.671831
4,Denmark,Normal,34.969676,37.336706


In [18]:
df = pd.pivot(df, index='Country', columns='Model')
df

Unnamed: 0_level_0,WAIC1,WAIC1,WAIC2,WAIC2
Model,Normal,Uniform,Normal,Uniform
Country,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
Austria,26.791404,26.876521,28.401017,28.620925
Belgium,29.118949,28.46068,30.62063,29.671831
Denmark,34.969676,35.413878,37.336706,36.86029
France,47.749859,47.732696,49.598799,49.386935
Germany,156.852629,162.293898,158.901193,188.450518
Italy,230.995553,231.156769,233.072121,233.425663
Norway,33.640207,33.21612,36.072503,35.206514
Spain,58.049608,60.924096,59.542263,97.36206
Sweden,24.089671,23.532478,25.910516,25.595768
Switzerland,70.322332,71.941434,72.965631,81.359444


In [20]:
df = df.drop(columns='WAIC1')
df = df.droplevel(0, axis=1)
df.head()

KeyError: "['WAIC1'] not found in axis"

In [23]:
idx = df['Normal']==df.min(axis=1)
df.loc[idx, 'Normal'] = ['\\textbf{'+'{:.2f}'.format(x)+'}' for x in df.loc[idx, 'Normal']] 
df.loc[~idx, 'Normal'] = ['{:.2f}'.format(x) for x in df.loc[~idx, 'Normal']] 
df.head()

Model,Normal,Uniform
Country,Unnamed: 1_level_1,Unnamed: 2_level_1
Austria,\textbf{28.40},28.620925
Belgium,30.62,29.671831
Denmark,37.34,36.86029
France,49.60,49.386935
Germany,\textbf{158.90},188.450518


In [25]:
df.to_csv('Table-WAIC-Uniform.csv', index='Country', escapechar='@', float_format="%.2f")