In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pymc3 as pm
import arviz as az

sns.set()

In [None]:
df = pd.read_csv('https://raw.githubusercontent.com/tolex3/resources/master/Rethinking_2/Data/Howell1.csv',sep=';')

df.head()

In [None]:
df = df.loc[df['age'] >= 18]
df.plot(x='weight',y='height',kind='scatter',color='b')


In [None]:
x = df['weight'].values
gender_idx = df['male']

model = pm.Model()
with model:
    alpha = pm.Normal('alpha',170,30,shape=2)
    beta = pm.Normal('beta',1,2,shape=2)
    sigma = pm.Uniform('sigma',0,50)
    
    reg = pm.Deterministic('reg',alpha[gender_idx] + beta[gender_idx] * x)
    
    obs = pm.Normal('obs',mu=reg,sd=sigma,observed=df['height'])
    
    trace = pm.sample(500,tune=500)

In [None]:
with model:
    idata = az.from_pymc3(trace,coords={'gender_idx': np.array(['female','male'])}, 
    dims={'alpha': ['gender_idx'], 'beta': ['gender_idx']})
    


In [None]:
with model:
    _ = az.plot_posterior(idata,var_names=['alpha','beta'],
                          coords={'gender_idx' : ['female','male']},
                         figsize=(18,12))

In [None]:
f_idx = df['male'] == 0
m_idx = df['male'] == 1

x_f = df.loc[df['male'] == 0]['weight']
x_m = df.loc[df['male'] == 1]['weight']

trace_f = trace['reg'][:,f_idx]
trace_m = trace['reg'][:,m_idx]

In [None]:


print (trace_f)

plt.plot(x,trace['alpha'][:,0].mean() + x * trace['beta'][:,0].mean(),color='crimson',ls='dashed')
plt.plot(x,trace['alpha'][:,1].mean() + x * trace['beta'][:,1].mean(),color='navy',ls='dashed')

ax = plt.gca()

draws = range(0,len(trace['reg']),10)

#plt.plot(x,trace['alpha'][:,1][draws] + trace['beta'][:,1][draws] * x[:,None],color='lightblue',alpha=0.05)
#plt.plot(x,trace['alpha'][:,0][draws] + trace['beta'][:,0][draws] * x[:,None],color='orange',alpha=0.05)

#### QUESTION : HOW TO DO THE below TWO CI-PLOTS WITH ARVIZ ? 
#### that is, how to make arviz 'aware' which of the items in trace['reg'] belong to the male/female categories ?
#### 

az.plot_hdi(x_f,trace_f,ax=ax)
az.plot_hdi(x_m,trace_m,ax=ax)

### how should I write the coords/dims argument above to be able to use az.plot_hdi for the below plot ? 

plt.savefig('arviz_question.jpg',format='jpg')