In [12]:
%matplotlib inline
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats
from scipy.special import logsumexp
import pandas as pd
import seaborn as sns
sns.set_context('paper', font_scale=1.3)
red, blue, green = sns.color_palette('Set1', 3)

import os
from datetime import datetime, timedelta

from rakott.mpl import fig_panel_labels, fig_xlabel, fig_ylabel, savefig_bbox

from inference import find_start_day

def load_chain(job_id, country, burn_fraction=0.6):
    fname = os.path.join(output_folder, job_id, 'inference', '{}.npz'.format(country))
    inference_data = np.load(fname)
    nsteps, ndim, N, Td1, Td2, model_type = inference_data['params']
    logliks = inference_data['logliks']
    nchains = logliks.size // nsteps
    logliks = logliks.reshape(nchains, nsteps)
    nburn = int(nsteps*burn_fraction)
    logliks = logliks[:, nburn:]
    return logliks

def inliers(logliks, PLOT=False):
    chain_mean_loglik = logliks.mean(axis=1)
    std_mean_loglikg = chain_mean_loglik.std(ddof=1)
    mean_mean_loglikg = chain_mean_loglik.mean()
    idx = abs(chain_mean_loglik - mean_mean_loglikg) < 3*std_mean_loglikg
    if PLOT:
        if idx.any():
            plt.plot(logliks[idx, ::1000].T, '.k', label='inliers')
        if (~idx).any():
            plt.plot(logliks[~idx, ::1000].T, '.r', label='outliers')
        plt.ylabel('Log-likelihood')
        plt.legend()
    return idx

def WAIC(logliks):
    logliks = logliks[inliers(logliks)]
    S = logliks.size
    llpd = -np.log(S) + logsumexp(logliks)
    p1 = 2*(-np.log(S) + logsumexp(logliks) - logliks.mean())
    p2 = np.var(logliks, ddof=1)
    return -2*(llpd + -p1), -2*(llpd + -p2)

In [14]:
job_ids = ['2020-05-14-n1-normal-1M','2020-05-27-more1week','2020-05-26-more2weeks','2020-05-25-normal-endapril-1M']
countries = 'Austria Belgium Denmark France Germany Italy Norway Spain Sweden Switzerland United_Kingdom'.split(' ')
output_folder = r'../output'

In [15]:
%%time
results = []
for country in countries:
    for job_id in job_ids:
        chain_fname = os.path.join(output_folder, job_id, 'inference', '{}.npz'.format(country))
        logliks = load_chain(job_id, country)
        waic1, waic2 = WAIC(logliks)
        results.append(dict(
            country=country,
            job_id=job_id,
#             WAIC1=waic1,
            WAIC2=waic2
        ))

CPU times: user 1min 25s, sys: 31.3 s, total: 1min 56s
Wall time: 2min 59s


In [16]:
# job_ids = ['2020-05-14-n1-normal-1M','2020-05-27-more1week','2020-05-26-more2weeks','2020-05-25-normal-endapril-1M']

df = pd.DataFrame(results)
df.loc[df['job_id'] == '2020-05-14-n1-normal-1M', 'job_id'] = '28M-March'
df.loc[df['job_id'] == '2020-05-27-more1week', 'job_id'] = 'more 1 week'
df.loc[df['job_id'] == '2020-05-26-more2weeks', 'job_id'] = 'more 2 weeks'
df.loc[df['job_id'] == '2020-05-25-normal-endapril-1M', 'job_id'] = 'end-April'

df = df.rename(columns={'country':'Country', 'job_id':'Model'})
df['Country'] = [x.replace('_', ' ') for x in df['Country']]
df.loc[df['Country']=='Wuhan', 'Country'] = 'Wuhan China'
df.head()

Unnamed: 0,Country,Model,WAIC2
0,Austria,28M-March,28.401017
1,Austria,more 1 week,34.601192
2,Austria,more 2 weeks,35.95847
3,Austria,end-April,36.145825
4,Belgium,28M-March,30.62063


In [17]:
df = pd.pivot(df, index='Country', columns='Model')
df = df.droplevel(0, axis=1)
# df=df[df.columns.reindex(['free','1','2','5','10','15'])[0]]
df = df.round(2)
df.head()

Model,28M-March,end-April,more 1 week,more 2 weeks
Country,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Austria,28.4,36.15,34.6,35.96
Belgium,30.62,70.39,43.7,49.16
Denmark,37.34,42.92,41.8,43.11
France,49.6,249.55,162.66,172.08
Germany,158.9,195.94,161.68,174.9


In [19]:
df.idxmin(axis=1)


Country
Austria              28M-March
Belgium              28M-March
Denmark              28M-March
France               28M-March
Germany              28M-March
Italy             more 2 weeks
Norway               28M-March
Spain                28M-March
Sweden               28M-March
Switzerland          end-April
United Kingdom       28M-March
dtype: object

In [27]:
def bold_one(df, column_str):
    idx = df[column_str]==df.min(axis=1)
    df.loc[idx, column_str] = ['\\textbf{'+'{:.2f}'.format(x)+'}' for x in df.loc[idx, column_str]] 
def bold_all(df, columns):
    minidxs = df.idxmin(axis=1)
    for i in columns:
        idx = i==minidxs
        df.loc[idx, i] = ['\\textbf{'+'{:.2f}'.format(x)+'}' for x in df.loc[idx, i]] 

In [28]:
bold_all(df, list(df.columns))

In [29]:
df

Model,28M-March,end-April,more 1 week,more 2 weeks
Country,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Austria,\textbf{28.40},36.15,34.6,35.96
Belgium,\textbf{30.62},70.39,43.7,49.16
Denmark,\textbf{37.34},42.92,41.8,43.11
France,\textbf{49.60},249.55,162.66,172.08
Germany,\textbf{158.90},195.94,161.68,174.9
Italy,233.07,82.29,81.12,\textbf{80.18}
Norway,\textbf{36.07},40.53,37.42,39.79
Spain,\textbf{59.54},143.16,123.56,129.57
Sweden,\textbf{25.91},39.51,26.6,31.11
Switzerland,72.97,\textbf{63.01},66.62,63.89


In [None]:
df.to_csv('Table-WAIC-more-dates.csv', index='Country', escapechar='@', float_format="%.2f")