<a href="https://colab.research.google.com/github/seanreed1111/BDA_py_demos/blob/master/pyro_logistic_regression_Bangladeshi_Wells_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

source: Regression and Other Stories, Gelman et al. Chapter 13

https://avehtari.github.io/ROS-Examples/Arsenic/arsenic_logistic_residuals.html

https://github.com/avehtari/ROS-Examples/tree/master/Arsenic/

In [None]:
pip install pyro-ppl

In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import statsmodels.api as sm

import torch
from torch.distributions import constraints
from torch import tensor

import pyro
import pyro.distributions as dist
from pyro.infer import SVI,Trace_ELBO
from pyro.infer.autoguide  import AutoMultivariateNormal, AutoNormal, init_to_mean
from pyro.optim import ClippedAdam

assert pyro.__version__.startswith('1.8.0')
pyro.set_rng_seed(1)
torch.manual_seed(1)

# Set matplotlib settings
%matplotlib inline
plt.style.use('default')
plt.rcParams['figure.figsize'] = [12, 8]

A research team marked wells as safe or unsafe according to their arsenic levels. People with unsafe wells were encouraged to switch. A few years later, they returned to study who had actually switched.

Inputs
- The distance to the closest known safe well
- The arsenic level of respondent's well
- Whether any members of household are active in community organizations
- The education level of the head of household

In [None]:
df = pd.read_csv("https://raw.githubusercontent.com/seanreed1111/RAOS-Examples/master/Arsenic/data/wells.csv")
df.head()

In [None]:
df['switch'].value_counts()[1] / df['switch'].count() * 100    # percentage of dataset that switched

In [None]:
df = df[['switch', 'arsenic', 'dist100', 'assoc', 'educ4']] #remove dist and educ
df.describe().T

Note that the max safe arsenic level is 0.5 on this scale. 
So all people in this dataset live near an unsafe well, and the worst well is ~ 20x more arsenic than the safe levels!

In [None]:
X = df.copy() # note: data not centered

In [None]:
# Everyone in this dataset used a well considered at an unsafe arsenic level (0.5 is considered safe). 
# Let's examine the arsenic levels in the dataset

In [None]:
sns.kdeplot(data=X['arsenic'])

In [None]:
#let's try to cut the arsenic levels into buckets

In [None]:
X['arsenic'].describe()

In [None]:
quantiles = pd.qcut(X['arsenic'], q=4)
quantiles.value_counts()

In [None]:
quantiles = pd.qcut(X['arsenic'], q=4, labels=False)
quantiles.value_counts()

In [None]:
quantiles

In [None]:
deciles = pd.qcut(X['arsenic'], q=10)
deciles.value_counts()

In [None]:
deciles = pd.qcut(X['arsenic'], q=10, labels=False)
deciles.value_counts()

In [None]:
X['arsenic_quantiles'] = quantiles
X['arsenic_deciles'] = deciles

In [None]:
sns.boxplot(data=X, x='switch', y='arsenic', hue='arsenic_quantiles',palette='colorblind');

In [None]:
sns.boxplot(data=X, x='switch', y='arsenic', hue='arsenic_deciles', palette='colorblind');

In [None]:
sns.stripplot(data=X, x='switch', y='arsenic', hue='arsenic_quantiles', palette="colorblind");

In [None]:
sns.stripplot(data=X, x='switch', y='arsenic', hue='arsenic_deciles', palette="colorblind");

In [None]:
# So looking at the data, more people in the higher arsenic levels switched

let's move on to a model with a model of 'switch ~ dist100 + arsenic' 

In [None]:
data = X[['dist100', 'arsenic']]
target = X['switch']

In [None]:
from sklearn.linear_model import LogisticRegression
# C = 1e9 means no L2 regularization
clf = LogisticRegression(C=1e9, random_state=0).fit(data, target)


In [None]:
# these are MLE estimates of parameters we expect to recover
print(clf.intercept_)
print(clf.coef_)

## Using statsmodels

In [None]:
import statsmodels.formula.api as smf
statsmod = smf.logit(formula='switch ~ dist100 + arsenic', data=X)
result = statsmod.fit()
print(result.summary())

## Using Bayesian Regression with SVI

In [None]:
# convert data and target to torch tensors
data = tensor(data.values, dtype=torch.float)
target = tensor(target.values, dtype=torch.float)

In [None]:
data.size(), target.size()

In [None]:
from torch import nn
from pyro.nn import PyroSample, PyroModule

class BayesianLogisticRegression(PyroModule):
    def __init__(self, in_features, out_features = 1, bias = True):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        if bias:
          self.linear.bias = PyroSample(dist.Normal(0., 5.).expand([out_features]).to_event(1))
        self.linear.weight = PyroSample(dist.Normal(0., 5.).expand([out_features, in_features]).to_event(2))

        
    def forward(self, x, y=None):
        logits = self.linear(x).squeeze(-1)

        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Bernoulli(logits=logits), obs=y)
            print(obs)
        return logits

In [None]:
data.size(1) #number of columns of the predictor

2

In [None]:
model = BayesianLogisticRegression(data.size(1)) 

In [None]:
from pyro.infer.autoguide import AutoMultivariateNormal

guide = AutoMultivariateNormal(model, init_loc_fn=init_to_mean)

In [None]:
def train(model, guide, lr=0.01, n_steps=4000):
    pyro.set_rng_seed(1)
    pyro.clear_param_store()
    
    gamma = 0.01  # final learning rate will be gamma * initial_lr
    lrd = gamma ** (1 / n_steps)
    adam = pyro.optim.ClippedAdam({'lr': lr, 'lrd': lrd})

    svi = SVI(model, guide, adam, loss=Trace_ELBO())

    for i in range(n_steps):
        elbo = svi.step(data, target)
        if i % 500 == 0:
          print(f"Elbo loss: {elbo}")
    print(f"Final Elbo loss: {elbo}")

In [None]:
%%time
train(model, guide)

tensor([1., 1., 0.,  ..., 0., 0., 1.])


AssertionError: ignored

In [None]:
from pyro.infer import Predictive

num_samples = 1000
predictive = Predictive(model, guide=guide, num_samples=num_samples)

svi_samples = {k: v.reshape((num_samples,-1)).detach().cpu().numpy()
               for k, v in predictive(data, target).items()
               if k != "obs"}

In [None]:
svi_samples.keys()

In [None]:
svi_samples['linear.bias'].mean()

In [None]:
svi_samples['linear.weight'].mean(axis=0)

In [None]:
guide.quantiles([0.05,0.50,0.95])

In [None]:
print(clf.intercept_)
print(clf.coef_)

In [None]:
sns.kdeplot(data = svi_samples['linear.bias']);

In [None]:
sns.kdeplot(data = svi_samples['linear.weight']);

In [None]:
# So all three methods seem to be in agreement about the central tendencies of the coefficients.

# What about MCMC?

In [None]:
from pyro.infer import MCMC, NUTS
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=2000, warmup_steps=500)

In [None]:
%%time
mcmc.run(data, target)

In [None]:
model

In [None]:
hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

In [None]:
hmc_samples.keys()

In [None]:
np.median(hmc_samples['linear.bias'])

In [None]:
sns.kdeplot(data=hmc_samples['linear.bias']);

In [None]:
hmc_samples['linear.weight'].shape

In [None]:
print(np.median(hmc_samples['linear.weight'][:,0,0]))
sns.kdeplot(data=hmc_samples['linear.weight'][:,0,0]);

In [None]:
print(np.median(hmc_samples['linear.weight'][:,0,1]))
sns.kdeplot(data=hmc_samples['linear.weight'][:,0,1]);

In [None]:
# sklearn estimate
print(clf.intercept_)
print(clf.coef_)


In [None]:
import arviz as az

az_data = az.from_pyro(mcmc)
az.plot_trace(az_data, compact=False)
plt.tight_layout()