# United States v. Shonubi (1997)

In [135]:
# packages

import pyro.ops.stats as stats
from pyro.infer import Predictive, SVI, Trace_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.nn import PyroSample, PyroModule
from pyro.optim import Adam
from pyro.infer.autoguide import AutoMultivariateNormal
from torch import nn

import os
from functools import partial
import torch
import logging
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import pyro
import pyro.distributions as dist
from pyro.infer import Predictive


logging.basicConfig(format='%(message)s', level=logging.INFO)
smoke_test = ('CI' in os.environ)
num_iterations = 1500 if not smoke_test else 2

In [137]:
def run_svi_inference(
    model,
    num_steps=500,
    verbose=True,
    lr=0.03,
    guide=None,
    blocked_sites=None,
    **model_kwargs,
):
    losses = []
    # running_loss_means = []
    if guide is None:
        guide = AutoMultivariateNormal(pyro.poutine.block(model, hide=blocked_sites))
    elbo = pyro.infer.Trace_ELBO()(model, guide)

    elbo(**model_kwargs)
    adam = torch.optim.Adam(elbo.parameters(), lr=lr)
    print(f"Running SVI for {num_steps} steps...")
    for step in range(1, num_steps + 1):
        adam.zero_grad()
        loss = elbo(**model_kwargs)
        loss.backward()
        losses.append(loss.item())
        adam.step()
        if (step % 100 == 0) or (step == 1) & verbose:
            print("[iteration %04d] loss: %.4f" % (step, loss))

    plt.figure()
    plt.plot(losses, label="ELBO loss")
    sns.despine()
    plt.title("ELBO Loss")
    plt.ylim(0, max(losses))
    plt.legend()
    plt.show()

    return guide
   



# def get_samples(
#     distance,
#     proximity,
#     how_far,
#     model= model_sigmavar_proximity,
#     num_svi_iters=num_svi_iters,
#     num_samples=num_samples,
# ):
#     guide = AutoMultivariateNormal(model, init_loc_fn=init_to_mean)
#     svi = SVI(
#         model_sigmavar_proximity, guide, optim.Adam({"lr": 0.01}), loss=Trace_ELBO()
#     )

#     iterations = []
#     losses = []

#     logging.info(f"Starting SVI inference with {num_svi_iters} iterations.")
#     start_time = time.time()
#     pyro.clear_param_store()
#     for i in range(num_svi_iters):
#         elbo = svi.step(distance, proximity, how_far)
#         iterations.append(i)
#         losses.append(elbo)
#         if i % 50 == 0:
#             logging.info("Elbo loss: {}".format(elbo))
#     end_time = time.time()
#     elapsed_time = end_time - start_time
#     logging.info("SVI inference completed in %.2f seconds.", elapsed_time)

#     # uncomment if you want to see the ELBO loss plots
#     # fig = px.line(x=iterations, y=losses, title="ELBO loss", template="presentation")
#     # labels = {"iterations": "iteration", "losses": "loss"}
#     # fig.update_xaxes(showgrid=False, title_text=labels["iterations"])
#     # fig.update_yaxes(showgrid=False, title_text=labels["losses"])
#     # fig.update_layout(width=700)
#     # fig.show()

#     predictive = Predictive(model, guide=guide, num_samples=num_samples)

#     proximity_svi = {
#         k: v.flatten().reshape(num_samples, -1).detach().cpu().numpy()
#         for k, v in predictive(distance, proximity, how_far).items()
#         if k != "obs"
#     }

#     print("SVI-based coefficient marginals:")
#     for site, values in ft.summary(proximity_svi, ["d", "p"]).items():
#         print("Site: {}".format(site))
#         print(values, "\n")

#     return {
#         "svi_samples": proximity_svi,
#         "svi_guide": guide,
#         "svi_predictive": predictive,
#     }
    
    
    
    
# svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

In [144]:
def summary(samples):
    site_stats = {}
    for site_name, values in samples.items():
        marginal_site = pd.DataFrame(values)
        describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
        site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats

def precis(samples):

    for site, values in summary(samples).items():
        print("Site: {}".format(site))
        print(values, "\n") 

In [145]:
# the dataset

sh = pd.read_csv('ShonubiCaseDataset.csv')
sh.head(5)

Unnamed: 0,obs,dataset,balloons,gross_wt,net_wt,purity,age_yrs,gender
0,1.0,1.0,79.0,742.4,503.2,0.51,,
1,2.0,1.0,90.0,901.7,576.9,0.32,,
2,3.0,12.0,90.0,800.2,573.3,0.85,38.0,1.0
3,4.0,12.0,1.0,706.2,439.8,0.75,41.0,1.0
4,5.0,1.0,5.0,72.2,23.1,0.62,,


FIRST, STRATEGY WHICH USES POSTERIOR MEANS ONLY


note there are 107 cases with gross_wt, but no net_wt. Let's first predict
those, then fill them in to predict weight based on balloons

In [146]:
compData = sh[['gross_wt', 'net_wt']].dropna()

data = torch.tensor(compData[["gross_wt", "net_wt"]].values,
                        dtype=torch.float)

gross_wt, net_wt  = data[:, 0], data[:, 1]
    
    
def model(gross_wt, net_wt):
    b_a = pyro.sample("bA", dist.Normal(0.8, 0.3))
    sigma = pyro.sample("sigma", dist.Uniform(0., 150.))
    mean = b_a * gross_wt
    with pyro.plate("data", len(gross_wt)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=net_wt)
        
        

In [147]:
guide = AutoDiagonalNormal(model)

adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())

In [148]:
# pyro.clear_param_store()
# for j in range(num_iterations):
#     # calculate the loss and take a gradient step
#     loss = svi.step(gross_wt, net_wt)
#     if j % 100 == 0:
#         print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))
        
        
for i in range(num_iterations):
    elbo = svi.step(gross_wt, net_wt)
    if i % 500 == 0:
        logging.info("Elbo loss: {}".format(elbo))

Elbo loss: 776.4009214639664


Elbo loss: 776.560912668705
Elbo loss: 775.7735349535942


In [149]:
num_samples = 1000
predictive = Predictive(model, guide=guide, num_samples=num_samples)
svi_samples = {k: v.reshape(num_samples).detach().cpu().numpy()
               for k, v in predictive(gross_wt, net_wt).items()
               if k != "obs"}

In [150]:
precis(svi_samples)

Site: bA
       mean       std        5%       25%       50%       75%       95%
0  0.608505  0.016319  0.581728  0.597156  0.608675  0.620185  0.635263 

Site: sigma
         mean       std          5%         25%        50%         75%  \
0  125.463554  5.080407  116.717117  121.811892  125.95512  129.343185   

         95%  
0  132.89352   



# Experimenting on creating predictions

In [152]:
# taking a subset of missing net_wt values
fillNet = sh.copy()[(~sh['gross_wt'].isna()) & (sh['net_wt'].isna())]
fillNet.loc[:, 'net_wt'] = fillNet['net_wt'].fillna(0) # 0s instead of NaNs?

# use torch.floats ?
gross_wtFILL = torch.tensor(fillNet["gross_wt"].values, dtype=torch.float)
net_wtFILL = torch.tensor(fillNet["net_wt"].values, dtype=torch.float) # sim wants tensors

fillNetPair = {"gross_wt": gross_wtFILL}

# using the trained model to fill in missing values
netPred = sim(netModel, fillNetPair, n=int(1e4)) 




# example
# pred_data = {"marriage": R_seq, "median_age_marriage": A_avg.expand_as(R_seq)}

# # compute counterfactual mean divorce (mu)
# mu = link(m5_3, data=pred_data)
# mu_mean = mu.mean(0)
# mu_PI = stats.pi(mu, 0.89, dim=0)

# # simulate counterfactual divorce outcomes
# R_sim = sim(m5_3, data=pred_data, n=int(1e4))