In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pymc3 as pm
import requests
import scipy.stats as sps
import arviz as az

az.rcParams["stats.hdi_prob"] = 0.89  # sets default credible interval used by arviz

sns.set()

In [None]:
start_date = '2020-03-01'
end_date = '2020-08-27'

url = 'https://covidtrackerapi.bsg.ox.ac.uk/api/v2/stringency/date-range/{}/{}'.format(start_date,end_date)

r = requests.get(url,timeout=5.0)
r.status_code

In [None]:
keys = pd.date_range(start_date,end_date)
keys

In [None]:
json = r.json()
data = json['data']
countries = json['countries']

In [None]:
data_list = []

for k in keys:
    date = k.date().strftime('%Y-%m-%d')
    for c in countries:
        try:
            data_list.append((date,c,data[date][c]['confirmed'],data[date][c]['deaths'],data[date][c]['stringency']))
        except KeyError:
            pass

In [None]:
df = pd.DataFrame(data_list,columns=['date','country','confirmed','dead','oxford_stringency'])
df['date'] = pd.to_datetime(df['date'])
df.set_index(['country','date'],inplace=True)
df.dropna(inplace=True)

df

In [None]:
population = pd.read_csv('world_pop.csv',sep=';',thousands=',',header=None,index_col=0)
population.columns=['pop']
population.index.name='country'
population

In [None]:
three_letter_abb = pd.read_csv('three_letter_country_abb.csv',sep=';',header=None,index_col=2)
three_letter_abb = three_letter_abb[0]
three_letter_abb.loc['GBR']

In [None]:
df = df.reset_index(level=1)

In [None]:
df = df.join(three_letter_abb)
df.rename(columns={0 : 'country'},inplace=True)
df

In [None]:
df = df.merge(population,left_on='country',right_on=population.index)
df['dead_per_M'] = df['dead'] / (df['pop'] / 1e6)
df

In [None]:
three_letter_abb.name='country'
three_letter_abb.index.name='abb'
three_letter_abb = pd.DataFrame(three_letter_abb)
three_letter_abb.columns=['country']
three_letter_abb.reset_index(inplace=True)

In [None]:
def standardize(series):
    return (series - series.mean()) / series.std()

In [None]:
df = df.merge(three_letter_abb,left_on='country',right_on='country')

df['dead_per_M_std'] = standardize(df['dead_per_M'])
df['ox_std'] = standardize(df['oxford_stringency'])

df.dropna(inplace=True)

mask = (df['confirmed'] > 100)
df = df[mask]

df

In [None]:
swe = df.loc[df['abb'] == 'SWE']
uk = df.loc[df['abb'] == 'GBR']
us = df.loc[df['abb'] == 'USA']
bel = df.loc[df['abb'] == 'BEL']
aus = df.loc[df['abb'] == 'AUS']

In [None]:
df

In [None]:

def shift_and_merge(shift=0):
    temp = pd.DataFrame()


    country_groups = df.groupby('abb')

    for country_abb in country_groups.groups.keys():
        country = country_groups.get_group(country_abb).copy()
        country['dead_inc'] = country['dead'] - country['dead'].shift()
        country['dead_per_M_inc'] = country['dead_per_M'] - country['dead_per_M'].shift()
        country['dead_per_M_inc_std'] = standardize(country['dead_per_M_inc'])
        country['dead_per_M_change']  = country['dead_per_M_inc'] - country['dead_per_M_inc'].shift()
        country['dead_per_M_change_std'] = standardize(country['dead_per_M_change'])
        country.dropna(inplace=True)
        temp = pd.concat([temp,country])
        
    return temp

In [None]:
shift = 21

all_shifted = shift_and_merge(shift=shift)

all_shifted.dropna(inplace=True)
all_shifted.reset_index(inplace=True,drop=True)

all_shifted

In [None]:
country_idx_map = dict(zip(all_shifted['abb'].unique(),range(len(all_shifted['abb']))))

inv_map = {v: k for k, v in country_idx_map.items()}

all_shifted['country_idx'] = all_shifted['abb'].apply(lambda x : country_idx_map[x])
all_shifted

In [None]:
swe = all_shifted.loc[all_shifted['abb'] == 'SWE']
print (swe.max())
print (swe.min())

In [None]:
inv_map

In [None]:
swe = all_shifted.loc[all_shifted['abb'] == 'SWE']
per = all_shifted.loc[all_shifted['abb'] == 'PER']
swe

In [None]:
# hierarcical model

# standardized values

# MODEL
# inc_deaths_per_M ~ Normal(req,obs_sigma)
# req = alpha[country_idx] + beta[country_idx] * ox_idx
# alpha ~ Normal(alpha_bar,alpha_sd)
# alpha_bar ~ Normal(0,1)
# alpha_sd ~ Exponential(1)
# beta ~ Normal(beta_bar,beta_sd)
# beta_bar ~ Normal(0,1)
# beta_sd ~ Exponential(1)
#
x = all_shifted['ox_std'].values
country_idx = all_shifted['country_idx'].values

model = pm.Model()

summary = pd.DataFrame()
result = pd.DataFrame()

with model:
    
    obs_sigma = pm.Exponential('obs_sigma',1)

    alpha_bar = pm.Normal('alpha_bar',mu=0,sd=1)
    alpha_sd = pm.Exponential('alpha_sd',1)

    beta_bar = pm.Normal('beta_bar',mu=0,sd=1)
    beta_sd = pm.Exponential('beta_sd',1)

    alpha = pm.Normal('alpha',mu=alpha_bar,sd=alpha_sd,shape=len(country_idx_map))
    beta = pm.Normal('beta',mu=beta_bar,sd=beta_sd,shape=len(country_idx_map))

    reg = alpha[country_idx] + beta[country_idx] * x
    
    lkh = pm.Normal('lkh',mu=reg,sd=obs_sigma,observed=all_shifted['dead_per_M_inc_std'])
    
    trace = pm.sample(500,tune=500)
    
    result = pm.trace_to_dataframe(trace)
    summary = az.summary(trace)
    print (summary)

In [None]:
with model:
    az.plot_trace(trace)

In [None]:
result.describe()

In [None]:
def param_type(colName):
    if 'bar' in colName :
        return colName
    
    elif 'sd' in colName:
        return colName
    
    elif 'beta' in colName :
        return 'beta'
    elif 'alpha' in colName:
        return 'alpha'
    else:
        return colName
    
grp = result.groupby(param_type,axis=1)

alphas = grp.get_group('alpha')
betas = grp.get_group('beta')
obs_sigma = grp.get_group('obs_sigma')
alpha_bar = grp.get_group('alpha_bar')
alpha_sd = grp.get_group('alpha_sd')
beta_bar = grp.get_group('beta_bar')
beta_sd = grp.get_group('beta_sd')


In [None]:
alphas

In [None]:
nr_rows = 500
nr_x = 100

rows = np.random.choice(range(len(alphas)),replace=True,size=nr_rows)

X = np.linspace(-2,2,nr_x)

alpha_samples = alphas.iloc[rows]
beta_samples = betas.iloc[rows]

lines = np.array([X[i] * beta_samples for i in range(len(X))])

lines = lines + alpha_samples.values
#lines[:,:,country_idx_map['SWE']]
lines[:,:,0]

In [None]:
_= plt.plot(X,lines[:,:,country_idx_map['SWE']],color='r',alpha=0.01)

In [None]:
mean_betas = betas.mean()
mean_betas.index = list(country_idx_map.keys())
mean_betas.loc['SWE']

In [None]:
mean_alphas = alphas.mean()
mean_alphas.index = list(country_idx_map.keys())
mean_alphas.loc['SWE']

In [None]:
mean_betas.sort_values()[:50].plot(kind='bar',figsize=(18,12))
plt.savefig('ox_hierarchical_0_50_shift_{}.jpg'.format(shift),format='jpg')

In [None]:
mean_betas.sort_values()[50:100].plot(kind='bar',figsize=(18,12))
plt.savefig('ox_hierarchical_50__100_shift_{}.jpg'.format(shift),format='jpg')

In [None]:
mean_betas.sort_values()[100:].plot(kind='bar',figsize=(18,12))
plt.savefig('ox_hierarchical_100_shift_{}.jpg'.format(shift),format='jpg')

In [None]:
def plot_reg(country_abb):
    fill_kwargs = {'alpha': 0.3}

    country = all_shifted.loc[all_shifted['abb'] == country_abb]
    
    #X = np.linspace(country['ox_std'].min(),country['ox_std'].max(),100)
    
    ax= country.plot(x='ox_std',y='dead_per_M_inc_std',style='o',figsize=(18,12),
                    title=country_abb.upper())
    
    ax.plot(X,lines[:,:,country_idx_map[country_abb]],color='r',alpha=0.01)
    
    
    az.plot_hdi(X,lines[:,:,country_idx_map[country_abb]].T,ax=ax,hdi_prob=0.89,
               fill_kwargs=fill_kwargs,color='m')
    
    plt.xlabel('ox_stringency [standardized]')
    plt.ylabel('daily change of increment deaths per Million [standardized]')
    plt.savefig('ox_hierarchical_reg_{}_shift_{}.jpg'.format(country_abb,shift),format='jpg')

country_abbs = ['SWE','ARG','PER','GBR','USA','ITA','ISR','ZAF','URY','CHL','BEL','ESP','CRI','DNK','FIN']

for c in country_abbs:
    plot_reg(c)


In [None]:
betas.columns = list(country_idx_map.keys())
alphas.columns = list(country_idx_map.keys())

beta_CIs = betas.quantile([0.055,0.945])
alpha_CIs = alphas.quantile([0.055,0.945])

beta_means = betas.mean()
alpha_means = alphas.mean()
alpha_CIs.loc[0.055,'SWE']

In [None]:
ca = 'SWE'
alpha_CIs.loc[0.055,ca]

In [None]:
r,c = 5,3 #rows,cols

fill_kwargs = {'alpha': 1}


fig,axes = plt.subplots(r,c,sharex=True,sharey=True,figsize=(18,12))

for i,ca in enumerate(country_abbs):
    
    country = all_shifted.loc[all_shifted['abb'] == ca]
    
    label = r'$\alpha$: {:.2f} CI: {:.2f}..{:.2f} $\beta$: {:.2f} CI: {:.2f}..{:.2f}'.format(
    alpha_means[ca],alpha_CIs.loc[0.055,ca],alpha_CIs.loc[0.945,ca],
    beta_means[ca],beta_CIs.loc[0.055,ca],beta_CIs.loc[0.945,ca])

    axes[i % r,i % c].plot(country['ox_std'],country['dead_per_M_inc_std'],'o',color='crimson')
    
    #X = np.linspace(country['ox_std'].min(),country['ox_std'].max(),nr_x)
    X = np.linspace(-2,2,nr_x)
    
    axes[i % r,i % c].plot(X,lines[:,:,country_idx_map[ca]],color='orange',alpha=0.01)
    
    axes[i % r, i % c].plot(X,alpha_means[ca] + X * beta_means[ca],color='k',
                            ls='dashed',label=label)
    
    axes[i % r,i % c].set_title(ca)
    axes[i % r, i % c].set_ylabel('inc dead per M [std]')
    axes[i % r, i % c].set_xlabel('Oxford Stringency [std]')
    axes[i % r, i % c].legend(loc='upper left')
    
    az.plot_hpd(X, lines[:,:,country_idx_map[ca]].T,ax=axes[i % r,i % c],
                color='m',fill_kwargs=fill_kwargs,hdi_prob=0.89)
    
    
fig.suptitle('Regression Oxford Index --> Daily Inc deaths per M, deaths shifted {} days\n standardized values'.format(shift))
    
plt.tight_layout()
plt.savefig('ox_hierarchical_multi_shifted_{}.jpg'.format(shift),format='jpg')

In [None]:
all_shifted.loc[all_shifted['abb'] == 'CRI']

In [None]:
betas['CRI'].plot(kind='hist')