# Linear regression with SVI

Resources:
- http://pyro.ai/examples/svi_part_i.html
- http://pyro.ai/examples/svi_part_ii.html
- http://pyro.ai/examples/svi_part_iii.html
- http://pyro.ai/examples/svi_part_iv.html
- http://num.pyro.ai/en/stable/svi.html

Discourse thread:
- https://forum.pyro.ai/t/large-variance-in-svi-losses-to-be-expected/3435

In [None]:
# Imports
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [None]:
import sys
import warnings

import numpy as np

import jax
import jax.numpy as jnp
from jax.experimental import optimizers

import numpyro
from numpyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO, Predictive
import numpyro.distributions as dist
from numpyro import handlers

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
sns.set_style('darkgrid')

In [None]:
numpyro.set_platform('cpu')
numpyro.set_host_device_count(8)

In [None]:
ground_truth_params = {
    "slope" : 2.32,
    "intercept": 4.11,
    "noise_std": 0.5
}

# Create Dataset

In [None]:
# Define the data
np.random.seed(42)
# Generate random data
n = 51 # Number of samples
# Linear relation
slope_true = ground_truth_params["slope"]
intercept_true = ground_truth_params["intercept"]
fn = lambda x_: x_ * slope_true + intercept_true
# Noise
err = ground_truth_params["noise_std"] * np.random.randn(n)  # Noise
# Features and output
x_data = np.random.uniform(-1., 1., n)  # Independent variable x
y_data = fn(x_data) + err  # Dependent variable

# Show data
plt.figure(figsize=(7, 4), dpi=100)
plt.scatter(x_data, y_data, label='data: $(x,y)$', color='tab:blue')
x_bound = (float(x_data.min()), float(x_data.max()))
plt.plot(
    x_bound, [fn(x_bound[0]), fn(x_bound[1])], color='black', linestyle='-',
    label=f'$y = {intercept_true:.2f} + {slope_true:.2f} x$')
plt.xlim(x_bound)
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title('Noisy data samples from linear line')
plt.legend()
plt.show()
#

# Define model and variational distribution

$$
\mu_i = \text{intercept} + \text{slope} * x_i \\
y_i \sim \mathcal{N}(\mu_i, \sigma) \quad (i = 1, \ldots, n)
$$

In [None]:
def model(x, y):
    slope = numpyro.sample('slope', dist.Normal(0., 10.))
    intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
    noise_std = numpyro.sample('noise_std', dist.Exponential(1.))
    with numpyro.plate('obs', x.shape[0]):
        y_loc = numpyro.deterministic('y_loc', intercept + slope * x)
        numpyro.sample('y', dist.Normal(y_loc, noise_std), obs=y)

In [None]:
def guide(x, y):
    slope_loc = numpyro.param("slope_loc", 0.)
    slope_scale = numpyro.param("slope_scale", 0.01, constraint=dist.constraints.positive)
    slope = numpyro.sample('slope', dist.Normal(slope_loc, slope_scale))
    intercept_loc = numpyro.param("intercept_loc", 0.)
    intercept_scale = numpyro.param("intercept_scale", 0.01, constraint=dist.constraints.positive)
    intercept = numpyro.sample('intercept', dist.Normal(intercept_loc, intercept_scale))
    noise_std_log_loc = numpyro.param("noise_std_log_loc", 0.1)
    noise_std_scale = numpyro.param("noise_std_scale", 0.01, constraint=dist.constraints.positive)
    noise_std = numpyro.sample('noise_std', dist.LogNormal(noise_std_log_loc, noise_std_scale))

Use `LogNormal` for guide distribution, since using `Exponential` leads to high variance.
Exponential guid that leads to high variance:
```
noise_std_rate = numpyro.param("noise_std_rate", 1., constraint=dist.constraints.positive)
noise_std = numpyro.sample('noise_std', dist.Exponential(noise_std_rate))
```

In [None]:
_s = dist.LogNormal(loc=-0.72457, scale=0.13578562).sample(key=jax.random.PRNGKey(42), sample_shape=(1000,))
sns.histplot(_s)
plt.show()

## Fit SVI

In [None]:
# Learning rate schedule
def cosine_annealing(lr_min, lr_max, num_steps, i):
    return lr_min + 0.5 * (lr_max - lr_min) * (1 + jnp.cos(jnp.pi * i / num_steps))


num_steps = 5000
lr_max = 2e-3
lr_min = 1e-4

iterations = jnp.arange(num_steps)
lr_steps = cosine_annealing(lr_min, lr_max, num_steps, iterations)


def lr_schedule(idx):
    return lr_steps[idx]

In [None]:
# Use clipped Optimizer to deal with unstable gradients
# http://num.pyro.ai/en/stable/optimizers.html#clippedadam
optimizer = numpyro.optim.ClippedAdam(step_size=lr_schedule, clip_norm=1.0)

# setup the inference algorithm
svi = SVI(
    model=model,
    guide=guide,
    optim=optimizer,
    loss=TraceMeanField_ELBO(num_particles=1)
)

# Run
svi_result = svi.run(
    jax.random.PRNGKey(0),
    num_steps=5000,
    x=x_data,
    y=y_data
)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7, 3))
ax.plot(svi_result.losses)
ax.set_title("losses")
ax.set_yscale("symlog")
plt.show()

In [None]:
svi_predictive = Predictive(
    guide,
    params=svi_result.params,
    num_samples=2000
)
posterior_samples = svi_predictive(
    jax.random.PRNGKey(0),
    x=x_data,
    y=y_data
)

In [None]:
fig, axs = plt.subplots(1, len(posterior_samples), figsize=(12, 4))

for ax, (param_name, param_samples) in zip(axs, posterior_samples.items()):
    d = sns.histplot(param_samples, kde=True, stat='probability', ax=ax)
    ax.set_xlabel(param_name)
    ax.set_title(f"Samples from {param_name!r}")
    ax.axvline(np.mean(param_samples), color="black", label="mean")
    ax.axvline( ground_truth_params[param_name], color="red", label="true")
fig.legend(*ax.get_legend_handles_labels(), bbox_to_anchor=(0., 0.7, 1.0, -.0))
plt.show()

In [None]:
for param_name, param_samples in posterior_samples.items():
    param_gt = ground_truth_params[param_name]
    param_mean = np.mean(param_samples)
    param_std = np.std(param_samples)
    param_median = np.median(param_samples)
    param_quantile_low, param_quantile_high = np.quantile(param_samples, (.025, .975))
    print(f"{param_name:>13}: true={param_gt:.2f}\t median={param_median:.2f}\t 95%-interval: {param_quantile_low:+.2f} - {param_quantile_high:+.2f}\t "
          f"mean:{param_mean:.2f}±{param_std:.2f}")

In [None]:
mean_slope = np.mean(posterior_samples["slope"])
mean_intercept = np.mean(posterior_samples["intercept"])
y_mean_pred = jnp.array([-1., 1]) * mean_slope + mean_intercept


# Show mean fit vs data
plt.figure(figsize=(7, 4), dpi=100)
plt.scatter(x_data, y_data, label='data: $(x,y)$', color='tab:blue')
x_bound = (float(x_data.min()), float(x_data.max()))
plt.plot(
    x_bound, [fn(x_bound[0]), fn(x_bound[1])], color='black', linestyle='-',
    label='true'
)
plt.plot(
    x_bound, y_mean_pred, color='red', linestyle='-',
    label=f'$pred = {mean_intercept:.2f} + {mean_slope:.2f} x$'
)
plt.xlim(x_bound)
plt.xlabel('$x$')
plt.ylabel('$y$')
plt.title('Mean fit vs ground-truth data')
plt.legend()
plt.show()

## Posterior predictions

In [None]:
def plot_predictions(x_samples, predictions, name):
    x_bound = (float(x_samples.min()), float(x_samples.max()))
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(7, 8))
    # Plot prior parameters
    y_mu_mean = jnp.mean(predictions['y_loc'], 0)
    y_mu_pct = jnp.percentile(predictions['y_loc'], q=np.array([5., 95., 0.5, 99.5]), axis=0)
    for i in range(min(10, predictions['y_loc'].shape[0])):
        yi = predictions['y_loc'][i]
        label=None
        if i == 0:
            label = 'samples'
        ax1.plot(x_samples, yi, color='tab:gray', linestyle='-', alpha=0.5, label=label)
    ax1.plot(x_samples, y_mu_mean, color='tab:blue', linestyle='-', label='mean($\mu_y$)', linewidth=2)
    ax1.fill_between(x_samples, y_mu_pct[0], y_mu_pct[1], color='tab:blue', alpha=0.2, label='$\mu_y \; 90\%$')
    ax1.fill_between(x_samples, y_mu_pct[2], y_mu_pct[3], color='tab:blue', alpha=0.1, label='$\mu_y \; 99\%$')
    ax1.set_xlim(x_bound)
    ax1.set_xlabel('$x$')
    ax1.set_ylabel('$y$')
    ax1.set_title(f'{name} parameter distribution')
    ax1.legend(loc='lower right')

    # Plot prior predictions
    y_mean = jnp.mean(predictions['y'], 0)
    y_pct = jnp.percentile(predictions['y'], q=np.array([5., 95., 0.5, 99.5]), axis=0)
    # Plot samples
    for i in range(min(100, predictions['y'].shape[0])):
        yi = predictions['y'][i]
        label=None
        if i == 0:
            label = 'samples'
        ax2.plot(x_samples, yi, color='tab:blue', marker='o', alpha=0.03, label=label)
    ax2.plot(x_samples, y_mean, 'k-', label='mean($y$)')
    ax2.fill_between(x_samples, y_pct[0], y_pct[1], color='k', alpha=0.2, label='$y \; 90\%$')
    ax2.fill_between(x_samples, y_pct[2], y_pct[3], color='k', alpha=0.1, label='$y \; 99\%$')
    ax2.set_xlim(x_bound)
    ax2.set_xlabel('$x$')
    ax2.set_ylabel('$y$')
    ax2.set_title(f'{name} predictive distribution')
    ax2.legend(loc='lower right')
    plt.tight_layout()
    plt.show()

In [None]:
# Get posterior predictive samples
# https://forum.pyro.ai/t/svi-version-of-mcmc-get-samples/3069/4
posterior_predictive = Predictive(
    model=model,
    guide=guide,
    params=svi_result.params,
    num_samples=1000
)

x_samples = np.linspace(-1.5, 1.5, 100)
posterior_predictions = posterior_predictive(jax.random.PRNGKey(42), x=x_samples, y=None)

In [None]:
plot_predictions(x_samples, posterior_predictions, 'Posterior')