# Train a BNN to classify MNIST using neural SVGD

In [2]:
# Train a Bayesian neural network to classify MNIST using
# Neural SVGD
#
# If using pmap, set the environment variable
# `export XLA_FLAGS="--xla_force_host_platform_device_count=8"`
# before running on CPU (this enables pmap to "see" multiple cores).
import sys
sys.path.append("../../learning_particle_gradients/")
sys.path.append("../../experiments/")
from functools import partial
from itertools import cycle

import numpy as onp
import jax
from jax import numpy as jnp
from jax import jit, grad, value_and_grad, vmap, pmap, config, random
config.update("jax_debug_nans", False)
from jax.ops import index_update, index
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.calibration import calibration_curve
import tensorflow_datasets as tfds

import haiku as hk
import optax

import nets
import utils
import models
from convnet import model, accuracy, crossentropy_loss, log_prior

# Config
key = random.PRNGKey(0)
EPOCHS = 1
BATCH_SIZE = 2
LEARNING_RATE = 1e-8
META_LEARNING_RATE = 1e-4
NUM_SAMPLES = 4
DISABLE_PROGRESS_BAR = False
USE_PMAP = False

if USE_PMAP:
    vpmap = pmap
else:
    vpmap = vmap

# Load MNIST
data_dir = '/tmp/tfds'
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']

# Full train and test set
train_images, train_labels = train_data['image'], train_data['label']
test_images, test_labels = test_data['image'], test_data['label']

# Split off the validation set
train_images, val_images, train_labels, val_labels = train_test_split(
    train_images, train_labels, test_size=0.1, random_state=0)
data_size = len(train_images)


def make_batches(images, labels, batch_size):
    """Returns an iterator that cycles through 
    tuples (image_batch, label_batch)."""
    num_batches = len(images) // batch_size
    split_idx = onp.arange(1, num_batches+1)*batch_size
    batches = zip(*[onp.split(data, split_idx, axis=0) for data in (images, labels)])
    return cycle(batches)


def loss(params, images, labels):
    """Minibatch approximation of the (unnormalized) Bayesian
    negative log-posterior evaluated at `params`. That is,
    -log model_likelihood(data_batch | params) * batch_rescaling_constant - log prior(params))"""
    logits = model.apply(params, images)
    return data_size/BATCH_SIZE * crossentropy_loss(logits, labels) -  log_prior(params) 


@jit
def ensemble_accuracy(param_set):
    """use ensemble predictions to compute validation accuracy"""
    vapply = vpmap(model.apply, (0, None))
    logits = vapply(param_set, val_images[:BATCH_SIZE])
    preds = jnp.mean(vmap(jax.nn.softmax)(logits), axis=0) # mean prediction
    return jnp.mean(preds.argmax(axis=1) == val_labels[:BATCH_SIZE])




# Neural SVGD Model
* Input: model parameters
* Output: 'gradient' of same shape as parameters

### Memory
Assume: BNN model has 50.000 parameters, and NSVGD layers have hidden dimensions `[1024, 1024]`. Then the NSVGD (meta) model has `2*50.000*1024` parameters (each 8 bytes), i.e. about 800MB.

The activations: `50.000 + 1024 + 1024 + 50.000` floats, i.e. about 800KB, times `NUM_SAMPLES`.

In [3]:
# initialize model parameters so we know their shape
key, subkey = random.split(key)
params_tree = model.init(subkey, train_images[:2])
params_flat, unravel = jax.flatten_util.ravel_pytree(params_tree)

# def model_fn(params):
#     mlp = nets.MLP([1024, 1024, params_flat.shape[0]])
#     scale = hk.get_parameter("scale", (), init=lambda *args: jnp.ones(*args))
#     return scale*mlp(params)

# dynamics_model = hk.without_apply_rng(hk.transform(model_fn))
# dyn_params = dynamics_model.init(subkey, params_flat)

# Training with Neural SVGD

In [4]:
def ravel(tree):
    return jax.flatten_util.ravel_pytree(tree)[0]


def init_flat_params(key):
    return ravel(model.init(key, train_images[:2]))


def get_minibatch_loss(batch):
    """
    args:
        batch = (images, labels)

    Returns a callable that computes target posterior
    given flattened param vector.
    """
    def minibatch_loss(params_flat):
        return loss(unravel(params_flat), *batch)
    return minibatch_loss

In [5]:
key, subkey = random.split(key)
init_particles = vmap(init_flat_params)(random.split(subkey, NUM_SAMPLES))

opt = optax.sgd(LEARNING_RATE)
# opt_state = opt.init(init_particles)

In [8]:
key1, key2 = random.split(key)
neural_grad = models.SDLearner(target_dim=init_particles.shape[1],
                               get_target_logp=get_minibatch_loss,
                               learning_rate=META_LEARNING_RATE,
                               key=key1,
                               sizes=[1024, 1024, init_particles.shape[1]],
                               aux=False)
particles = models.Particles(key2, neural_grad.gradient, init_particles, custom_optimizer=opt)

In [9]:
def sample_tv(key):
    return vmap(init_flat_params)(random.split(subkey, NUM_SAMPLES)).split(2)

batches = make_batches(train_images, train_labels, BATCH_SIZE)

In [1]:
# Warmup on first batch
neural_grad.train(next_batch=sample_tv,
                  n_steps=3, # 100
                  early_stopping=False,
                  data=next(batches))

NameError: name 'neural_grad' is not defined

In [None]:
next_particles = partial(particles.next_batch)
test_batches = get_batches(x_test, y_test, 2*NUM_VALS) if full_data else get_batches(x_val, y_val, 2*NUM_VALS)
train_batches = get_batches(xx, yy, NUM_STEPS+1) if full_data else get_batches(x_train, y_train, NUM_STEPS+1)

In [None]:
for i, data_batch in tqdm(enumerate(train_batches), total=NUM_STEPS, disable=not progress_bar):
    neural_grad.train(next_batch=next_particles, n_steps=10, data=data_batch)
    particles.step(neural_grad.get_params())
    if i % (NUM_STEPS//NUM_VALS)==0:
        test_logp = get_minibatch_logp(*next(test_batches))
        train_logp = get_minibatch_logp(*data_batch)
        stepdata = {
            "accuracy": compute_test_accuracy(unravel(particles.particles.training)[0]),
            "test_logp": test_logp(particles.particles.training),
            "training_logp": train_logp(particles.particles.training),
        }
        metrics.append_to_log(particles.rundata, stepdata)
neural_grad.done()
particles.done()

In [None]:


@jit
def step(param_set, opt_state, images, labels):
    """Update param_set elements in parallel using Langevin dynamics."""
    step_losses, g = vpmap(value_and_grad(loss), (0, None, None))(param_set, images, labels)
    g, opt_state = opt.update(g, opt_state, param_set)
    return optax.apply_updates(param_set, g), opt_state, step_losses

In [6]:
# initialize set of parameters
key, subkey = random.split(key)
param_set = vmap(model.init, (0, None))(random.split(subkey, NUM_SAMPLES), train_images[:5])
opt_state = opt.init(param_set)


losses = []
accuracies = []
batches = make_batches(train_images, train_labels, BATCH_SIZE)
n_train_steps = EPOCHS * data_size // BATCH_SIZE