In [None]:
%%HTML
<!-- Mejorar visualización en proyector -->
<style>
.rendered_html {font-size: 1.2em; line-height: 150%;}
div.prompt {min-width: 0ex; padding: 0px;}
.container {width:95% !important;}
</style>

In [3]:
%autosave 0
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm.notebook import tqdm
import pyro

Autosave disabled


In [None]:
# Synthetic data
se = 0.1
np.random.seed(0)
x = np.linspace(0, 1, num=20) 
x_test = np.linspace(-0.05, 1.05, num=200)
f = lambda x : x*np.sin(10*x)

x = np.delete(x, slice(9, 14))
y = f(x) + se*np.random.randn(len(x))
fig, ax = plt.subplots(figsize=(7, 3), tight_layout=True)
ax.scatter(x, y);

x_torch = torch.from_numpy(x.astype('float32')).unsqueeze(1)
x_test = torch.from_numpy(x_test.astype('float32')).unsqueeze(1)
y_torch = torch.from_numpy(y.astype('float32')).unsqueeze(1)

In [None]:
torch.tensor([[0., 0.]]).shape

In [8]:
prior = GaussianScaleMixture(component_logits=torch.tensor([0.5, 0.5]), coord_scale=torch.tensor([1.]), 
                             component_scale=torch.tensor([1., 0.01]))

NotImplementedError: This distribution does not support D = 1

In [None]:
prior.batch_shape

In [None]:
prior.event_shape

In [None]:
prior.rsample()

In [None]:
prior.expand_by([1, 10])

In [None]:
prior.expand([10]).to_event(2).shape()

In [None]:
prior.expand_by([10]).to_event(1).rsample().shape

In [None]:
prior.expand_by?

In [None]:
PyroSample(prior.expand_by((10, 1)).to_event(2))

In [None]:
MixtureOfDiagNormals??

In [1]:
from pyro.nn import PyroSample, PyroModule
from pyro.distributions import Uniform, Normal, GaussianScaleMixture

class BayesianMLPRegression(PyroModule):
    def __init__(self, n_hidden=10, prior_scale=1.):
        super().__init__()
        #prior = Normal(0, prior_scale)
        prior = MixtureOfDiagNormals(locs=torch.tensor([[0.], [0.]]),
                                     coord_scale=torch.tensor([[prior_scale], [0.01]]),
                                     component_logits=torch.tensor([0.5, 0.5]))
        # Hidden layer
        #display(prior.expand([n_hidden, 1]).to_event(2).shape())
        self.hidden = PyroModule[torch.nn.Linear](1, n_hidden)
        self.hidden.weight = PyroSample(prior.expand_by([n_hidden, 1]).to_event(2))
        self.hidden.bias = PyroSample(prior.expand_by([n_hidden]).to_event(1))
        # Output layer
        self.output = PyroModule[torch.nn.Linear](n_hidden, 1)
        self.output.weight = PyroSample(prior.expand([1, n_hidden]).to_event(2))
        self.output.bias = PyroSample(prior.expand([1]).to_event(1))
        # activation function
        self.activation = torch.nn.Tanh()
        
    def forward(self, x, y=None):
        z = self.activation(self.hidden(x))
        mean = self.output(z).squeeze(-1)
        sigma = pyro.sample("sigma", Uniform(0.0, 0.1))
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", Normal(mean, sigma), obs=y) #likelihood
        return mean

In [None]:
pyro.enable_validation(True)

model = BayesianMLPRegression()

print(pyro.poutine.trace(model).get_trace(x_torch, y_torch).format_shapes())

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 3), tight_layout=True, dpi=80)
#ax[0].set_yscale('log')

def update_plot(k, epoch_loss, samples):
    ax[0].cla()
    ax[0].plot(range(k), epoch_loss[:k])
    #ax[0].autoscale_view()
    ax[1].cla()
    ax[1].plot(x, y, 'k.');
    med = np.median(samples, axis=[0])
    qua = np.quantile(samples, (0.05, 0.95), axis=0)
    ax[1].plot(x_test.numpy()[:, 0], med)
    ax[1].fill_between(x_test.numpy()[:, 0], qua[0], qua[1], alpha=0.5)
    fig.canvas.draw()

In [None]:
pyro.enable_validation(True) # Turn this on for additional debugging
pyro.clear_param_store() 
model = BayesianMLPRegression(n_hidden=10, prior_scale=1.) # Declare the neural network

# Create a guide
from pyro.infer.autoguide import AutoDiagonalNormal
guide = AutoDiagonalNormal(model, init_scale=1e-2)

# Create SVI object
svi = pyro.infer.SVI(model, 
                     guide, 
                     optim=pyro.optim.ClippedAdam({'lr':1e-2, 'clip_norm': 10.0}), # Optimizer
                     loss=pyro.infer.Trace_ELBO()) # Loss function 

epoch_loss = np.zeros(shape=(10000,))
for k in tqdm(range(len(epoch_loss))):
    loss = svi.step(x=x_torch, y=y_torch.squeeze(-1)) # Actual training step
    epoch_loss[k] = loss / len(x_torch)
        
    if k % 100 == 0:
        # Compute predictive posterior
        predictive = pyro.infer.Predictive(model, guide=guide, num_samples=100)
        samples = predictive(x_test, None)['obs'].detach().numpy()
        # Plot it
        update_plot(k, epoch_loss, samples)        

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt

In [None]:
import numpyro
import numpy as np
import jax.numpy as jnp
import jax.random as random
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

In [None]:
# Synthetic data
se = 0.1
np.random.seed(0)
x = np.linspace(0, 1, num=20) 
x_test = np.linspace(-0.05, 1.05, num=200)
f = lambda x : x*np.sin(10*x)

x = np.delete(x, slice(9, 14))
y = f(x) + se*np.random.randn(len(x))
fig, ax = plt.subplots(figsize=(7, 3), tight_layout=True)
ax.scatter(x, y);

#x_torch = torch.from_numpy(x.astype('float32')).unsqueeze(1)
#x_test = torch.from_numpy(x_test.astype('float32')).unsqueeze(1)
#y_torch = torch.from_numpy(y.astype('float32')).unsqueeze(1)

In [None]:
def model(x, y, n_hidden=10):

    
    w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((1, n_hidden)), 
                                          jnp.ones((1, n_hidden))))  # D_X D_H
    
    b1 = numpyro.sample("b1", dist.Normal(jnp.zeros((n_hidden,)), 
                                          jnp.ones((n_hidden,))))  # D_X D_H
    
    z1 = jnp.tanh(jnp.matmul(x, w1) + b1)   
    
    w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((n_hidden, n_hidden)), 
                                          jnp.ones((n_hidden, n_hidden))))  # D_H D_H
    b2 = numpyro.sample("b2", dist.Normal(jnp.zeros((n_hidden,)), 
                                          jnp.ones((n_hidden,))))  # D_X D_H
    z2 = jnp.tanh(jnp.matmul(z1, w2) + b2)  # N D_H  <= second layer of activations

    # sample final layer of weights and neural network output
    w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((n_hidden, 1)), 
                                          jnp.ones((n_hidden, 1)))) 
    b3 = numpyro.sample("b3", dist.Normal(jnp.zeros((1,)), 
                                          jnp.ones((1,))))  # D_H D_Y
    z3 = jnp.matmul(z2, w3) + b3  # N D_Y  <= output of the neural network

    # we put a prior on the observation noise
    prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
    sigma_obs = 1.0 / jnp.sqrt(prec_obs)

    # observe data
    numpyro.sample("obs", dist.Normal(z3, sigma_obs), obs=y)

In [None]:
rng_key, rng_key_predict = random.split(random.PRNGKey(0))

numpyro.set_host_device_count(2)
kernel = NUTS(model)
mcmc = MCMC(kernel, 100, 1000, num_chains=2, progress_bar=True)
mcmc.run(rng_key, x.reshape(-1,1), y.reshape(-1,1), 10)
mcmc.print_summary()

In [None]:
samples = mcmc.get_samples()

def predict(model, rng_key, samples, X, D_H):
    model = numpyro.handlers.substitute(numpyro.handlers.seed(model, rng_key), samples)
    # note that Y will be sampled in the model because we pass Y=None here
    model_trace = numpyro.handlers.trace(model).get_trace(x=X, y=None, n_hidden=D_H)
    return model_trace['obs']['value']

from jax import vmap
vmap_args = (samples, random.split(rng_key, 1000 * 2))
predictions = vmap(lambda samples, rng_key: predict(model, rng_key, samples, x_test.reshape(-1,1), 10))(*vmap_args)
predictions = predictions[..., 0]

# compute mean prediction and confidence interval around median
mean_prediction = jnp.mean(predictions, axis=0)
percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)

In [None]:
predictions

In [None]:
plt.figure()
plt.plot(x_test, mean_prediction)
plt.fill_between(x_test, percentiles[0, :], percentiles[1, :], alpha=0.5)
plt.scatter(x, y)

- https://petuum.com/2019/01/15/intro-to-modern-bayesian-learning-and-probabilistic-programming/

- https://github.com/pyro-ppl/numpyro/blob/master/examples/bnn.py