In [152]:
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
from scipy.special import logsumexp
import pandas as pd
import seaborn as sns
sns.set_context('paper', font_scale=1.3)
red, blue, green = sns.color_palette('Set1', 3)

import os
from datetime import datetime, timedelta

from rakott.mpl import fig_panel_labels, fig_xlabel, fig_ylabel, savefig_bbox

from inference import find_start_day
from ppc import load_data

def load_chain(job_id, country, burn_fraction=0.6):
    fname = os.path.join(output_folder, job_id, 'inference', '{}.npz'.format(country))
    inference_data = np.load(fname)
    nsteps, ndim, N, Td1, Td2, model_type = inference_data['params']
    logliks = inference_data['logliks']
    nchains = logliks.size // nsteps
    logliks = logliks.reshape(nchains, nsteps)
    nburn = int(nsteps*burn_fraction)
    logliks = logliks[:, nburn:]
    return logliks.ravel()

def WAIC(logliks):
    S = logliks.size
    llpd = -np.log(S) + logsumexp(logliks)
    p1 = 2*(-np.log(S) + logsumexp(logliks) - logliks.mean())
    p2 = np.var(logliks, ddof=1)
    return -2*(llpd + -p1), -2*(llpd + -p2)

In [185]:
job_ids = ['2020-05-14-n1-normal-1M', '2020-05-14-n1-notau-1M', '2020-05-15-n1-fixed-tau-1M']
countries = 'Austria Belgium Denmark France Germany Italy Norway Spain Sweden Switzerland United_Kingdom Wuhan'.split(' ')
output_folder = r'/Users/yoavram/Library/Mobile Documents/com~apple~CloudDocs/EffectiveNPI-Data/output'

In [None]:
results = []
for country in countries:
    for job_id in job_ids:
        chain_fname = os.path.join(output_folder, job_id, 'inference', '{}.npz'.format(country))
        logliks = load_chain(job_id, country)
        waic1, waic2 = WAIC(logliks)
        results.append(dict(
            country=country,
            job_id=job_id,
            WAIC1=waic1,
            WAIC2=waic2
        ))

In [None]:
df = pd.DataFrame(results)
df.loc[df['job_id'] == '2020-05-14-n1-normal-1M', 'job_id'] = 'Free'
df.loc[df['job_id'] == '2020-05-14-n1-notau-1M', 'job_id'] = 'No'
df.loc[df['job_id'] == '2020-05-15-n1-fixed-tau-1M', 'job_id'] = 'Fixed'
df = df.rename(columns={'country':'Country', 'job_id':'Model'})
df['Country'] = [x.replace('_', ' ') for x in df['Country']]
df.loc[df['Country']=='Wuhan', 'Country'] = 'Wuhan, China'
df.head()

In [None]:
df = pd.pivot(df, index='Country', columns='Model')
df

In [None]:
df = df.drop(columns='WAIC1')
df = df.droplevel(0, axis=1)
df.head()

In [None]:
df.to_csv('../figures/Table-WAIC.csv', index='Country', float_format="%.2f")

In [184]:
%cat ../figures/Table-WAIC.csv

Country,Fixed,Free,No
Austria,26.68,29.80,39.73
Belgium,29.38,30.62,28.81
Denmark,38.56,37.29,49.63
France,50.04,50.59,72.17
Germany,214.99,174.24,310.69
Italy,301.52,233.13,609.26
Norway,34.21,36.07,37.54
Spain,59.90,92.55,141.60
Sweden,25.93,25.86,28.36
Switzerland,74.85,73.07,99.74
United Kingdom,38.10,37.49,35.77


In [161]:
df.reset_index(level=0)

Unnamed: 0_level_0,Country,WAIC2,WAIC2,WAIC2
Model,Unnamed: 1_level_1,Fixed τ,Free τ,Νο τ
0,Austria,26.678583,29.796479,39.730292
1,Belgium,29.383127,30.62063,28.81343
2,Denmark,38.561618,37.288872,49.630238
3,France,50.036671,50.593087,72.172448
4,Germany,214.989565,174.240162,310.693855
5,Italy,301.521641,233.131936,609.255455
6,Norway,34.205331,36.072503,37.540257
7,Spain,59.897205,92.550938,141.597982
8,Sweden,25.932879,25.8591,28.363601
9,Switzerland,74.848207,73.068505,99.744745


Model,Fixed τ,Free τ,Νο τ
Country,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
Austria,26.678583,29.796479,39.730292
Belgium,29.383127,30.62063,28.81343
Denmark,38.561618,37.288872,49.630238
France,50.036671,50.593087,72.172448
Germany,214.989565,174.240162,310.693855
Italy,301.521641,233.131936,609.255455
Norway,34.205331,36.072503,37.540257
Spain,59.897205,92.550938,141.597982
Sweden,25.932879,25.8591,28.363601
Switzerland,74.848207,73.068505,99.744745
