# United States v. Shonubi (1997)

In [12]:
# packages

import pandas as pd
import seaborn as sns
import torch

import pyro
import pyro.distributions as dist
import pyro.ops.stats as stats

from rethinking import (LM, MAP, coef, extract_samples, glimmer,
                        link, precis, replicate, sim, vcov)


In [81]:
# the dataset

sh = pd.read_csv('ShonubiCaseDataset.csv')
sh.head()

sh

Unnamed: 0,obs,dataset,balloons,gross_wt,net_wt,purity,age_yrs,gender
0,1.0,1.0,79.0,742.4,503.2,0.51,,
1,2.0,1.0,90.0,901.7,576.9,0.32,,
2,3.0,12.0,90.0,800.2,573.3,0.85,38.0,1.0
3,4.0,12.0,1.0,706.2,439.8,0.75,41.0,1.0
4,5.0,1.0,5.0,72.2,23.1,0.62,,
...,...,...,...,...,...,...,...,...
243,,,,,,,,
244,,,,,,,,
245,,,,,,,,
246,,,,,,,,


FIRST, STRATEGY WHICH USES POSTERIOR MEANS ONLY


note there are 107 cases with gross_wt, but no net_wt. Let's first predict
those, then fill them in to predict weight based on balloons

In [79]:
def model(gross_wt, net_wt):

    m = pyro.sample("m", dist.Normal(0.8, 0.3))
    mu = m * gross_wt
    sigma = pyro.sample("sigma", dist.Uniform(0, 150))

    with pyro.plate("plate"):
        pyro.sample("net_wt", dist.Normal(mu, sigma), obs=net_wt) 

gross_wt = torch.tensor(netData["gross_wt"], dtype=torch.float)
net_wt = torch.tensor(netData["net_wt"], dtype=torch.float)

netModel = MAP(model).run(gross_wt, net_wt)

precis(netModel)

In [87]:
precis(netModel)

Unnamed: 0,Mean,StdDev,|0.89,0.89|
m,0.8,0.03,0.75,0.85
sigma,62.55,4.56,54.79,69.43


In [102]:
netData.dtypes

gross_wt    float64
net_wt      float64
dtype: object

In [152]:
# taking a subset of missing net_wt values
fillNet = sh.copy()[(~sh['gross_wt'].isna()) & (sh['net_wt'].isna())]
fillNet.loc[:, 'net_wt'] = fillNet['net_wt'].fillna(0) # 0s instead of NaNs?

# use torch.floats ?
gross_wtFILL = torch.tensor(fillNet["gross_wt"].values, dtype=torch.float)
net_wtFILL = torch.tensor(fillNet["net_wt"].values, dtype=torch.float) # sim wants tensors

fillNetPair = {"gross_wt": gross_wtFILL}

# using the trained model to fill in missing values
netPred = sim(netModel, fillNetPair, n=int(1e4)) 




# example
# pred_data = {"marriage": R_seq, "median_age_marriage": A_avg.expand_as(R_seq)}

# # compute counterfactual mean divorce (mu)
# mu = link(m5_3, data=pred_data)
# mu_mean = mu.mean(0)
# mu_PI = stats.pi(mu, 0.89, dim=0)

# # simulate counterfactual divorce outcomes
# R_sim = sim(m5_3, data=pred_data, n=int(1e4))

In [153]:
netPred

tensor([[270.1053, 356.0332, 604.5569,  ..., 631.4575, 315.5721, 545.8282],
        [252.6828, 430.4397, 657.4655,  ..., 641.9639, 355.4886, 627.8801],
        [241.9848, 452.2124, 511.4755,  ..., 563.9434, 282.2535, 575.4324],
        ...,
        [190.5428, 386.3636, 615.5149,  ..., 669.1935, 282.1034, 563.7313],
        [280.5052, 590.8213, 534.6596,  ..., 767.5511, 216.3159, 678.9118],
        [438.5371, 695.6119, 669.7462,  ..., 604.7857, 322.1996, 582.6583]])

In [123]:
fillNet['gross_wt'] = fillNet['gross_wt'].fillna(0)

# Convert 'gross_wt' column to NumPy array
gross_wt_array = fillNet['gross_wt'].values

# Convert NumPy array to torch tensor
gross_wtFILL = torch.tensor(gross_wt_array, dtype=torch.float)