# Scaled Linear Regression

In [1]:
import arviz as az
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from jax import random, ops
from numpyro import diagnostics, infer

rng_key = random.PRNGKey(0)

# Original data

In [10]:
rng_key, rng_key_eta, rng_key_x, rng_key_y, rng_key_z = random.split(rng_key, 5)
eta = dist.Normal(50, 20).sample(rng_key_eta, (500,))
x = dist.Normal(eta, 10).sample(rng_key_x)
y = dist.Normal(5 * eta, 10).sample(rng_key_y)
z = dist.Normal(2 * x + y + 3, 0.1).sample(rng_key_z)

# 2x + y

In [11]:
def linear_regression(x, y, z):
    a = numpyro.sample("a", dist.Normal(0, 1))
    b = numpyro.sample("b", dist.Normal(0, 1))
    c = numpyro.sample("c", dist.Normal(0, 1))
    sigma = numpyro.sample("sigma", dist.HalfCauchy(10))
    numpyro.sample("z", dist.Normal(a * x + b * y + c, sigma), obs=z)

In [12]:
rng_key, rng_key_infer = random.split(rng_key)

kernel = infer.NUTS(linear_regression)
mcmc = infer.MCMC(kernel, num_warmup=500, num_samples=500)
mcmc.run(rng_key_infer, x, y, z)
posterior_samples = mcmc.get_samples()

sample: 100%|██████████| 1000/1000 [00:05<00:00, 192.02it/s, 1003 steps of size 3.11e-03. acc. prob=0.94]


In [13]:
mcmc.print_summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
         a      2.00      0.00      2.00      2.00      2.00    398.87      1.00
         b      1.00      0.00      1.00      1.00      1.00    330.93      1.00
         c      2.99      0.01      2.99      2.97      3.00    138.87      1.00
     sigma      0.10      0.00      0.10      0.09      0.10    167.98      1.00

Number of divergences: 0


# Scaled

In [14]:
rng_key, rng_key_infer = random.split(rng_key)

x_scl = (x - x.mean()) / x.std()
y_scl = (y - y.mean()) / y.std()

kernel = infer.NUTS(linear_regression)
mcmc = infer.MCMC(kernel, num_warmup=500, num_samples=500)
mcmc.run(rng_key_infer, x_scl, y_scl, z)
posterior_samples = mcmc.get_samples()

sample: 100%|██████████| 1000/1000 [00:04<00:00, 217.20it/s, 3 steps of size 7.90e-01. acc. prob=0.90]


In [15]:
mcmc.print_summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
         a      0.57      0.98      0.59     -1.02      2.18    431.43      1.00
         b      0.49      0.99      0.47     -1.21      2.02    451.91      1.00
         c      1.26      1.08      1.29     -0.69      2.68    400.85      1.00
     sigma    376.90     11.79    376.17    356.97    394.08    474.56      1.00

Number of divergences: 0
