In [152]:
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
from scipy.special import logsumexp
import pandas as pd
import seaborn as sns
sns.set_context('paper', font_scale=1.3)
red, blue, green = sns.color_palette('Set1', 3)

import os
from datetime import datetime, timedelta

from rakott.mpl import fig_panel_labels, fig_xlabel, fig_ylabel, savefig_bbox

from inference import find_start_day
from ppc import load_data

def load_chain(job_id, country, burn_fraction=0.6):
    fname = os.path.join(output_folder, job_id, 'inference', '{}.npz'.format(country))
    inference_data = np.load(fname)
    nsteps, ndim, N, Td1, Td2, model_type = inference_data['params']
    logliks = inference_data['logliks']
    nchains = logliks.size // nsteps
    logliks = logliks.reshape(nchains, nsteps)
    nburn = int(nsteps*burn_fraction)
    logliks = logliks[:, nburn:]
    return logliks.ravel()

def WAIC(logliks):
    S = logliks.size
    llpd = -np.log(S) + logsumexp(logliks)
    p1 = 2*(-np.log(S) + logsumexp(logliks) - logliks.mean())
    p2 = np.var(logliks, ddof=1)
    return -2*(llpd + -p1), -2*(llpd + -p2)

In [185]:
job_ids = ['2020-05-14-n1-normal-1M', '2020-05-14-n1-notau-1M', '2020-05-15-n1-fixed-tau-1M']
countries = 'Austria Belgium Denmark France Germany Italy Norway Spain Sweden Switzerland United_Kingdom Wuhan'.split(' ')
output_folder = r'/Users/yoavram/Library/Mobile Documents/com~apple~CloudDocs/EffectiveNPI-Data/output'

In [186]:
results = []
for country in countries:
    for job_id in job_ids:
        chain_fname = os.path.join(output_folder, job_id, 'inference', '{}.npz'.format(country))
        logliks = load_chain(job_id, country)
        waic1, waic2 = WAIC(logliks)
        results.append(dict(
            country=country,
            job_id=job_id,
            WAIC1=waic1,
            WAIC2=waic2
        ))

In [214]:
df = pd.DataFrame(results)
df.loc[df['job_id'] == '2020-05-14-n1-normal-1M', 'job_id'] = 'Free'
df.loc[df['job_id'] == '2020-05-14-n1-notau-1M', 'job_id'] = 'No'
df.loc[df['job_id'] == '2020-05-15-n1-fixed-tau-1M', 'job_id'] = 'Fixed'
df = df.rename(columns={'country':'Country', 'job_id':'Model'})
df['Country'] = [x.replace('_', ' ') for x in df['Country']]
df.loc[df['Country']=='Wuhan', 'Country'] = 'Wuhan China'
df.head()

Unnamed: 0,Country,Model,WAIC1,WAIC2
0,Austria,Free,27.230978,29.796479
1,Austria,No,38.799347,39.730292
2,Austria,Fixed,25.592279,26.678583
3,Belgium,Free,29.118949,30.62063
4,Belgium,No,28.003337,28.81343


In [215]:
df = pd.pivot(df, index='Country', columns='Model')
df

Unnamed: 0_level_0,WAIC1,WAIC1,WAIC1,WAIC2,WAIC2,WAIC2
Model,Fixed,Free,No,Fixed,Free,No
Country,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
Austria,25.592279,27.230978,38.799347,26.678583,29.796479,39.730292
Belgium,28.190326,29.118949,28.003337,29.383127,30.62063,28.81343
Denmark,33.650988,34.950003,48.660911,38.561618,37.288872,49.630238
France,47.830982,48.04015,70.374954,50.036671,50.593087,72.172448
Germany,213.577046,159.520948,308.878349,214.989565,174.240162,310.693855
Italy,299.827885,231.031183,435.082876,301.521641,233.131936,609.255455
Norway,32.48134,33.640207,36.682201,34.205331,36.072503,37.540257
Spain,58.545959,60.329175,140.685038,59.897205,92.550938,141.597982
Sweden,23.51263,24.15307,27.472755,25.932879,25.8591,28.363601
Switzerland,72.751101,70.361504,98.566775,74.848207,73.068505,99.744745


In [216]:
df = df.drop(columns='WAIC1')
df = df.droplevel(0, axis=1)
df.head()

Model,Fixed,Free,No
Country,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
Austria,26.678583,29.796479,39.730292
Belgium,29.383127,30.62063,28.81343
Denmark,38.561618,37.288872,49.630238
France,50.036671,50.593087,72.172448
Germany,214.989565,174.240162,310.693855


In [217]:
df.to_csv('../figures/Table-WAIC.csv', index='Country', escapechar='@', float_format="%.2f")