In [73]:
%matplotlib inline
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
from scipy.special import logsumexp
import pandas as pd
import seaborn as sns
sns.set_context('paper', font_scale=1.3)
red, blue, green = sns.color_palette('Set1', 3)

import os
from datetime import datetime, timedelta

from rakott.mpl import fig_panel_labels, fig_xlabel, fig_ylabel, savefig_bbox

from inference import find_start_day

def load_chain(job_id, country, burn_fraction=0.6):
    fname = os.path.join(output_folder, job_id, 'inference', '{}.npz'.format(country))
    inference_data = np.load(fname)
    nsteps, ndim, N, Td1, Td2, model_type = inference_data['params']
    logliks = inference_data['logliks']
    nchains = logliks.size // nsteps
    logliks = logliks.reshape(nchains, nsteps)
    nburn = int(nsteps*burn_fraction)
    logliks = logliks[:, nburn:]
    return logliks

def inliers(logliks, PLOT=False):
    chain_mean_loglik = logliks.mean(axis=1)
    std_mean_loglikg = chain_mean_loglik.std(ddof=1)
    mean_mean_loglikg = chain_mean_loglik.mean()
    idx = abs(chain_mean_loglik - mean_mean_loglikg) < 3*std_mean_loglikg
    if PLOT:
        if idx.any():
            plt.plot(logliks[idx, ::1000].T, '.k', label='inliers')
        if (~idx).any():
            plt.plot(logliks[~idx, ::1000].T, '.r', label='outliers')
        plt.ylabel('Log-likelihood')
        plt.legend()
    return idx

def WAIC(logliks):
    logliks = logliks[inliers(logliks)]
    S = logliks.size
    llpd = -np.log(S) + logsumexp(logliks)
    p1 = 2*(-np.log(S) + logsumexp(logliks) - logliks.mean())
    p2 = np.var(logliks, ddof=1)
    return -2*(llpd + -p1), -2*(llpd + -p2)

In [74]:
job_ids = ['2020-05-07-deltat0-normal','2020-04-30-prior-walkers-model2-normal','2020-05-09-zero-2-normal','2020-05-09-zero-5-normal', '2020-05-09-zero-10-normal', '2020-05-09-zero-15-normal', ]
countries = 'Austria Belgium Denmark France Germany Italy Norway Spain Sweden Switzerland United_Kingdom'.split(' ')
output_folder = r'../output'

In [116]:
%%time
results = []
for country in countries:
    for job_id in job_ids:
        chain_fname = os.path.join(output_folder, job_id, 'inference', '{}.npz'.format(country))
        logliks = load_chain(job_id, country)
        waic1, waic2 = WAIC(logliks)
        results.append(dict(
            country=country,
            job_id=job_id,
#             WAIC1=waic1,
            WAIC2=waic2
        ))

CPU times: user 7.79 s, sys: 1.15 s, total: 8.94 s
Wall time: 9.8 s


In [117]:
job_ids = ['2020-05-07-deltat0-normal','2020-04-30-prior-walkers-model2-normal','2020-05-09-zero-2-normal','2020-05-09-zero-5-normal', '2020-05-09-zero-10-normal', '2020-05-09-zero-15-normal', ]

df = pd.DataFrame(results)
df.loc[df['job_id'] == '2020-05-07-deltat0-normal', 'job_id'] = 'free'
df.loc[df['job_id'] == '2020-04-30-prior-walkers-model2-normal', 'job_id'] = '1'
df.loc[df['job_id'] == '2020-05-09-zero-2-normal', 'job_id'] = '2'
df.loc[df['job_id'] == '2020-05-09-zero-5-normal', 'job_id'] = '5'
df.loc[df['job_id'] == '2020-05-09-zero-10-normal', 'job_id'] = '10'
df.loc[df['job_id'] == '2020-05-09-zero-15-normal', 'job_id'] = '15'

df = df.rename(columns={'country':'Country', 'job_id':'Model'})
df['Country'] = [x.replace('_', ' ') for x in df['Country']]
df.loc[df['Country']=='Wuhan', 'Country'] = 'Wuhan China'
df.head()

Unnamed: 0,Country,Model,WAIC2
0,Austria,free,28.329929
1,Austria,1,32.426785
2,Austria,2,28.87825
3,Austria,5,47.718124
4,Austria,10,31.038937


In [118]:
df = pd.pivot(df, index='Country', columns='Model')
df = df.droplevel(0, axis=1)
df=df[df.columns.reindex(['free','1','2','5','10','15'])[0]]
df = df.round(2)
df.head()

Model,free,1,2,5,10,15
Country,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Austria,28.33,32.43,28.88,47.72,31.04,31.11
Belgium,30.97,30.23,30.35,32.46,33.45,34.24
Denmark,39.48,37.44,53.46,54.47,59.06,109.28
France,78.79,49.48,58.36,198.9,87.43,71.61
Germany,211.06,830.03,531.26,258.08,271.99,457.76


In [121]:
df.idxmin(axis=1)


Country
Austria           free
Belgium              1
Denmark              1
France               1
Germany           free
Italy             free
Norway               1
Spain                1
Sweden               1
Switzerland          2
United Kingdom       1
dtype: object

In [111]:
def bold_one(df, column_str):
    idx = df[column_str]==df.min(axis=1)
    df.loc[idx, column_str] = ['\\textbf{'+'{:.2f}'.format(x)+'}' for x in df.loc[idx, column_str]] 
def bold_all(df, columns):
    minidxs = df.idxmin(axis=1)
    for i in columns:
        idx = i==minidxs
        df.loc[idx, i] = ['\\textbf{'+'{:.2f}'.format(x)+'}' for x in df.loc[idx, i]] 

In [112]:
bold_all(df, ['free','1','2','5','10','15']):

In [132]:
df.to_csv('Table-WAIC-Zeros-Normal.csv', index='Country', escapechar='@', float_format="%.2f")