In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
import torch
import pyro
from tqdm.notebook import tqdm
import sklearn.datasets

Create synthetic data

In [None]:
X, y = sklearn.datasets.make_moons(200, noise=0.2)

x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01), np.arange(y_min, y_max, 0.01))

fig, ax = plt.subplots(figsize=(6, 3))
ax.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Spectral_r, alpha=0.5);

Neural network model in pytorch
- class that inherits from `torch.nn.Module`
- `__init__(self, args):` Define layers
- `forward(self, x):` Define how layers are connected

In [None]:
class NNet_classifier(torch.nn.Module):
    
    def __init__(self, num_hidden=10):
        super(NNet_classifier, self).__init__()
        self.layer1 = torch.nn.Linear(2, num_hidden)
        self.layer2 = torch.nn.Linear(num_hidden, num_hidden)
        self.layer3 = torch.nn.Linear(num_hidden, 1)
        self.activation = torch.nn.Tanh()
        
    def forward(self, x): 
        z = self.activation(self.layer1(x))
        z = self.activation(self.layer2(z))
        return self.layer3(z) #Neural net output

Neural network training

- `criterion`: Cost function to be minimized, *.e.g.* BCE for binary classification 
- `optimizer`: Optimization algorithm, typically based on stochastic gradient descent ($\eta$ is the learning rate)
$$
\theta_{t+1} = \theta_{t} - \eta \nabla_\theta L(\theta_t)
$$


Training is performed by
1. Evaluating the network using `forward`
1. Calculating the error/loss selected in `criterion`
1. Computing the derivatives of the error using the `backward` attribute of the error
1. Updating parameters according to `optimizer`

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(7, 3), tight_layout=True)
line2 = ax[1].plot([], [])

def update_plot(model):
    ax[0].cla()
    Z = model.forward(torch.from_numpy(np.c_[xx.ravel(), yy.ravel()].astype('float32')))
    Z = torch.nn.Sigmoid()(Z).detach().numpy().reshape(xx.shape)
    ax[0].contourf(xx, yy, Z, cmap=plt.cm.RdBu_r, alpha=0.75, vmin=0, vmax=1)
    ax[0].scatter(X[y==0, 0], X[y==0, 1], c='k', marker='o', s=20, alpha=0.25)
    ax[0].scatter(X[y==1, 0], X[y==1, 1], c='k', marker='x', s=20, alpha=0.25)

    line2[0].set_xdata(range(k))
    line2[0].set_ydata(epoch_loss[:k])
    for ax_ in ax:
        ax_.relim()
        ax_.autoscale_view()
    fig.canvas.draw()

In [None]:
model = NNet_classifier(num_hidden=20)
display(model)
criterion = torch.nn.BCEWithLogitsLoss(reduction='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def train_one_epoch(x, y, phase='train'):
    haty = model.forward(x) # Evaluate the model
    loss = criterion(haty, y) # Calculate errors
    if phase == 'train':
        optimizer.zero_grad()
        loss.backward() # Compute derivaties
        optimizer.step() # Update parameters 
    return loss.item()

x_train = torch.from_numpy(X.astype('float32'))#.reshape(-1, 1)
y_train = torch.from_numpy(y.astype('float32')).reshape(-1, 1)
epoch_loss = np.zeros(shape=(6000,)) 

for k in tqdm(range(len(epoch_loss))):
    epoch_loss[k] = train_one_epoch(x_train, y_train)
    if k % 100 == 0: 
        update_plot(model)

In [None]:
import pyro.distributions as dist

class BayesianNNet_classifier(pyro.nn.PyroModule):
    def __init__(self, num_hidden=10, prior_std=10.):
        super().__init__()
        prior = dist.Normal(0, prior_std)
        self.layer1 = pyro.nn.PyroModule[torch.nn.Linear](2, num_hidden)
        self.layer1.weight = pyro.nn.PyroSample(prior.expand([num_hidden, 2]).to_event(2))
        self.layer1.bias = pyro.nn.PyroSample(prior.expand([num_hidden]).to_event(1))
        
        #self.layer2 = pyro.nn.PyroModule[torch.nn.Linear](num_hidden, num_hidden)
        #self.layer2.weight = pyro.nn.PyroSample(prior.expand([num_hidden, num_hidden]).to_event(2))
        #self.layer2.bias = pyro.nn.PyroSample(prior.expand([num_hidden]).to_event(1))
        
        self.layer3 = pyro.nn.PyroModule[torch.nn.Linear](num_hidden, 1)
        self.layer3.weight = pyro.nn.PyroSample(prior.expand([1, num_hidden]).to_event(2))
        self.layer3.bias = pyro.nn.PyroSample(prior.expand([1]).to_event(1))        
        
        self.activation = torch.nn.Tanh()

    def forward(self, x, y=None):
        h = self.activation(self.layer1(x))
        #h = self.activation(self.layer2(h))
        p = self.layer3(h).squeeze(1)
        with pyro.plate("data", size=x.shape[0], dim=-1):
            obs = pyro.sample("obs", dist.Bernoulli(logits=p), obs=y)
        return p

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(7, 3), tight_layout=True)
line2 = ax[1].plot([], [])

def update_plot(samples):
    ax[0].cla()
    Z = samples["obs"].mean(0).reshape(xx.shape).detach().numpy()
    ax[0].contourf(xx, yy, Z, cmap=plt.cm.RdBu_r, alpha=0.75, vmin=0, vmax=1)
    ax[0].scatter(X[y==0, 0], X[y==0, 1], c='k', marker='o', s=20, alpha=0.25)
    ax[0].scatter(X[y==1, 0], X[y==1, 1], c='k', marker='x', s=20, alpha=0.25)

    line2[0].set_xdata(range(k))
    line2[0].set_ydata(epoch_loss[:k])
    for ax_ in ax:
        ax_.relim()
        ax_.autoscale_view()
    fig.canvas.draw()

In [None]:
pyro.enable_validation(True)
pyro.clear_param_store()
model = BayesianNNet_classifier(num_hidden=20, prior_std=5.)
print(pyro.poutine.trace(model).get_trace(x_train, y_train.squeeze(1)).format_shapes())

from pyro.infer.autoguide import AutoDiagonalNormal
guide = AutoDiagonalNormal(model)

svi = pyro.infer.SVI(model, 
                     guide, 
                     optim=pyro.optim.ClippedAdam({'lr':1e-3}),
                     loss=pyro.infer.Trace_ELBO())

epoch_loss = np.zeros(shape=(10000,))
for k in tqdm_notebook(range(len(epoch_loss))):
    epoch_loss[k] = svi.step(x_train, y_train.squeeze(1))
    if k % 100 == 0:
        predictive = pyro.infer.Predictive(model, guide=guide, num_samples=10)
        samples = predictive(torch.from_numpy(np.c_[xx.ravel(), yy.ravel()].astype('float32')))
        update_plot(samples)

In [None]:
predictive = pyro.infer.Predictive(model, 
                                   guide=guide, 
                                   num_samples=500)
samples = predictive(torch.from_numpy(np.c_[xx.ravel(), yy.ravel()].astype('float32')))

In [None]:
fig, ax = plt.subplots(1, 4, figsize=(9, 2), tight_layout=True)

for k in range(4):
    zz = samples["obs"][k].reshape(xx.shape).detach().numpy()
    ax[k].pcolormesh(xx, yy, zz, cmap=plt.cm.coolwarm)
    ax[k].scatter(X[y==0, 0], X[y==0, 1], c='k', marker='o', s=2, alpha=0.25)
    ax[k].scatter(X[y==1, 0], X[y==1, 1], c='k', marker='x', s=2, alpha=0.25)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(7, 3), tight_layout=True)

zz = samples["obs"].mean(0).reshape(xx.shape).detach().numpy()
ax[0].contourf(xx, yy, zz, cmap=plt.cm.RdBu_r, alpha=0.75, vmin=0, vmax=1)
ax[0].scatter(X[y==0, 0], X[y==0, 1], c='k', marker='o', s=5, alpha=0.25)
ax[0].scatter(X[y==1, 0], X[y==1, 1], c='k', marker='x', s=5, alpha=0.25)

      
zz = samples["obs"].std(0).reshape(xx.shape).detach().numpy()
ax[1].contourf(xx, yy, zz, cmap=plt.cm.Greys, alpha=0.75, vmin=0)
ax[1].scatter(X[y==0, 0], X[y==0, 1], c='k', marker='o', s=5, alpha=0.25)
ax[1].scatter(X[y==1, 0], X[y==1, 1], c='k', marker='x', s=5, alpha=0.25);

In [None]:
import torchvision
mnist_test = torchvision.datasets.MNIST(root='~/datasets', train=False, transform=torchvision.transforms.ToTensor())
mnist_loader = torch.utils.data.DataLoader(mnist_test, batch_size=128, shuffle=True)

In [None]:
import pyro.distributions as dist

class BayesianNNet_classifier(pyro.nn.PyroModule):
    def __init__(self, ninput=28*28, num_hidden=10):
        super().__init__()
        self.layer1 = pyro.nn.PyroModule[torch.nn.Linear](ninput, num_hidden)
        self.layer1.weight = pyro.nn.PyroSample(dist.Normal(0., 1.).expand([num_hidden, ninput]).to_event(2))
        self.layer1.bias = pyro.nn.PyroSample(dist.Normal(0., 1.).expand([num_hidden]).to_event(1))
        
        self.layer2 = pyro.nn.PyroModule[torch.nn.Linear](num_hidden, 10)
        self.layer2.weight = pyro.nn.PyroSample(dist.Normal(0., 1.).expand([10, num_hidden]).to_event(2))
        self.layer2.bias = pyro.nn.PyroSample(dist.Normal(0., 1.).expand([10]).to_event(1))
        
        self.activation = torch.nn.Tanh()

    def forward(self, x, y=None):
        p = self.layer2(self.activation(self.layer1(x))).squeeze(1)
        with pyro.plate("data", size=x.shape[0], dim=-1):
            obs = pyro.sample("obs", dist.Categorical(logits=p), obs=y)
        return p

In [None]:
pyro.enable_validation(True)
pyro.clear_param_store()
model = BayesianNNet_classifier(num_hidden=100)

from pyro.infer.autoguide import AutoDiagonalNormal
guide = AutoDiagonalNormal(model)

svi = pyro.infer.SVI(model, 
                     guide, 
                     optim=pyro.optim.ClippedAdam({'lr':1e-2}),
                     loss=pyro.infer.Trace_ELBO())


fig, ax = plt.subplots(1, 2, figsize=(7, 3), tight_layout=True)
line2 = ax[1].plot([], [])

epoch_loss = np.zeros(shape=(100,))
for k in tqdm_notebook(range(len(epoch_loss))):
    for images, labels in mnist_loader:
        # calculate the loss and take a gradient step
        epoch_loss[k] += svi.step(images.reshape(-1, 28*28), labels)
    #break    
    if k % 1 == 0:
        ax[0].cla()
        line2[0].set_xdata(range(k))
        line2[0].set_ydata(epoch_loss[:k])
        for ax_ in ax:
            ax_.relim()
            ax_.autoscale_view()
        fig.canvas.draw()

In [None]:
predictive = pyro.infer.Predictive(model, 
                                   guide=guide, 
                                   num_samples=100)
samples = predictive(mnist_test.data.reshape(-1, 28*28)/255.)

In [None]:
import ipywidgets as widgets

fig, ax = plt.subplots(1, 2, figsize=(5, 3))
idx = 0
def update(x):
    global idx
    for ax_ in ax:
        ax_.cla()
    ax[0].imshow(mnist_test.data[idx], cmap=plt.cm.Greys_r)
    res = ax[1].hist(samples['obs'][:, idx], range=(0, 10))
    H = np.sum(-res[0]*np.log(res[0]/100+1e-10)/100)
    ax[1].set_title("%0.4f" %(H))
    ax[1].set_xticks(range(10));
    idx+=1

bnext = widgets.Button(description='next')
bnext.on_click(update)
bnext