In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pymc3 as pm
import requests
import scipy.stats as sps
import arviz as az

az.rcParams["stats.hdi_prob"] = 0.89  # sets default credible interval used by arviz

sns.set()

In [None]:
start_date = '2020-03-01'
end_date = '2020-11-01'

url = 'https://covidtrackerapi.bsg.ox.ac.uk/api/v2/stringency/date-range/{}/{}'.format(start_date,end_date)

r = requests.get(url,timeout=5.0)
r.status_code

In [None]:
keys = pd.date_range(start_date,end_date)
keys

In [None]:
json = r.json()
data = json['data']
countries = json['countries']

In [None]:
data_list = []

for k in keys:
    date = k.date().strftime('%Y-%m-%d')
    for c in countries:
        try:
            data_list.append((date,c,data[date][c]['confirmed'],data[date][c]['deaths'],data[date][c]['stringency']))
        except KeyError:
            pass

In [None]:
df = pd.DataFrame(data_list,columns=['date','country','confirmed','dead','oxford_stringency'])
df['date'] = pd.to_datetime(df['date'])
df.set_index(['country','date'],inplace=True)
df.dropna(inplace=True)

df

In [None]:
population = pd.read_csv('world_pop.csv',sep=';',thousands=',',header=None,index_col=0)
population.columns=['pop']
population.index.name='country'
population

In [None]:
three_letter_abb = pd.read_csv('three_letter_country_abb.csv',sep=';',header=None,index_col=2)
three_letter_abb = three_letter_abb[0]
three_letter_abb.loc['GBR']

In [None]:
df = df.reset_index(level=1)

In [None]:
df = df.join(three_letter_abb)
df.rename(columns={0 : 'country'},inplace=True)
df

In [None]:
df = df.merge(population,left_on='country',right_on=population.index)
df['dead_per_M'] = df['dead'] / (df['pop'] / 1e6)
df

In [None]:
three_letter_abb.name='country'
three_letter_abb.index.name='abb'
three_letter_abb = pd.DataFrame(three_letter_abb)
three_letter_abb.columns=['country']
three_letter_abb.reset_index(inplace=True)

In [None]:
def standardize(series):
    return (series - series.mean()) / series.std()

In [None]:
df = df.merge(three_letter_abb,left_on='country',right_on='country')

df['dead_per_M_std'] = standardize(df['dead_per_M'])
df['ox_std'] = standardize(df['oxford_stringency'])

df.dropna(inplace=True)

mask = (df['confirmed'] > 100)
df = df[mask]

df

In [None]:
swe = df.loc[df['abb'] == 'SWE']
uk = df.loc[df['abb'] == 'GBR']
us = df.loc[df['abb'] == 'USA']
bel = df.loc[df['abb'] == 'BEL']
aus = df.loc[df['abb'] == 'AUS']
isr = df.loc[df['abb'] == 'ISR']

In [None]:
df

In [None]:
isr.tail(40)

In [None]:

def shift_and_merge(shift=0):
    temp = pd.DataFrame()


    country_groups = df.groupby('abb')

    for country_abb in country_groups.groups.keys():
        country = country_groups.get_group(country_abb).copy()
        country['ox_mean'] = country['oxford_stringency'].mean()

        country['dead_inc'] = country['dead'] - country['dead'].shift()
        country['dead_per_M_inc'] = country['dead_per_M'] - country['dead_per_M'].shift()
        country['dead_per_M_inc_std'] = standardize(country['dead_per_M_inc'])
        country['dead_per_M_change']  = country['dead_per_M_inc'] - country['dead_per_M_inc'].shift()
        country['dead_per_M_change_std'] = standardize(country['dead_per_M_change'])
        
        country['dead_per_M_change_std'] = country['dead_per_M_change_std'].shift(shift)
        country['dead_per_M_inc_std'] = country['dead_per_M_inc_std'].shift(shift)
        country['ox_mean'] = country['oxford_stringency'].mean()
        country.dropna(inplace=True)
        
        country['day_idx'] = range(len(country))

        temp = pd.concat([temp,country])
        
    return temp

In [None]:
### PARAM ###

shift = 21

all_shifted = shift_and_merge(shift=shift)

all_shifted.dropna(inplace=True)
all_shifted.reset_index(inplace=True,drop=True)

all_shifted

In [None]:
isr.head()

In [None]:
isr_shifted = all_shifted.loc[all_shifted['country'] == 'Israel']
swe_shifted = all_shifted.loc[all_shifted['country'] == 'Sweden']
isr_shifted.head()
swe_shifted.tail()

In [None]:
country_idx_map = dict(zip(all_shifted['abb'].unique(),range(len(all_shifted['abb']))))

inv_map = {v: k for k, v in country_idx_map.items()}

all_shifted['country_idx'] = all_shifted['abb'].apply(lambda x : country_idx_map[x])
all_shifted

In [None]:
swe = all_shifted.loc[all_shifted['abb'] == 'SWE']
print (swe.max())
print (swe.min())

In [None]:
inv_map

In [None]:
swe = all_shifted.loc[all_shifted['abb'] == 'SWE']
per = all_shifted.loc[all_shifted['abb'] == 'PER']
swe

In [None]:
# hierarcical model, cond on country

# standardized values

# MODEL
# inc_deaths_per_M ~ Normal(req,obs_sigma)
# req = alpha[country_idx] + beta[country_idx] * ox_idx
# alpha ~ Normal(alpha_bar,alpha_sd)
# alpha_bar ~ Normal(0,1)
# alpha_sd ~ Exponential(1)
# beta ~ Normal(beta_bar,beta_sd)
# beta_bar ~ Normal(0,1)
# beta_sd ~ Exponential(1)
#
x = all_shifted['ox_std'].values
country_idx = all_shifted['country_idx'].values

model = pm.Model()

summary = pd.DataFrame()
result = pd.DataFrame()

with model:
    
    
    obs_sigma = pm.Exponential('obs_sigma',1)

    alpha_bar = pm.Normal('alpha_bar',mu=0,sd=1)
    alpha_sd = pm.Exponential('alpha_sd',1)

    beta_bar = pm.Normal('beta_bar',mu=0,sd=1)
    beta_sd = pm.Exponential('beta_sd',1)

    alpha = pm.Normal('alpha',mu=alpha_bar,sd=alpha_sd,shape=len(country_idx_map))
    beta = pm.Normal('beta',mu=beta_bar,sd=beta_sd,shape=len(country_idx_map))

    reg = pm.Deterministic('reg',alpha[country_idx] + beta[country_idx] * x)
    
    lkh = pm.Normal('lkh',mu=reg,sd=obs_sigma,observed=all_shifted['dead_per_M_inc_std'])
    
    trace = pm.sample(500,tune=500)
    
    result = pm.trace_to_dataframe(trace)
    summary = az.summary(trace)


In [None]:
summary

In [None]:
trace['reg'].shape

In [None]:
np.array((range(4),range(4)))

In [None]:
keys = np.array(list(country_idx_map.keys()))
keys

In [None]:
with model:
    idata = az.from_pymc3(trace,coords={'country_idx': keys,
                                        'reg_idx' : range(trace['reg'].shape[1])}, 
    dims={'alpha': ['country_idx'], 'beta': ['country_idx'],'reg' : ['reg_idx']})
    
    _=az.plot_pair(idata,point_estimate='mean',textsize=12,
               marginals=True,figsize=(18,12))
    
    plt.close()
 
 
idata

In [None]:
with model:
    az.plot_trace(trace)
    plt.close()

In [None]:

az.plot_forest(idata,var_names=['beta'],coords={'country_idx' : keys},
               combined=True,rope=[-0.01,0.01])
plt.close()

In [None]:
result.describe()

In [None]:
def param_type(colName):
    if 'bar' in colName :
        return colName
    
    elif 'sd' in colName:
        return colName
    
    elif 'beta' in colName :
        return 'beta'
    elif 'alpha' in colName:
        return 'alpha'
    else:
        return colName
    
grp = result.groupby(param_type,axis=1)

alphas = grp.get_group('alpha')
betas = grp.get_group('beta')
obs_sigma = grp.get_group('obs_sigma')
alpha_bar = grp.get_group('alpha_bar')
alpha_sd = grp.get_group('alpha_sd')
beta_bar = grp.get_group('beta_bar')
beta_sd = grp.get_group('beta_sd')


In [None]:
alphas

In [None]:
nr_rows = 500
nr_x = 100

rows = np.random.choice(range(len(alphas)),replace=True,size=nr_rows)

X = np.linspace(-3,3,nr_x)

alpha_samples = alphas.iloc[rows]
beta_samples = betas.iloc[rows]

lines = np.array([X[i] * beta_samples for i in range(len(X))])

lines = lines + alpha_samples.values
#lines[:,:,country_idx_map['SWE']]
lines[:,:,0]

In [None]:
_= plt.plot(X,lines[:,:,country_idx_map['SWE']],color='r',alpha=0.01)

plt.close()

In [None]:
ox_std_means = all_shifted.groupby('abb').mean()['ox_std']

In [None]:
mean_betas = betas.mean()
mean_betas.index = list(country_idx_map.keys())
mean_betas

In [None]:
mean_alphas = alphas.mean()
mean_alphas.index = list(country_idx_map.keys())
mean_alphas.loc['SWE']

In [None]:
all_means = pd.concat([ox_std_means,mean_alphas,mean_betas],axis=1)
all_means.columns = ['ox_std_mean','alpha','beta']
all_means

In [None]:
all_means.sort_values('beta').plot(y=['beta','ox_std_mean'],figsize=(18,12))
plt.xticks(rotation=90)
 

In [None]:
all_means.sort_values('ox_std_mean',ascending=False).head(20)

In [None]:
# regression ox_mean --> beta mean

model2 = pm.Model()
with model2:
    
    x2 = all_means.sort_values('beta')['ox_std_mean'].values
    y2 = all_means.sort_values('beta')['beta']
    
    alpha2 = pm.Normal('alpha2',0,2)
    beta2 = pm.Normal('beta2',0,2)
    
    sigma2 = pm.Exponential('sigma2',1)
    
    reg2 = alpha2 + beta2 * x2
    
    obs2 = pm.Normal('obs2',reg2,sigma2,observed=y2)
    
    trace2 = pm.sample(500,tune=500)
    
    summary2 = az.summary(trace2)
    result2 = pm.trace_to_dataframe(trace2)
    
    az.plot_trace(trace2)
    
    print (summary2)
    
     

In [None]:
with model2:
    idata = az.from_pymc3(trace2)
    _=az.plot_pair(idata,point_estimate='mean',textsize=12,
               marginals=True,figsize=(18,12))
    
     

In [None]:
xx = x2


with model2:
    ppc2 = pm.sample_posterior_predictive(trace2,500,model2,var_names=['alpha2',
                                                                    'beta2','obs2'])
    
    mu_CI = ppc2['alpha2'] + ppc2['beta2'] * xx[:,None]
    

In [None]:
import scipy.stats as sps

slope,intercept,_,_,_ = sps.linregress(all_means.sort_values('beta')['ox_std_mean'],
                                      all_means.sort_values('beta')['beta'])

lsq = intercept + slope * xx 

plt.figure(figsize=(18,12))
plt.title('avg. OXIDX vs change in daily inc deaths per million\n"countries with low OXIDX mean have + slope, countries with high OXIDX mean have - slope"')
plt.scatter(all_means.sort_values('beta')['ox_std_mean'],
            all_means.sort_values('beta')['beta'],color='r')

plt.plot(xx,summary2.loc['alpha2','mean'] + summary2.loc['beta2','mean'] * xx, ls='--',color='r')

ax = plt.gca()
az.plot_hdi(x=xx,y=mu_CI.T,ax=ax,hdi_prob=0.89)

az.plot_hpd(x=xx,y=ppc2['obs2'],ax=ax,hdi_prob=0.89)

plt.plot(xx,lsq,'--',color='k')



plt.xlabel('ox_std_mean')
plt.ylabel('beta')

 

In [None]:
ax = all_means.sort_values('beta')[:50].plot(y='beta',kind='bar',figsize=(18,12))

ax2 = plt.twinx()

all_means.sort_values('beta')[:50].plot(ax=ax2,y='ox_std_mean',style='ro--',)

plt.savefig('ox_hierarchical_0_50_shift_{}.jpg'.format(shift),format='jpg')
 

In [None]:
mean_betas.sort_values()[50:100].plot(kind='bar',figsize=(18,12))
plt.savefig('ox_hierarchical_50__100_shift_{}.jpg'.format(shift),format='jpg')
 

In [None]:
mean_betas.sort_values()[100:].plot(kind='bar',figsize=(18,12))
plt.savefig('ox_hierarchical_100_shift_{}.jpg'.format(shift),format='jpg')
 

In [None]:
mean_alphas

In [None]:
betas.columns = list(country_idx_map.keys())
alphas.columns = list(country_idx_map.keys())

beta_CIs = betas.quantile([0.055,0.945])
alpha_CIs = alphas.quantile([0.055,0.945])

beta_means = betas.mean()
alpha_means = alphas.mean()
alpha_CIs.loc[0.055,'SWE']

In [None]:
def plot_reg(country_abb):
    
    global_ox_mu = all_shifted['oxford_stringency'].mean()
    global_ox_sd = all_shifted['oxford_stringency'].std()
    global_dead_inc_mu = all_shifted['dead_per_M_inc'].mean()
    global_dead_inc_sd = all_shifted['dead_per_M_inc'].std()
    
    
    fill_kwargs = {'alpha': 0.3}

    country = all_shifted.loc[all_shifted['abb'] == country_abb]
    
    X = np.linspace(country['ox_std'].min(),country['ox_std'].max(),100) ### 
    
    ax = country.plot(x='ox_std',y='dead_per_M_inc_std',style='o',figsize=(18,12),title=country_abb.upper())
    
    label = r'$\alpha$ {:.2f} $\alpha$-CI {:.2f} {:.2f} $\beta$ {:.2f} $\beta$-CI {:.2f} {:.2f}'.format(
        mean_alphas[country_abb],alpha_CIs.loc[0.055,country_abb],alpha_CIs.loc[0.945,country_abb],
        mean_betas[country_abb],beta_CIs.loc[0.055,country_abb],beta_CIs.loc[0.945,country_abb])
    
    country_idx = all_shifted['abb'] == country_abb
    country_trace = trace['reg'][:,country_idx]
    
    
    ### ??? how to tell pymc which country each of the 22k 'reg' corresponds to...? ###
    az.plot_hpd(country['ox_std'],country_trace,hdi_prob=0.89,fill_kwargs=fill_kwargs,ax=ax) ###
    
    draws = range(0,len(trace['alpha'][:,country_idx_map[country_abb]]),5)

    a = trace['alpha'][:,country_idx_map[country_abb]][draws]
    b = trace['beta'][:,country_idx_map[country_abb]][draws]


    #ax.plot(X,a + b * X[:,None],color='r',alpha=0.1)
    #ax.plot(X,lines[:,:,country_idx_map[country_abb]],color='r',alpha=0.01)
    
    ax.plot(X,mean_alphas[country_abb] + X * mean_betas[country_abb],'--',color='k',label=label)
    
    #ax.plot(country['ox_std'],trace['alpha'][:,country_idx].mean() +\
            #trace['beta'][:,country_idx].mean() * country['ox_std'],ls='dashed',color='k')
    
    ### az.plot_hdi(X,lines[:,:,country_idx_map[country_abb]].T,ax=ax,hdi_prob=0.89,
               #fill_kwargs=fill_kwargs,color='m')
    
    ax.axvline(country['ox_std'].mean(),color='orange',ls='dashed')
    ax.axhline(country['dead_per_M_inc_std'].mean(),color='orange',ls='dashed')
    
    ax.set_ylim([-3,3])
    
    ax.legend(loc='upper center')
    
    plt.xlabel(r'ox_stringency [standardized] Global $\mu$: {:.2f} Global $\sigma$: {:.2f}'.format(global_ox_mu,
                                                                                    global_ox_sd))
    plt.ylabel(r'daily change of increment deaths per Million [standardized] Global $\mu$: {:.2f} Global $\sigma$: {:.2f}'.format(
    global_dead_inc_mu,global_dead_inc_sd))
        
    plt.savefig('ox_hierarchical_reg_{}_shift_{}.jpg'.format(country_abb,shift),format='jpg')
     

country_abbs = ['SWE','ARG','PER','GBR','USA','ITA','ISR','ZAF','URY','CHL','BEL','ESP','AUS','DNK','BRA']

country_abbs_high = ['ARG','BOL','BRA','CHL','COL','DOM','HND','IND','IRQ','KAZ','OMN','PAN','PER','QAT','ZAF']

country_abbs_low = ['BLR','EST','HRV','JPN','LTU','MUS','NER','NIC','NOR','NZL','SMR','SOM','TUN','TZA','SWE']

country_abbs = country_abbs_high


for c in country_abbs:
    plot_reg(c)



In [None]:
betas.columns = list(country_idx_map.keys())
alphas.columns = list(country_idx_map.keys())

beta_CIs = betas.quantile([0.055,0.945])
alpha_CIs = alphas.quantile([0.055,0.945])

beta_means = betas.mean()
alpha_means = alphas.mean()
alpha_CIs.loc[0.055,'SWE']

In [None]:
ca = 'SWE'
alpha_CIs.loc[0.055,ca]

In [None]:
r,c = 5,3 #rows,cols

fill_kwargs = {'alpha': 1}


fig,axes = plt.subplots(r,c,sharex=True,sharey=True,figsize=(18,12))

for i,ca in enumerate(country_abbs):
    
    country = all_shifted.loc[all_shifted['abb'] == ca]
    
    label = r'$\alpha$: {:.2f} CI: {:.2f}..{:.2f} $\beta$: {:.2f} CI: {:.2f}..{:.2f}'.format(
    alpha_means[ca],alpha_CIs.loc[0.055,ca],alpha_CIs.loc[0.945,ca],
    beta_means[ca],beta_CIs.loc[0.055,ca],beta_CIs.loc[0.945,ca])

    axes[i % r,i % c].plot(country['ox_std'],country['dead_per_M_inc_std'],'o',color='crimson')
    
    #X = np.linspace(country['ox_std'].min(),country['ox_std'].max(),nr_x)
    X = np.linspace(-3,3,nr_x)
    
    axes[i % r,i % c].plot(X,lines[:,:,country_idx_map[ca]],color='orange',alpha=0.01)
    
    axes[i % r, i % c].plot(X,alpha_means[ca] + X * beta_means[ca],color='k',
                            ls='dashed',label=label)
    
    axes[i % r,i % c].set_title(ca)
    axes[i % r, i % c].set_ylabel('inc dead per M [std]')
    axes[i % r, i % c].set_xlabel('Oxford Stringency [std]')
    axes[i % r, i % c].legend(loc='upper left')
    
    az.plot_hpd(X, lines[:,:,country_idx_map[ca]].T,ax=axes[i % r,i % c],
                color='m',fill_kwargs=fill_kwargs,hdi_prob=0.89)
    
    
fig.suptitle('Regression Oxford Index --> Daily Inc deaths per M, deaths shifted {} days\n standardized values'.format(shift))
    
plt.tight_layout()
plt.savefig('ox_hierarchical_multi_shifted_{}.jpg'.format(shift),format='jpg')
 

In [None]:
all_shifted.loc[all_shifted['abb'] == 'CRI']

In [None]:
betas['CRI'].plot(kind='hist')

In [None]:
all_shifted

In [None]:
swe = all_shifted.loc[all_shifted['country'] == 'Belgium']
swe

In [None]:
plt.scatter(swe['ox_std'],swe['dead_per_M_inc'])

In [None]:
all_shifted

In [None]:
day_idx_map = all_shifted['day_idx'].unique()

In [None]:

plt.figure(figsize=(18,12))
plt.title('Oxford Index vs daily increment deaths per million')
x = all_shifted['ox_std']
plt.scatter(x,all_shifted['dead_per_M_inc_std'],alpha=0.1,label='data',color='r')
plt.ylim([-3,3])
plt.ylabel('daily inc deaths per million [standardized]')
plt.xlabel('Oxford Index [standardized]')
all_shifted.describe()
plt.savefig('ox_index_pure_data.jpg',format='jpg')

In [None]:
# continuous condition, no interaction, cmp model 9

x = all_shifted['ox_std'].values
country_idx = standardize(all_shifted['country_idx']).values
day_idx = standardize(all_shifted['day_idx'].astype(int)).values

model3 = pm.Model()

summary3 = pd.DataFrame()
result3 = pd.DataFrame()

with model3:
    
    obs_sigma = pm.Exponential('obs_sigma',2)

    alpha_bar = pm.Normal('alpha_bar',mu=0,sd=1)
    alpha_sd = pm.Exponential('alpha_sd',1)

    beta_bar = pm.Normal('beta_bar',mu=0,sd=1,shape=3)
    beta_sd = pm.Exponential('beta_sd',1,shape=3)
    
    beta = pm.Normal('beta',mu=beta_bar[0],sd=beta_sd[0])
    beta2 = pm.Normal('beta2',beta_bar[1],beta_sd[1])
    beta3 = pm.Normal('beta3',beta_bar[2],beta_sd[2])
    
    alpha = pm.Normal('alpha',mu=alpha_bar,sd=alpha_sd)
    

    reg = pm.Deterministic('reg',alpha + beta * x + beta2 * day_idx + beta3 * country_idx)
    
    lkh = pm.Normal('lkh',mu=reg,sd=obs_sigma,observed=all_shifted['dead_per_M_inc_std'])
    
    trace3 = pm.sample(500,tune=500,target_accept=0.99)
    
    result3 = pm.trace_to_dataframe(trace3)
    summary3 = az.summary(trace3,var_names=['~reg'])
    print (summary3)

In [None]:
trace3['beta'].mean()

In [None]:
with model3:
    idata = az.from_pymc3(trace3)
    
    _=az.plot_pair(idata,var_names=['~reg'],point_estimate='mean',textsize=12,
               marginals=True,figsize=(18,12))
    
    _ = az.plot_forest(idata,var_names=['~reg'],hdi_prob=0.89)
     

In [None]:
ppc3 = pm.sample_posterior_predictive(trace3,model=model3)
 

In [None]:
fill_kwargs = {'label':'ppc 89%',
            'color':'m',
               'alpha':0.2}

fill_kwargs2 = {'label':'CI 89%',
            'color':'orange',
               'alpha':0.2}


x = all_shifted['ox_std'].values
xx = np.linspace(all_shifted['ox_std'].min(),all_shifted['ox_std'].max(),100)

plt.figure(figsize=(18,12))
plt.title('Oxford Index [std] as predictor for daily inc deaths per M [std]\nConditioned \
for country and day into pandemic\nHierachical (Pooled) Model\nshift={} days'.format(shift))

plt.scatter(x,all_shifted['dead_per_M_inc_std'],alpha=0.4,label='data',color='lightgrey')

plt.plot(xx,trace3['alpha'].mean() + trace3['beta'].mean() * xx + trace3['beta2'].mean() *\
         day_idx.mean() + trace3['beta3'].mean() * country_idx.mean(),color='k',ls='dashed',label='mean')
ax = plt.gca()

az.plot_hdi(x,trace3['reg'],hdi_prob=0.89,color='orange',ax=ax,fill_kwargs=fill_kwargs2)
az.plot_hpd(x,ppc3['lkh'],hdi_prob=0.89,ax=ax,fill_kwargs=fill_kwargs)

plt.ylabel('daily inc deaths per M [std]')
plt.xlabel('daily Oxford Index [std]')
plt.ylim([-4,5])
plt.legend(loc='upper left')

plt.savefig('ox_index_cond_country_day.jpg',format='jpg',dpi=400)
 

In [None]:
##### BINNED PREDICTOR #####

bins=10



all_shifted['ox_bin'] = pd.cut(x=all_shifted['oxford_stringency'],bins=bins,labels=range(bins))
all_shifted['ox_bin'] = all_shifted['ox_bin'].astype(int)
all_shifted['ox_bin_std'] = standardize(all_shifted['ox_bin'])
all_shifted

In [None]:
# no pooling no condition, not binned predictor; cmp 7 with binned predictor

model4 = pm.Model()
with model4:
    
    alpha = pm.Normal('alpha',0,1)
    beta = pm.Normal('beta',0,1)
    sigma = pm.Exponential('sigma',1)
    
    req = pm.Deterministic('req',alpha + beta * all_shifted['ox_bin_std'].values)
    obs = pm.Normal('obs',mu=req,sd=sigma,observed=all_shifted['dead_per_M_inc_std'])
    
    trace4 = pm.sample(500,tune=500)
    summary4 = pm.summary(trace4)
    #result4 = pm.trace_to_dataframe(trace4)
    #az.plot_trace(trace4)

In [None]:
summary4
    

In [None]:
with model4:
    idata = az.from_pymc3(trace4)
    _=az.plot_pair(idata,var_names=['~req'],point_estimate='mean',textsize=12,
               marginals=True,figsize=(18,12))
    
     

In [None]:
ppc = pm.sample_posterior_predictive(trace4,model=model4)
ppc['obs'].shape

In [None]:
# category means
foo = all_shifted.groupby('ox_bin_std')
keys = list(foo.groups.keys())

means = foo['dead_per_M_inc_std'].mean()
means

In [None]:
e_b = np.zeros((bins,2))

for i in range(bins):
    g_idx = keys[i]
    g = foo.get_group(g_idx)['dead_per_M_inc_std']
    print (g.describe(percentiles=[0.055,0.945]))
    mean = g.mean()
    ci = np.percentile(g,[5.5,94.5])
    ci[0] = np.abs(ci[0] - mean)
    ci[1] = np.abs(ci[1] - mean)
    e_b[i] = ci


e_b

In [None]:
draws = range(0,len(trace4),10)



fill_kwargs = {'label':'ppc 89%',
            'color':'m',
               'alpha':0.2}


plt.figure(figsize=(18,12))

plt.title('Oxford Index [std] as predictor for daily inc deaths per M [std]\nTotal non-hierachical,non-conditioned\n \
Using {} bins for Oxford Index\nshift={} days'.format(bins,shift))

x = all_shifted['ox_bin_std'].values

plt.scatter(x,all_shifted['dead_per_M_inc_std'],color='lightgrey',alpha=0.4,label='data')

plt.scatter(means.index,means,s=100,color='k',label='bin mean')

plt.errorbar(means.index,means,yerr=e_b.T,capsize=5,capthick=5,color='k',fmt='none',label='data bin 89% range')

plt.plot(x,trace4['alpha'][draws] + trace4['beta'][draws] * x[:,np.newaxis],color='orange',alpha=0.2)

ax = plt.gca()

plt.plot(x,trace4['alpha'].mean() + trace4['beta'].mean() * x,color='k',ls='dashed')

az.plot_hdi(x=x,y=trace4['req'],hdi_prob=0.89,ax=ax,color='lime')

ppc = pm.sample_posterior_predictive(trace4,model=model4)

az.plot_hpd(x,ppc['obs'],hdi_prob=0.89,color='m',ax=ax,fill_kwargs=fill_kwargs)

plt.ylabel('dead_per_M_inc_std')
plt.xlabel('oxford index [std] [data in {} categories]'.format(bins))

plt.ylim([-4,5])
plt.legend(loc='upper left')
plt.savefig('oxford_index_binned.jpg',format='jpg')
 

In [None]:
freq = all_shifted.groupby(['abb','ox_bin']).count()['country']
freq['ZAF']

In [None]:
#country_abbs = ['SWE','ARG','PER','GBR','USA','ITA','ISR','ZAF','URY','CHL','BEL','ESP','CRI','DNK','FIN']

r,c = 5,3
fig,axes = plt.subplots(r,c,sharey=True,figsize=(18,12))


### cludge to fix a common x range for bar subplot ### 

for i,ca in enumerate(country_abbs):
    try:
        (freq[ca].reindex(range(bins)).replace(np.nan,0)).plot(kind='bar',ax=axes[i % r, i % c],title=ca)

        axes[i % r, i % c].set_ylabel('nr of days')
    except:
        pass
    
plt.suptitle('Number of days spent at binned [0-{}] Oxford Index Level'.format(bins-1))
plt.tight_layout()
plt.savefig('oxford_index_days_in_prison_levels_shift_{}.jpg'.format(shift),format='jpg')
 

In [None]:
# Multiindex slice
idx = pd.IndexSlice
freq.loc[idx[country_abbs]].plot(kind='bar',figsize=(18,12))
 

In [None]:
freq = all_shifted.groupby('ox_bin')['country'].count()
freq

In [None]:
d_scale = np.repeat(2.0,bins-1)

d = pm.Dirichlet.dist(d_scale).random()
    
d

In [None]:
# ordered categorical, cmp model 8

import theano.tensor as tt

E = all_shifted['ox_bin'].values

model6 = pm.Model()
with model6:
    
    bE = pm.Normal('bE',0,1)
    delta = pm.Dirichlet("delta", np.repeat(2.0, bins-1), shape=bins-1)
    delta_j = tt.concatenate([tt.zeros(1), delta])
    delta_j_cumulative = tt.cumsum(delta_j)
    sigma = pm.Exponential('sigma',1)
    
    phi = pm.Deterministic('phi',bE * delta_j_cumulative[E])
    
    obs = pm.Normal('obs',mu=phi,sd=sigma,observed=all_shifted['dead_per_M_inc_std'])
    
    trace6 = pm.sample(1000,tune=1000)
    
    summary6 = pm.summary(trace6)
    #az.plot_trace(trace6)

In [None]:
summary6

In [None]:
#summary6.loc['delta[0]' : 'delta[8]']

In [None]:
ppc = pm.sample_posterior_predictive(trace=trace6,model=model6)

In [None]:
ppc['obs'].shape

In [None]:
cat_means = np.zeros(bins)

for i in range(bins):
    cat_means[i] = all_shifted.loc[all_shifted['ox_bin'] == i]['dead_per_M_inc_std'].mean()
    
cat_means

In [None]:
# error bar example

e_bars = np.zeros((bins,2))

for i in range(bins):
    CI = np.percentile(all_shifted.loc[all_shifted['ox_bin']==i]['dead_per_M_inc_std'],[5.5,94.5])
    mean = all_shifted.loc[all_shifted['ox_bin'] == i]['dead_per_M_inc_std'].mean()
    CI[0] = np.abs(CI[0] - mean)
    CI[1] = np.abs(CI[1] - mean)
    e_bars[i] = CI

print (e_bars)

e_bars.T

#plt.scatter(range(bins),np.repeat(2,bins))
#plt.errorbar(range(bins),np.repeat(2,bins),e_bars.T,fmt='none',capsize=5)

plt.scatter(range(bins),cat_means)
plt.errorbar(range(bins),cat_means,e_bars.T,fmt='none',capsize=5)

In [None]:
fill_kwargs = {'label':'ppc 89%',
            'color':'orange',
               'alpha':0.4}

fill_kwargs2 = {'label':'GLM 89% CI',
            'color':'yellow',
               'alpha':0.4}


plt.figure(figsize=(18,12))

title = 'COVID19 Oxford Stringency Index : \
association with daily increment dead per million\nOxford Index as Ordered Categorical Predictor,\
10 levels\ndaily data for {} countries during {} - {}'.format(len(country_idx_map),
                                                             start_date,end_date)
plt.title(title)

#plt.scatter(E,all_shifted['dead_per_M_inc_std'],color='lightgrey',alpha=0.3,label='data')
ax = plt.gca()
#az.plot_hpd(E,ppc['obs'],ax=ax,hdi_prob=0.89,fill_kwargs=fill_kwargs)
az.plot_hdi(E,trace6['phi'],ax=ax,color='r',hdi_prob=0.89,fill_kwargs=fill_kwargs2,)


for i in range(bins):
    label = 'Category Mean daily inc dead per M' if i == 0 else ''
    plt.plot(i,all_shifted.loc[all_shifted['ox_bin']==i]['dead_per_M_inc_std'].mean(),'o--',
                color='k',label=label)    
    
#plt.errorbar(range(bins),cat_means,e_bars.T,fmt='none',capsize=5,
             #capthick=5,label='Data 89% range',color='k')
    
plt.xlabel('Oxford Index Ordered Category')
plt.ylabel('daily dead inc per Million [std]')
plt.ylim([-0.3,0.3])
plt.legend(loc='upper left')
_= plt.xticks(range(bins))
plt.savefig('oxford_index_ordered_categorical.jpg',format='jpg',dpi=400)

In [None]:
(all_shifted.loc[all_shifted['abb']=='ISR']).plot(x='date',y='ox_bin')

In [None]:
high_ox= ((all_shifted.loc[all_shifted['ox_bin'] > 6]).groupby('abb')['dead_per_M_inc_std'].count()) > 140
high_ox.loc[high_ox == True]



In [None]:
low_ox = ((all_shifted.loc[all_shifted['ox_bin'] < 4]).groupby('abb')['dead_per_M_inc_std'].count()) > 100
low_ox.loc[low_ox == True]

In [None]:
all_shifted.loc[all_shifted['abb']== 'ISR']

In [None]:
with model6:
    idata = az.from_pymc3(trace6)
    _=az.plot_pair(idata,var_names=['delta'],point_estimate='mean',textsize=12,
               marginals=True,figsize=(18,12))

In [None]:
with model6:
    az.plot_posterior(trace6,var_names=['bE','delta'],figsize=(18,12))
    
plt.savefig('oxford_index_ordered_categorical_deltas.jpg',format='jpg',dpi=400)

In [None]:
ox_select = all_shifted.loc[all_shifted['ox_bin'] > all_shifted['ox_bin'].max() - 3]
ox_select

In [None]:
# bin_std, no condition, but selected high index countries

model7 = pm.Model()
with model7:
    alpha = pm.Normal('alpha',0,1)
    beta = pm.Normal('beta',0,1)
    sigma = pm.Exponential('sigma',1)
    
    req = pm.Deterministic('req',alpha + beta * ox_select['ox_bin_std'].values)
    obs = pm.Normal('obs',mu=req,sd=sigma,observed=ox_select['dead_per_M_inc_std'])
    
    trace7 = pm.sample(1000,tune=1000,target_accept=0.95)
    summary7 = pm.summary(trace7)

In [None]:
summary7


In [None]:
with model7:
    
    az.plot_posterior(trace7,
                      var_names=['alpha','beta','sigma'],
                      figsize=(18,12),
                      ref_val=[summary7.loc['alpha','mean'],0,summary7.loc['sigma','mean']])
    
plt.savefig('oxford_index_high_index_last_3_idx_grps.jpg',format='jpg',dpi=400)

In [None]:
az.plot_hdi(E,trace6['phi'])

In [None]:
def normalize_0_1(series):
    return series / series.max()

all_shifted['ox_bin_0_1'] = all_shifted['ox_bin'] / all_shifted['ox_bin'].max()
all_shifted

In [None]:
#metric predictor, cmp model6

model8 = pm.Model()
with model8:
    
    bE = pm.Normal('bE',0,1)
    
    sigma = pm.Exponential('sigma',1)
    
    phi = pm.Deterministic('phi',bE * all_shifted['ox_bin_0_1'].values)
    
    obs = pm.Normal('obs',mu=phi,sd=sigma,observed=all_shifted['dead_per_M_inc_std'])
    
    trace8 = pm.sample(1000,tune=1000)
    
    summary8 = pm.summary(trace8)


In [None]:
summary8

In [None]:
with model8:
    
    az.plot_posterior(trace8,
                      var_names=['bE','sigma'],
                      figsize=(18,12),
                      ref_val=[summary8.loc['bE','mean'],summary8.loc['sigma','mean']])
    

In [None]:
with model8:
    az.plot_hdi(all_shifted['ox_bin_0_1'].values,trace8['phi'])

In [None]:
ppc = pm.sample_posterior_predictive(trace=trace8,model=model8)
ppc

In [None]:
az.plot_hpd(all_shifted['ox_bin_0_1'],ppc['obs'])

In [None]:
ox_0_1_values = np.sort(all_shifted['ox_bin_0_1'].unique())
ox_0_1_values

In [None]:
plt.figure(figsize=(18,12))
plt.title('comparison Metric vs Ordered Category Predictor [dummy data]')
plt.plot(ox_0_1_values,summary6.loc['bE','mean'] * ox_0_1_values,label='ordered category',color='r',ls='dashed')
plt.plot(ox_0_1_values,summary8.loc['bE','mean'] * ox_0_1_values,label='metric category',color='g',ls='dashed')
plt.legend(loc='lower left')
plt.xlabel('predictor')
plt.ylabel('outcome')
plt.savefig('oxford_cmp_metric_vs_ordered_category.jpg',format='jpg')

In [None]:
E_values=np.unique(E)
E_values

In [None]:
plt.figure(figsize=(18,12))
plt.title('Ordered Categorical Predictor - uncertainty in regression line')
draws = range(0,len(trace6),10)
_ = plt.plot(E_values,trace6['bE'][draws]  * E_values[:,np.newaxis],color='r',alpha=0.1)

plt.plot(E_values,summary6.loc['bE','mean'] * E_values,
         color='k',ls='dashed',label='mean')

plt.ylabel('outcome')
plt.xlabel('ordered categorical predictor')
plt.legend(loc='lower left')
plt.savefig('oxford_ordered_categorical_predictor.jpg',format='jpg')


In [None]:
draws = range(0, len(trace8),10)

plt.figure(figsize=(18,12))
plt.title('Metric predictor - uncertainty in prediction line')
_ = plt.plot(ox_0_1_values,trace8['bE'][draws] * ox_0_1_values[:,np.newaxis],color='r',alpha=0.1)
plt.plot(ox_0_1_values,summary8.loc['bE','mean'] * ox_0_1_values,label='mean',
         color='k',ls='dashed')

plt.ylabel('outcome')
plt.xlabel('metric predictor')
plt.legend(loc='lower left')
plt.savefig('oxford_metric_predictor.jpg',format='jpg')



In [None]:
plt.figure(figsize=(18,12))
plt.title('Oxford Index bin distribution')
sns.violinplot(x='ox_bin',y='dead_per_M_inc_std',data=all_shifted,color='r',scale='count')
plt.savefig('oxford_data_variability_violin.jpg',format='jpg')

In [None]:
date_mask = all_shifted['date'] > '2020-06-01'
level_mask = all_shifted['ox_bin'] > 7

low_ox = ((all_shifted[date_mask & level_mask]).groupby('abb')['dead_per_M_inc_std'].count()) > 10
low_ox.loc[low_ox == True]

In [None]:
# with interaction 

x = all_shifted['ox_std'].values
country_idx = standardize(all_shifted['country_idx']).values
day_idx = standardize(all_shifted['day_idx'].astype(int)).values

model9 = pm.Model()

summary9 = pd.DataFrame()
result9 = pd.DataFrame()

with model9:
    
    obs_sigma = pm.Exponential('obs_sigma',2)

    alpha_bar = pm.Normal('alpha_bar',mu=0,sd=1)
    alpha_sd = pm.Exponential('alpha_sd',1)

    beta_bar = pm.Normal('beta_bar',mu=0,sd=1,shape=3)
    beta_sd = pm.Exponential('beta_sd',1,shape=3)
    
    beta = pm.Normal('beta',mu=beta_bar[0],sd=beta_sd[0])
    beta2 = pm.Normal('beta2',beta_bar[1],beta_sd[1])
    beta3 = pm.Normal('beta3',beta_bar[2],beta_sd[2])
    
    beta4 = pm.Normal('beta4',0,1)
    beta5 = pm.Normal('beta5',0,1)
    beta6 = pm.Normal('beta6',0,1)
    
    alpha = pm.Normal('alpha',mu=alpha_bar,sd=alpha_sd)
    

    reg = pm.Deterministic('reg',alpha + beta * x + beta2 * day_idx + beta3 * country_idx + \
                          beta4 * x * day_idx + beta5 * x * country_idx + beta6 * day_idx * country_idx)
    
    lkh = pm.Normal('lkh',mu=reg,sd=obs_sigma,observed=all_shifted['dead_per_M_inc_std'])
    
    trace9 = pm.sample(500,tune=500,target_accept=0.99)
    
    summary9 = az.summary(trace9)
    print (summary9)

In [None]:
with model9:
    _ = az.plot_posterior(trace9,var_names=['alpha_bar','beta_bar','alpha','beta','beta2','beta3',
                                       'beta4','beta5','beta6','alpha_sd','beta_sd'])

In [None]:
with model9:
    idata = az.from_pymc3(trace9)
    _ = az.plot_forest(idata,var_names=['~reg'])

In [None]:
ppc9 = pm.sample_posterior_predictive(trace9,model=model9)

fill_kwargs = {'label':'ppc 89%',
            'color':'m',
               'alpha':0.2}

fill_kwargs2 = {'label':'CI 89%',
            'color':'orange',
               'alpha':0.2}


x = all_shifted['ox_std'].values
xx = np.linspace(all_shifted['ox_std'].min(),all_shifted['ox_std'].max(),100)

plt.figure(figsize=(18,12))
plt.title('Oxford Index [std] as predictor for daily inc deaths per M [std]\nConditioned \
for country and day into pandemic\nHierachical (Pooled) Model\nshift={} days\nInteractions'.format(shift))

plt.scatter(x,all_shifted['dead_per_M_inc_std'],alpha=0.4,label='data',color='lightgrey')

ax = plt.gca()

az.plot_hdi(x,trace9['reg'],hdi_prob=0.89,color='orange',ax=ax,fill_kwargs=fill_kwargs2,)
az.plot_hpd(x,ppc9['lkh'],hdi_prob=0.89,ax=ax,fill_kwargs=fill_kwargs)

plt.ylabel('daily inc deaths per M [std]')
plt.xlabel('daily Oxford Index [std]')
plt.ylim([-4,5])
plt.legend(loc='upper left')

In [None]:
len(all_shifted['country_idx'].unique())

In [None]:
# dirichlet with pooling by country_idx 

country_idx = all_shifted['country_idx'].astype(int)
    

import theano.tensor as tt

E = all_shifted['ox_bin'].values

model10 = pm.Model()
with model10:

    bE_bar = pm.Normal('bE_bar',0,1)
    bE_sd_bar = pm.Exponential('bE_sd_bar',1)
    
    bE = pm.Normal('bE',bE_bar,bE_sd_bar, shape=len(all_shifted['country_idx'].unique()))
    
    delta = pm.Dirichlet("delta", np.repeat(2.0, bins-1), shape=bins-1)
    delta_j = tt.concatenate([tt.zeros(1), delta])
    delta_j_cumulative = tt.cumsum(delta_j)
    sigma = pm.Exponential('sigma',1)
    
    phi = pm.Deterministic('phi',bE[country_idx] * delta_j_cumulative[E])
    
    obs = pm.Normal('obs',mu=phi,sd=sigma,observed=all_shifted['dead_per_M_inc_std'])
    
    trace10 = pm.sample(500,tune=500)
    
    summary10 = pm.summary(trace10)
    #az.plot_trace(trace10)

In [None]:
summary10


In [None]:
keys = np.array(list(country_idx_map.keys()))

In [None]:
idata = az.from_pymc3(trace10,model=model10,coords={'country_idx': keys}, 
    dims={'bE': ['country_idx']})
idata

In [None]:
_ = az.plot_posterior(idata,var_names=['delta'])

In [None]:
_ = az.plot_posterior(idata,var_names='bE',ref_val=0,coords={'country_idx' : keys[19 : 29]})

In [None]:
result_bE = pd.DataFrame(trace10['bE'],columns=keys)
result_bE

In [None]:
country_abbs_high = ['ARG','BOL','BRA','CHL','COL','DOM','HND','IND','IRQ','KAZ','OMN','PAN','PER','QAT','ZAF']

slope_highs = result_bE[country_abbs_high]
slope_highs.describe(percentiles=[0.055,0.945])

In [None]:
CI = np.percentile(slope_highs,[5.5,94.5],axis=0)
CI

In [None]:
##### ERROR BARS !!!!! #### 

means = slope_highs.mean(axis=0)
errs = CI

errs[0] = np.abs(errs[0] - means)
errs[1] = np.abs(errs[1] - means) 

errs

In [None]:
plt.figure(figsize=(18,12))
#plt.plot(slope_highs.columns,slope_highs.mean(axis=0),'o',color='crimson')

plt.errorbar(x=slope_highs.columns,y=means,
             yerr=errs,fmt='o',capsize=5,color='k')

In [None]:
high_country_keys = []

for abb in country_abbs_high:
    high_country_keys.append (country_idx_map[abb])

high_country_keys = np.array(high_country_keys).astype(int)
high_country_keys

In [None]:
_ = az.plot_posterior(idata,var_names='bE',ref_val=0,coords={'country_idx' : keys[high_country_keys]})

In [None]:
_ = az.plot_posterior(trace10['phi'],hdi_prob=0.89)

In [None]:
az.plot_forest(idata,var_names=['delta'])

In [None]:
az.plot_forest(idata,var_names=['bE'],rope=[-0.01,0.01],hdi_prob=0.89,
               coords={'country_idx' : keys[high_country_keys]},combined=True,figsize=(18,12))
plt.xlabel('Total mean effect of lockdown in terms  of change standardized daily deaths per M \
\ndelta per bin to be observed')
plt.ylabel('country')

In [None]:
nr_countries = len(country_idx_map)

az.plot_forest(idata,var_names=['bE'],rope=[-0.01,0.01],
               coords={'country_idx' : keys[0:30]},combined=True,figsize=(18,12))

plt.savefig('lockdown_0_30.jpg',format='jpg',dpi=400)

az.plot_forest(idata,var_names=['bE'],rope=[-0.01,0.01],
               coords={'country_idx' : keys[30:60]},combined=True,figsize=(18,12))

plt.savefig('lockdown_30_60.jpg',format='jpg',dpi=400)


az.plot_forest(idata,var_names=['bE'],rope=[-0.01,0.01],
               coords={'country_idx' : keys[60:90]},combined=True,figsize=(18,12))

plt.savefig('lockdown_60_90.jpg',format='jpg',dpi=400)


az.plot_forest(idata,var_names=['bE'],rope=[-0.01,0.01],
               coords={'country_idx' : keys[90:120]},combined=True,figsize=(18,12))

plt.savefig('lockdown_90_120.jpg',format='jpg',dpi=400)


az.plot_forest(idata,var_names=['bE'],rope=[-0.01,0.01],
               coords={'country_idx' : keys[120:150]},combined=True,figsize=(18,12))

plt.savefig('lockdown_120_150.jpg',format='jpg',dpi=400)


az.plot_forest(idata,var_names=['bE'],rope=[-0.01,0.01],
               coords={'country_idx' : keys[150:]},combined=True,figsize=(18,12))


In [None]:
az.plot_forest(idata,var_names=['bE'],rope=[-0.01,0.01],
               coords={'country_idx' : keys[:]},combined=True)

plt.savefig('lockdown_impact.jpg',format='jpg',dpi=400)

In [None]:
# cant just do sps.linregress() to get slope and intercept b/c now we have a non-linear
# predictor!!! And the Dirichlet based model does not have any intercept
#
c_name = 'ZAF'

c = all_shifted.loc[all_shifted['abb'] == c_name]
ax = c.plot(x='ox_std',y='dead_per_M_inc_std',style='o')
slope = trace10['bE'][:,country_idx_map[c_name]].mean()
CI = np.percentile(trace10['bE'][:,country_idx_map[c_name]],[5.5,94.5])
CI2 = np.percentile(trace10['bE'][:,country_idx_map[c_name]],[25,75])
hdi = az.hdi(trace10['bE'][:,country_idx_map[c_name]],hdi_prob=0.50)
print (slope)
print (CI)
print (CI2)
print (hdi)

x = np.linspace(-2,2,100)
       
ax.plot(x,slope * x)


In [None]:
az.plot_forest(data=[trace,trace10],model_names=['model','model10'],var_names=['beta','bE'],
               combined=True,colors='cycle')

In [None]:
idata0 = az.from_pymc3(trace,model=model,coords={'country_idx': keys}, 
    dims={'beta' : ['country_idx']})

idata10 = az.from_pymc3(trace10,model=model10,coords={'country_idx': keys}, 
    dims={'bE': ['country_idx']})


In [None]:
az.plot_forest([idata0,idata10],model_names=['model','model10'],var_names=['beta','bE'],
               combined=True,textsize=12,rope=[-0.01,0.01])

In [None]:
az.plot_forest([idata10],model_names=['model10'],var_names=['bE'],
               combined=True,textsize=12,rope=[-0.01,0.01],coords={'country_idx' : keys[0:30]})

az.plot_forest([idata10],model_names=['model10'],var_names=['bE'],
               combined=True,textsize=12,rope=[-0.01,0.01],coords={'country_idx' : keys[30:60]})

az.plot_forest([idata10],model_names=['model10'],var_names=['bE'],
               combined=True,textsize=12,rope=[-0.01,0.01],coords={'country_idx' : keys[60:90]})

az.plot_forest([idata10],model_names=['model10'],var_names=['bE'],
               combined=True,textsize=12,rope=[-0.01,0.01],coords={'country_idx' : keys[90:120]})

az.plot_forest([idata10],model_names=['model10'],var_names=['bE'],
               combined=True,textsize=12,rope=[-0.01,0.01],coords={'country_idx' : keys[120:]})

In [None]:
all_shifted.loc[all_shifted['abb'] == 'MEX']


In [None]:
diffs = pd.Series(((trace['beta'].mean(axis=0) < 0) &\
                   (trace10['bE'].mean(axis=0) > 0)) |((trace['beta'].mean(axis=0) > 0)\
                    & (trace10['bE'].mean(axis=0) < 0) ),
                  index=keys)
diffs = diffs.loc[diffs == True]
diffs

In [None]:
plt.figure(figsize=(18,12))
plt.plot(keys,trace['beta'].mean(axis=0),'o')
plt.plot(keys,trace10['bE'].mean(axis=0),'x')
_=plt.xticks(rotation=90)

In [None]:
for a in diffs.index:
    print (a)

In [None]:
diff_country_keys = []

for abb in diffs.index:
    diff_country_keys.append (country_idx_map[abb])

diff_country_keys = np.array(diff_country_keys).astype(int)
diff_country_keys

In [None]:
az.plot_forest([idata0,idata10],model_names=['model','model10'],var_names=['beta','bE'],
                coords={'country_idx' : keys[diff_country_keys]},combined=True,rope=[-0.01,0.01])