In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from jax import random
from numpyro import diagnostics, infer
from sklearn import datasets
from sklearn import linear_model

from myapp import spike_and_slab

np.random.seed(12345)

# Data

In [None]:
x, y = datasets.load_diabetes(return_X_y=True)
y = y[:, None]

x.shape, y.shape

# Full scratch

In [None]:
hyperparams = spike_and_slab.SpikeAndSlabHyperParams(
    a_w=1, b_w=1, nu_psi=5, q_psi=4, r=0.001
)

In [None]:
posterior_samples = spike_and_slab.gibbs_sampling(x, y, hyperparams)

In [None]:
plt.figure(figsize=(12, 8))

plt.subplot(321)
plt.plot(posterior_samples.mu)
plt.title("mu")

plt.subplot(322)
plt.plot(posterior_samples.alpha)
plt.title("alpha")

plt.subplot(323)
plt.plot(posterior_samples.sigma_2)
plt.title("sigma_2")

plt.subplot(324)
plt.plot(posterior_samples.delta)
plt.title("delta")

plt.subplot(325)
plt.plot(posterior_samples.psi)
plt.title("psi")

plt.subplot(326)
plt.plot(posterior_samples.w)
plt.title("w")

plt.tight_layout()
plt.show()

# Prediction

In [None]:
y_pred = spike_and_slab.predict(x, posterior_samples)

In [None]:
plt.scatter(y.ravel(), np.median(y_pred, 0).ravel())
plt.plot([20, 350], [20, 350], "--r") 
plt.xlabel("True")
plt.ylabel("Prediction")
plt.show()