In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pymc3 as pm
import arviz as az

sns.set()

In [None]:
def logit_pure(x):
    return (np.exp(x)) / (1 + np.exp(x))

def logistic_pure(p):
    return np.log(p / (1 - p))

In [None]:
### divergencies dissapear when having 40 data points asop to 4 ###
df = pd.DataFrame({'tank_id' : [0,1],
                  'population' : [100,100],
                  'survivors' : [10,50]})

df

In [None]:
coords = {'tank_name' : ['tank_a','tank_b']}

with pm.Model(coords=coords) as model:
    alpha_bar = pm.Normal('alpha_bar',mu=0,sd=1.5)
    sigma = pm.Exponential('sigma',1)
    
    alpha = pm.Normal('alpha',mu=alpha_bar,sd=sigma,dims='tank_name')
    
    obs = pm.Binomial('obs',n=df['population'],p=logit_pure(alpha[df['tank_id']]),
                      observed=df['survivors'],dims='tank_name')
    
    trace = pm.sample(500,tune=500,target_accept=0.99)

In [None]:
with model:
    print(az.summary(trace,hdi_prob=0.89))
    az.plot_trace(trace)

In [None]:
idata = az.from_pymc3(trace,model=model)
idata

In [None]:
az.plot_posterior(idata,var_names='alpha',coords={'tank_name' : ['tank_a']})


In [None]:
print (logit_pure(trace['alpha']).mean(axis=0))
logistic_pure(logit_pure(trace['alpha']).mean(axis=0))