In [1]:
%matplotlib inline

import pandas as pd
import pymc3 as pm
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from scipy import stats
import numpy as np

plt.style.use('bmh')

In [2]:
howell1 = pd.read_csv('data/Howell1.csv', sep=';')
howell1.head()

Unnamed: 0,height,weight,age,male
0,151.765,47.825606,63.0,1
1,139.7,36.485807,63.0,0
2,136.525,31.864838,65.0,0
3,156.845,53.041915,41.0,1
4,145.415,41.276872,51.0,0


### 4H1

The weights listed below were recorded in the !Kung census, but heights were not recorded for
these individuals. Provide predicted heights and 89% intervals (either HPDI or PI) for each of these
individuals.
That is, fill in the table below, using model-based predictions.

Individual,  weight,  expected height,  89% interval 

weight:
 46.95,
 43.72,
 64.78,
 32.59,
 54.63


In [3]:
weight_missing = pd.Series([46.95, 43.72, 64.78, 32.59, 54.63], name='weight_test')

In [4]:
with pm.Model() as model_4h1:
    alpha = pm.Bound(pm.Normal, lower=0)('alpha', mu=howell1.height.mean(), sd=200)
    beta = pm.Normal('beta', mu=0, sd=20)

    _mu_height = alpha + beta * (howell1.weight - howell1.weight.mean())

    sigma = pm.HalfCauchy('sigma', beta=10)
    height = pm.Normal('height', mu=_mu_height, sd=sigma, observed=howell1.height)

    trace = pm.sample(10000, tune=2000, cores=2, chains=2)    

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [sigma, beta, alpha]
Sampling 2 chains: 100%|██████████| 24000/24000 [00:19<00:00, 1246.92draws/s]


In [5]:
(
    pd.merge(
        pm.trace_to_dataframe(trace, varnames=['alpha', 'beta']).assign(key=1),
        weight_missing.reset_index().assign(key=1),
        how='outer',
        on='key'
    )
    .drop('key', axis=1)
    .assign(
        height = lambda _df: _df.alpha + _df.beta * (_df.weight_test - howell1.weight.mean())
    )
    .groupby('index')
    .height.describe(percentiles=[0.055, 0.945])
    .join(weight_missing)
    [['mean', '5.5%', '94.5%']]
)

Unnamed: 0_level_0,mean,5.5%,94.5%
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,158.270361,157.444325,159.074321
1,152.570906,151.831506,153.304569
2,189.732058,188.2916,191.14708
3,132.931608,132.262894,133.59572
4,171.822006,170.761773,172.862929
