In [29]:
from functools import partial
from typing import Callable

import arviz
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpyro
import pandas as pd
import seaborn as sns
import sklearn
from jax import Array
from numpyro import distributions as dist
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

In [30]:
iris = load_iris()
x_train, x_test, y_train, y_test = train_test_split(
    iris.data.astype(float), iris.target.astype(int), test_size=0.33, random_state=42
)

In [31]:
class Model:
    def __init__(self, x_dim=4, y_dim=3, h_dim=5):
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.h_dim = h_dim

    def __call__(self, x, y=None):
        """
        We need None for predictive
        """
        x_dim = self.x_dim
        y_dim = self.y_dim
        h_dim = self.h_dim
        # Number of observations
        n = x.shape[0]
        # standard deviation of Normals
        sd = 1  # EXERCISE: 100->1
        # Layer 1
        w1 = numpyro.sample("w1", dist.Normal(0, sd).expand((x_dim, h_dim)).to_event(2))
        b1 = numpyro.sample("b1", dist.Normal(0, sd).expand((h_dim,)).to_event(1))
        # Layer 2 # EXERCISE: added layer
        w2 = numpyro.sample("w2", dist.Normal(0, sd).expand((h_dim, h_dim)).to_event(2))
        b2 = numpyro.sample("b2", dist.Normal(0, sd).expand((h_dim,)).to_event(1))
        # Layer 3
        w3 = numpyro.sample("w3", dist.Normal(0, sd).expand((h_dim, y_dim)).to_event(2))
        b3 = numpyro.sample("b3", dist.Normal(0, sd).expand((y_dim,)).to_event(1))
        # NN
        h1 = jnp.tanh((x @ w1) + b1)
        h2 = jnp.tanh((h1 @ w2) + b2)  # EXERCISE: added layer
        logits = h2 @ w3 + b3
        # Save deterministc variable (logits) in trace
        numpyro.deterministic("logits", logits)
        # Categorical likelihood
        with numpyro.plate("labels", n):
            obs = numpyro.sample("obs", dist.Categorical(logits=logits), obs=y)
        return obs

In [32]:
kernel = numpyro.infer.NUTS(Model())
mcmc = numpyro.infer.MCMC(kernel, num_samples=1000, num_warmup=200, num_chains=4)
mcmc.run(jax.random.key(0), x=x_train, y=y_train)

  mcmc = numpyro.infer.MCMC(kernel, num_samples=1000, num_warmup=200, num_chains=4)
sample: 100%|██████████| 1200/1200 [00:08<00:00, 146.22it/s, 511 steps of size 1.22e-02. acc. prob=0.83] 
sample: 100%|██████████| 1200/1200 [00:07<00:00, 164.89it/s, 511 steps of size 9.71e-03. acc. prob=0.95]
sample: 100%|██████████| 1200/1200 [00:06<00:00, 175.61it/s, 511 steps of size 1.10e-02. acc. prob=0.92]
sample: 100%|██████████| 1200/1200 [00:07<00:00, 160.10it/s, 511 steps of size 8.91e-03. acc. prob=0.97]


In [33]:
arv = arviz.from_numpyro(mcmc)
arviz.summary(arv).describe()

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
count,373.0,373.0,373.0,373.0,373.0,373.0,373.0,373.0,373.0
mean,-0.00234,1.44881,-2.692094,2.732397,0.024965,0.018461,3898.935657,3470.597855,1.000509
std,2.141836,0.223968,2.364947,2.044342,0.00727,0.004809,1238.349594,581.353629,0.002434
min,-2.992,0.818,-6.353,-0.134,0.011,0.013,300.0,845.0,1.0
25%,-1.657,1.351,-4.541,1.236,0.021,0.016,3882.0,3490.0,1.0
50%,-0.089,1.477,-2.584,2.128,0.023,0.018,4389.0,3709.0,1.0
75%,2.18,1.572,-0.276,4.649,0.027,0.019,4559.0,3821.0,1.0
max,3.722,1.885,1.063,6.703,0.065,0.046,6039.0,3974.0,1.02


In [42]:
acc_train = (mcmc.get_samples()["logits"].mean(0).argmax(1) == y_train).mean()
print(f"Training accuracy: {acc_train:.1%}")

Training accuracy: 98.0%


In [50]:
pred_logits = numpyro.infer.Predictive(Model(), mcmc.get_samples())(
    jax.random.key(0), x_test
)["logits"]
acc_test = (pred_logits.mean(0).argmax(1) == y_test).mean()
print(f"Test accuracy: {acc_test:.1%}")

Test accuracy: 98.0%
