In [None]:
%load_ext autoreload
%autoreload 2

import logging
logging.basicConfig(format='%(message)s', level=logging.INFO)
import pickle


import pyro
pyro.enable_validation(True)
import pyro.distributions as dist
from pyro.nn import PyroSample, PyroModule
from torch import nn


In [None]:
import sys
sys.path.append('./')

In [None]:
from src.utils import SviPredictive
from src.visualization import plot_predictions

In [None]:
with open('./data/logistic-regression.pkl', 'rb') as f:
    data = pickle.load(f)
x_train = data['x_train']
y_train = data['y_train']
x_test = data['x_test']
y_test = data['y_test']

# Bayesian Logistic Regression


### [Logistic Regression](https://en.wikipedia.org/wiki/Logistic_regression)
An example of [Generalized Linear Model](https://en.wikipedia.org/wiki/Generalized_linear_model) (GLM) with [Logit](https://en.wikipedia.org/wiki/Logit) as link function.

$$ \mathrm{logit}(p) = \log \left( \frac{p}{1 - p} \right) $$

The inverse of Logit is [logistic function](https://en.wikipedia.org/wiki/Logistic_function)

$$ \mathrm{logit}^{-1}(\alpha) = \frac{1}{1 + \exp{(-\alpha)}}$$

### Bayesian Logistic Regression

$$ y \sim \mathrm{Bernoulli}(p) $$

$$ \mathrm{logit}(p) = \beta^\intercal X $$

In [None]:
class BayesianLogisticRegression(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(
            dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)
        )
        self.linear.bias = PyroSample(
            dist.Normal(0., 10.).expand([out_features]).to_event(1)
        )
        self.sigmoid = nn.Sigmoid() # logistic function

    def forward(self, x, y=None):
        mean = self.linear(x).squeeze(-1)
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Bernoulli(self.sigmoid(mean)), obs=y)
        return mean

But `Sigmoid` is deterministic and has no parameters. `Bernoulli` can't handle pure `mean`? Yes it can! Just pass it as `logits` instead of `probs`.

In [None]:
class BayesianLogisticRegression(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(
            dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)
        )
        self.linear.bias = PyroSample(
            dist.Normal(0., 10.).expand([out_features]).to_event(1)
        )

    def forward(self, x, y=None):
        mean = self.linear(x).squeeze(-1)
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Bernoulli(logits=mean), obs=y)
        return mean

In [None]:
from pyro.infer.autoguide import AutoDiagonalNormal

model = BayesianLogisticRegression(11, 1)
guide = AutoDiagonalNormal(model)

from pyro.infer import SVI, Trace_ELBO


adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())


num_iterations = 5_000
pyro.clear_param_store()
for j in range(num_iterations):
    loss = svi.step(x_train, y_train)
    if j % 500 == 0:
        print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))

In [None]:
dict(pyro.get_param_store())

In [None]:
num_samples = 1000
svi_predictive = SviPredictive(
    model,
    guide=guide,
    num_samples=num_samples,
    return_sites=('obs','_RETURN')
)
svi_predictive(x_test)

In [None]:
properties = {
    'x': 9,
    'x_label': "Percent of sugar",
    'y_label': "Average output of the model",
    'y_labels': {
        0: 'Has no chocolate (0)',
        1: 'Has chocolate (1)',
    },
    'category': 7,
    'category_labels': {
        0: "Is not a bar",
        1: "Is bar",
    }
}
predictors = {
    'SVI': svi_predictive
}
data = {
    'x': x_test,
    'y': y_test
}
plot_predictions(data, predictors, properties, regression='logistic')