In [12]:
%matplotlib inline
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
from scipy.special import logsumexp
import pandas as pd
import seaborn as sns
sns.set_context('paper', font_scale=1.3)
red, blue, green = sns.color_palette('Set1', 3)

import os
from datetime import datetime, timedelta

from rakott.mpl import fig_panel_labels, fig_xlabel, fig_ylabel, savefig_bbox

from inference import find_start_day

def load_chain(job_id, country, burn_fraction=0.6):
    fname = os.path.join(output_folder, job_id, 'inference', '{}.npz'.format(country))
    inference_data = np.load(fname)
    nsteps, ndim, N, Td1, Td2, model_type = inference_data['params']
    logliks = inference_data['logliks']
    nchains = logliks.size // nsteps
    logliks = logliks.reshape(nchains, nsteps)
    nburn = int(nsteps*burn_fraction)
    logliks = logliks[:, nburn:]
    return logliks

def inliers(logliks, PLOT=False):
    chain_mean_loglik = logliks.mean(axis=1)
    std_mean_loglikg = chain_mean_loglik.std(ddof=1)
    mean_mean_loglikg = chain_mean_loglik.mean()
    idx = abs(chain_mean_loglik - mean_mean_loglikg) < 3*std_mean_loglikg
    if PLOT:
        if idx.any():
            plt.plot(logliks[idx, ::1000].T, '.k', label='inliers')
        if (~idx).any():
            plt.plot(logliks[~idx, ::1000].T, '.r', label='outliers')
        plt.ylabel('Log-likelihood')
        plt.legend()
    return idx

def WAIC(logliks):
    logliks = logliks[inliers(logliks)]
    S = logliks.size
    llpd = -np.log(S) + logsumexp(logliks)
    p1 = 2*(-np.log(S) + logsumexp(logliks) - logliks.mean())
    p2 = np.var(logliks, ddof=1)
    return -2*(llpd + -p1), -2*(llpd + -p2)

In [36]:
job_ids = ['2020-05-14-n1-normal-1M','2020-05-27-freeTd-1-15']
countries = 'Austria Belgium Denmark France Germany Italy Norway Spain Sweden Switzerland United_Kingdom'.split(' ')
output_folder = r'../output'

In [37]:
%%time
results = []
for country in countries:
    for job_id in job_ids:
        chain_fname = os.path.join(output_folder, job_id, 'inference', '{}.npz'.format(country))
        logliks = load_chain(job_id, country)
        waic1, waic2 = WAIC(logliks)
        results.append(dict(
            country=country,
            job_id=job_id,
#             WAIC1=waic1,
            WAIC2=waic2
        ))

CPU times: user 41.4 s, sys: 14.9 s, total: 56.3 s
Wall time: 1min 24s


In [38]:
# job_ids = ['2020-05-14-n1-normal-1M','2020-05-27-more1week','2020-05-26-more2weeks','2020-05-25-normal-endapril-1M']

df = pd.DataFrame(results)
df.loc[df['job_id'] == '2020-05-14-n1-normal-1M', 'job_id'] = 'Tds: 9 and 6'
df.loc[df['job_id'] == '2020-05-27-freeTd-1-15', 'job_id'] = 'free Tds ~U[0,15]'

df = df.rename(columns={'country':'Country', 'job_id':'Model'})
df['Country'] = [x.replace('_', ' ') for x in df['Country']]
df.loc[df['Country']=='Wuhan', 'Country'] = 'Wuhan China'
df.head()

Unnamed: 0,Country,Model,WAIC2
0,Austria,Tds: 9 and 6,28.401017
1,Austria,"free Tds ~U[0,15]",28.75536
2,Belgium,Tds: 9 and 6,30.62063
3,Belgium,"free Tds ~U[0,15]",31.224699
4,Denmark,Tds: 9 and 6,37.336706


In [39]:
df = pd.pivot(df, index='Country', columns='Model')
df = df.droplevel(0, axis=1)
# df=df[df.columns.reindex(['free','1','2','5','10','15'])[0]]
df = df.round(2)
df.head()

Model,Tds: 9 and 6,"free Tds ~U[0,15]"
Country,Unnamed: 1_level_1,Unnamed: 2_level_1
Austria,28.4,28.76
Belgium,30.62,31.22
Denmark,37.34,39.99
France,49.6,50.07
Germany,158.9,67.67


In [40]:
df.idxmin(axis=1)


Country
Austria                Tds: 9 and 6
Belgium                Tds: 9 and 6
Denmark                Tds: 9 and 6
France                 Tds: 9 and 6
Germany           free Tds ~U[0,15]
Italy             free Tds ~U[0,15]
Norway                 Tds: 9 and 6
Spain                  Tds: 9 and 6
Sweden                 Tds: 9 and 6
Switzerland       free Tds ~U[0,15]
United Kingdom         Tds: 9 and 6
dtype: object

In [41]:
def bold_one(df, column_str):
    idx = df[column_str]==df.min(axis=1)
    df.loc[idx, column_str] = ['\\textbf{'+'{:.2f}'.format(x)+'}' for x in df.loc[idx, column_str]] 
def bold_all(df, columns):
    minidxs = df.idxmin(axis=1)
    for i in columns:
        idx = i==minidxs
        df.loc[idx, i] = ['\\textbf{'+'{:.2f}'.format(x)+'}' for x in df.loc[idx, i]] 

In [42]:
bold_all(df, list(df.columns))

In [45]:
df

Model,Tds: 9 and 6,"free Tds ~U[0,15]"
Country,Unnamed: 1_level_1,Unnamed: 2_level_1
Austria,\textbf{28.40},28.76
Belgium,\textbf{30.62},31.22
Denmark,\textbf{37.34},39.99
France,\textbf{49.60},50.07
Germany,158.9,\textbf{67.67}
Italy,233.07,\textbf{61.69}
Norway,\textbf{36.07},36.91
Spain,\textbf{59.54},60.48
Sweden,\textbf{25.91},26.47
Switzerland,72.97,\textbf{60.85}


In [46]:
df.to_csv('Table-WAIC-free-Tds.csv', index='Country', escapechar='@', float_format="%.2f")