# Linear regression

based on https://peterroelants.github.io/posts/linear-regression-four-ways/


- https://towardsdatascience.com/introduction-to-bayesian-linear-regression-e66e60791ea7
- http://num.pyro.ai/en/latest/tutorials/bayesian_hierarchical_linear_regression.html
- https://www.hellocybernetics.tech/entry/2020/02/23/034551
- https://arviz-devs.github.io/arviz/user_guide/numpyro_refitting.html


In [None]:
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [None]:
import sys
import warnings

import numpy as np

import jax
import jax.numpy as jnp

import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm  # Colormaps
import seaborn as sns
import arviz as az
from tqdm import tqdm_notebook as tqdm
from IPython.display import display

sns.set_style('darkgrid')
az.rcParams['stats.hdi_prob'] = 0.90
az.style.use("arviz-darkgrid")

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(8)

In [None]:
rng_key = jax.random.PRNGKey(42)

In [None]:
# Define the data
rng_key, rng_key_ = jax.random.split(rng_key)
np.random.seed(rng_key_)
# Generate random data
n = 50 # Number of samples
# Underlying linear relation
m = 2.32 # slope
b = 4.11  # bias
fn = lambda x_: x_ * m + b
# Noise
e_std = 0.5  # Standard deviation of the noise
err = e_std * np.random.randn(n)  # Noise
# Features and output
x_data = np.random.uniform(-1, 1, n)  # Independent variable x
y_data = fn(x_data) + err  # Dependent variable

# Show data
plt.figure(figsize=(7, 4), dpi=100)
plt.scatter(x_data, y_data, label='data: $(x,y)$', color='tab:blue')
plt.plot(
    [-1, 1], [fn(-1), fn(1)], color='black', linestyle='-',
    label=f'$y = {b:.2f} + {m:.2f} x$')
plt.xlim((-1, 1))
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title('Noisy data samples from linear line')
plt.legend()
plt.show()
#

In [None]:
def plot_predictions(x_samples, predictions, name):
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(7, 8), dpi=100)
    # Plot prior parameters
    y_mu_mean = jnp.mean(predictions['y_mu'], 0)
    y_mu_pct = jnp.percentile(predictions['y_mu'], q=np.array([10., 90., 1., 99.]), axis=0)
    for i in range(min(10, predictions['y_mu'].shape[0])):
        yi = predictions['y_mu'][i]
        label=None
        if i == 0:
            label = 'samples'
        ax1.plot(x_samples, yi, color='tab:gray', linestyle='-', alpha=0.5, label=label)
    ax1.plot(x_samples, y_mu_mean, color='tab:blue', linestyle='-', label='mean($\mu_y$)', linewidth=2)
    ax1.fill_between(x_samples, y_mu_pct[0], y_mu_pct[1], color='tab:blue', alpha=0.2, label='$\mu_y \; 90\%$')
    ax1.fill_between(x_samples, y_mu_pct[2], y_mu_pct[3], color='tab:blue', alpha=0.1, label='$\mu_y \; 99\%$')
    ax1.set_xlim((-1, 1))
    ax1.set_xlabel('$x$')
    ax1.set_ylabel('$y$')
    ax1.set_title(f'{name} parameter distribution')
    ax1.legend(loc='lower right')

    # Plot prior predictions
    y_mean = jnp.mean(predictions['y'], 0)
    y_pct = jnp.percentile(predictions['y'], q=np.array([10., 90., 1., 99.]), axis=0)
    # Plot samples
    for i in range(min(200, predictions['y'].shape[0])):
        yi = predictions['y'][i]
        label=None
        if i == 0:
            label = 'samples'
        ax2.plot(x_samples, yi, color='tab:blue', marker='o', alpha=0.02, label=label)
    ax2.plot(x_samples, y_mean, 'k-', label='mean($y$)')
    ax2.fill_between(x_samples, y_pct[0], y_pct[1], color='k', alpha=0.2, label='$y \; 90\%$')
    ax2.fill_between(x_samples, y_pct[2], y_pct[3], color='k', alpha=0.1, label='$y \; 99\%$')
    ax2.set_xlim((-1, 1))
    ax2.set_xlabel('$x$')
    ax2.set_ylabel('$y$')
    ax2.set_title(f'{name} predictive distribution')
    ax2.legend(loc='lower right')
    plt.show()

$$
y_i \sim \mathcal{N}(\theta_0 + \theta_1 x_i, \sigma^2) \quad (i = 1, \ldots, n)
$$

In [None]:
def model(x, y):
    theta_0 = numpyro.sample('theta_0', dist.Normal(0., 10.))
    theta_1 = numpyro.sample('theta_1', dist.Normal(0., 10.))
    y_sigma = numpyro.sample('y_sigma', dist.Exponential(1.))
    y_mu = numpyro.deterministic('y_mu', theta_0 + theta_1 * x)
    numpyro.sample('y', dist.Normal(y_mu, y_sigma), obs=y)

In [None]:
rng_key, rng_key_ = jax.random.split(rng_key)

x_samples = np.linspace(-1, 1, 100)
num_prior_predictive_samples = 1000
prior_predictive = Predictive(model, num_samples=num_prior_predictive_samples)
prior_predictions = prior_predictive(rng_key_, x=x_samples, y=None)

In [None]:
plot_predictions(x_samples, prior_predictions, 'Prior')

In [None]:
rng_key, rng_key_ = jax.random.split(rng_key)

num_warmup, num_samples = 1000, 2000

# Run NUTS.
kernel = NUTS(model)
mcmc = MCMC(
    kernel,
    num_warmup=num_warmup,
    num_samples=num_samples,
    num_chains=4,
    chain_method='parallel'
)
mcmc.run(rng_key_, x=x_data, y=y_data)
mcmc.print_summary()
mcmc_samples = mcmc.get_samples()

In [None]:
az_posterior = az.from_numpyro(posterior=mcmc)
display(az_posterior)

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
az.plot_posterior(az_posterior, var_names=['theta_0', 'theta_1', 'y_sigma'], ax=ax)
plt.suptitle('Posterior plots', fontsize=18)
plt.show()

In [None]:
fig, axes = plt.subplots(4, 2, figsize=(12, 8))
az.plot_trace(az_posterior, compact=True, axes=axes)
plt.suptitle('Trace plots', fontsize=18)
plt.show()

In [None]:
# fig, ax = plt.subplots(1, 3, figsize=(14, 4), dpi=70)
# az.plot_posterior(az_posterior, var_names=['theta_0', 'theta_1', 'y_sigma'], ax=ax)
# plt.suptitle('Posterior plots', fontsize=18)
# plt.tight_layout()
# plt.show()

In [None]:
rng_key, rng_key_ = jax.random.split(rng_key)

x_samples = np.linspace(-1, 1, 100)
posterior_predictive = Predictive(model, posterior_samples=mcmc_samples)
posterior_predictions = posterior_predictive(rng_key_, x=x_samples, y=None)

In [None]:
plot_predictions(x_samples, posterior_predictions, 'Posterior')