# Train a BNN to classify MNIST using neural SVGD

In [1]:
# for leonhard
import os
os.environ['XLA_FLAGS'] = "--xla_gpu_cuda_data_dir=" + os.environ["CUDA_HOME"]
os.environ['XLA_FLAGS']

'--xla_gpu_cuda_data_dir=/cluster/apps/gcc-6.3.0/cuda-10.1.243-n6qg6z5js3zfnhp2cfg5yjccej636czm'

In [2]:
# Train a Bayesian neural network to classify MNIST using
# Neural SVGD
#
# If using pmap, set the environment variable
# `export XLA_FLAGS="--xla_force_host_platform_device_count=8"`
# before running on CPU (this enables pmap to "see" multiple cores).
%load_ext autoreload
import sys
import os
on_cluster = not os.getenv("HOME") == "/home/lauro"
if on_cluster:
    sys.path.append("/cluster/home/dlauro/projects-2020-Neural-SVGD/learning_particle_gradients/")
sys.path.append("../../experiments/")


from functools import partial
from itertools import cycle

import numpy as onp
import jax
from jax import numpy as jnp
from jax import jit, grad, value_and_grad, vmap, pmap, config, random
config.update("jax_debug_nans", False)
from jax.ops import index_update, index
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.calibration import calibration_curve
import tensorflow_datasets as tfds

import haiku as hk
import optax

import nets
import utils
import models
import metrics
from convnet import make_model, accuracy, crossentropy_loss, log_prior, ensemble_accuracy

import os
model = make_model("large")

# Config
key = random.PRNGKey(0)
EPOCHS = 1
# NUM_VALS = 20

META_LEARNING_RATE = 1e-3

NUM_SAMPLES = 100
DISABLE_PROGRESS_BAR = False
USE_PMAP = False

BATCH_SIZE = 128
LAMBDA_REG = 10**2
STEP_SIZE = 1e-3 #1e-3, 1e-4 too large (w adam)
PATIENCE = 5
MAX_TRAIN_STEPS = 100 #15


if USE_PMAP:
    vpmap = pmap
else:
    vpmap = vmap

# Load MNIST
data_dir = './data' if on_cluster else '/tmp/tfds'
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']

# Full train and test set
train_images, train_labels = train_data['image'], train_data['label']
test_images, test_labels = test_data['image'], test_data['label']

# Split off the validation set
train_images, val_images, train_labels, val_labels = train_test_split(
    train_images, train_labels, test_size=0.1, random_state=0)
data_size = len(train_images)

def make_batches(images, labels, batch_size, cyclic=True):
    """Returns an iterator through tuples (image_batch, label_batch).
    if cyclic, then the iterator cycles back after exhausting the batches"""
    num_batches = len(images) // batch_size
    split_idx = onp.arange(1, num_batches+1)*batch_size
    batches = zip(*[onp.split(data, split_idx, axis=0) for data in (images, labels)])
    return cycle(batches) if cyclic else list(batches)

def loss(params, images, labels):
    """Minibatch approximation of the (unnormalized) Bayesian
    negative log-posterior evaluated at `params`. That is,
    -log model_likelihood(data_batch | params) * batch_rescaling_constant - log prior(params))"""
    logits = model.apply(params, images)
    return data_size/BATCH_SIZE * crossentropy_loss(logits, labels) - log_prior(params)

In [3]:
%autoreload

# Neural SVGD Model
* Input: model parameters
* Output: 'gradient' of same shape as parameters

### Memory
Assume: convnet model has 50.000 parameters, and NSVGD layers have hidden dimensions `[1024, 1024]`. Then the NSVGD (meta) model has `2*50.000*1024` parameters (each 8 bytes), i.e. about 800MB.

The activations: `50.000 + 1024 + 1024 + 50.000` floats, i.e. about 800KB, times `NUM_SAMPLES`.

In [4]:
# Utility functions for dealing with parameters
key, subkey = random.split(key)
params_tree = model.init(subkey, train_images[:2])
params_flat, unravel = jax.flatten_util.ravel_pytree(params_tree)


def ravel(tree):
    return jax.flatten_util.ravel_pytree(tree)[0]


def init_flat_params(key):
    return ravel(model.init(key, train_images[:2]))


def get_minibatch_logp(batch):
    """
    args:
        batch = (images, labels)

    Returns a callable that computes target posterior
    given flattened param vector.
    """
    def minibatch_logp(params_flat):
        return -loss(unravel(params_flat), *batch)
    return minibatch_logp


def sample_tv(key):
    """return two sets of particles at initialization, for
    training and validation in the warmup phase"""
    return vmap(init_flat_params)(random.split(key, NUM_SAMPLES)).split(2)


# @jit
# def acc(param_set_flat):
#     param_set = vmap(unravel)(param_set_flat)
#     logits = vmap(model.apply, (0, None))(param_set, val_images[:BATCH_SIZE])
#     return ensemble_accuracy(logits, val_labels[:BATCH_SIZE])


def vmean(fun):
    """vmap, but computes mean along mapped axis"""
    def compute_mean(*args, **kwargs):
        return jnp.mean(vmap(fun)(*args, **kwargs), axis=-1)
    return compute_mean

In [5]:
validation_batches = make_batches(val_images, val_labels, 512, cyclic=False) # 11 validation batches


def minibatch_accuracy(param_set_flat, images, labels):
    param_set = vmap(unravel)(param_set_flat)
    logits = vmap(model.apply, (0, None))(param_set, images)
    return ensemble_accuracy(logits, labels)


def compute_acc(param_set_flat):
    accs = []
    for batch in validation_batches:
        accs.append(minibatch_accuracy(param_set_flat, *batch))
    return jnp.mean(jnp.array(accs))

In [6]:
train_batches = make_batches(train_images, train_labels, BATCH_SIZE)
test_batches  = make_batches(test_images,  test_labels, BATCH_SIZE)

# init particles and dynamics model

In [7]:
key, subkey = random.split(key)
init_particles = vmap(init_flat_params)(random.split(subkey, NUM_SAMPLES))

opt = optax.sgd(STEP_SIZE)

# opt = optax.chain(
#     optax.scale_by_adam(),
#     optax.scale(-STEP_SIZE),
# )

key, subkey1, subkey2 = random.split(key, 3)
neural_grad = models.SDLearner(target_dim=init_particles.shape[1],
                               get_target_logp=get_minibatch_logp,
                               learning_rate=META_LEARNING_RATE,
                               key=subkey1,
                               sizes=[64, 64, init_particles.shape[1]],
                               aux=False,
                               use_hutchinson=True,
                               lambda_reg=LAMBDA_REG,
                               patience=PATIENCE,
                               dropout=True)
particles = models.Particles(subkey2, neural_grad.gradient, init_particles, custom_optimizer=opt)

In [9]:
init_particles.shape

(100, 24480)

In [8]:
# maximal initial stein discrepancy (theoretical ideal)
import plot
import stein
from jax.scipy import stats

n, d = init_particles.shape
n = 100
key, subkey = random.split(key)
xs = random.normal(subkey, (n, d)) / 100

first_batch = next(train_batches)
images, labels = first_batch

def logp(x):
    def loglikelihood(x):
        params = unravel(x)
        logits = model.apply(params, images)
        return -data_size/BATCH_SIZE * crossentropy_loss(logits, labels)
    return loglikelihood(x) + log_prior(unravel(x))


def logq(x):
    return stats.norm.logpdf(x, loc=0, scale=1/100).sum()


def f(x):
    def loglikelihood(x):
        params = unravel(x)
        logits = model.apply(params, images)
        return -data_size/BATCH_SIZE * crossentropy_loss(logits, labels) / (2*LAMBDA_REG)
    return grad(loglikelihood)(x)

# double-check stein discrepancy

# a) true
l2 = utils.l2_norm_squared(xs, f) # = sd(f*) / (2 LAMBDA_REG)
true_sd = 2 * LAMBDA_REG * l2
min_loss = -l2 * LAMBDA_REG

# b) hutchinson
key, subkey = random.split(key)
random_estimate_sd = stein.stein_discrepancy_hutchinson(subkey, xs, logp, f)

print("analytical sd:", true_sd)
print("hutchinson estimate:", random_estimate_sd)
print("ratio:", true_sd / random_estimate_sd)

RuntimeError: Resource exhausted: Out of memory while trying to allocate 1284505600 bytes.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

# train

In [None]:
# Warmup on first batch
NUM_WARMUP_ITER = 5
key, subkey = random.split(key)
neural_grad.warmup(key=subkey,
                   sample_split_particles=sample_tv,
                   next_data=lambda: first_batch,
                   n_iter=NUM_WARMUP_ITER,
                   n_inner_steps=100,
                   progress_bar=True)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=[15, 5])

ax = axs[0]
ax.plot(neural_grad.rundata['training_loss'])
ax.plot(neural_grad.rundata['validation_loss'])
ax.axhline(y=min_loss, label="true first-batch loss", linestyle="--")

ax = axs[1]
ax.plot(neural_grad.rundata['training_sd'])
ax.plot(neural_grad.rundata['validation_sd'])
ax.axhline(y=true_sd, linestyle="--", label="true first-batch stein discrepancy")

ax = axs[2]
ax.plot(neural_grad.rundata['l2_norm'])
ax.axhline(y=l2, linestyle="--", label="true l2 norm")

In [None]:
neural_grad.rundata['train_steps']

In [None]:
N_PARTICLE_STEPS = 5

key, fixed_key = random.split(key)
def step(key, train_batch):
    """one iteration of the particle trajectory simulation"""
    neural_grad.train(
        split_particles=particles.next_batch(key),
        n_steps=MAX_TRAIN_STEPS,
        data=train_batch
    )
    particles.step(neural_grad.get_params())


@jit
def compute_eval(train_batch, ps):
    test_logp = vmean(get_minibatch_logp(next(test_batches)))
    train_logp = vmean(get_minibatch_logp(train_batch))
    stepdata = {
        "accuracy": compute_acc(ps),
        "test_logp": test_logp(ps),
        "training_logp": train_logp(ps),
        "particle_mean": ps.mean(),
    }
    return stepdata

* currently testing what happens if we **fix** train/val splits
    * outcome: no overfit!

In [None]:
# num_steps = EPOCHS * data_size // BATCH_SIZE // 5
num_steps = 101
for step_counter in tqdm(range(num_steps)):
    key, subkey = random.split(key)
    train_batch = next(train_batches)
    step(subkey, train_batch)
#     if step_counter % (num_steps//NUM_VALS) == 0:
    if step_counter % 2 == 0:
        metrics.append_to_log(particles.rundata,
                              compute_eval(train_batch, particles.particles))
# neural_grad.done()
# particles.done()

In [None]:
fig, axs = plt.subplots(1, 3, figsize=[15, 5])

ax = axs[0]
ax.plot(particles.rundata["accuracy"], "--.")

ax = axs[1]
ax.plot(neural_grad.rundata['training_loss'], ".", label="training loss")
ax.plot(neural_grad.rundata['validation_loss'], ".", label="validation loss")
ax.axhline(y=min_loss, label="loss should never go below here", linestyle="--")
ax.legend()

ax = axs[2]
ax.plot(particles.rundata['training_logp'])
# ax.plot(particles.rundata['test_logp'])

In [None]:
fig, axs = plt.subplots(1, 3, figsize=[15, 5])

ax = axs[0]
ax.plot(particles.rundata["accuracy"], "--.")

ax = axs[1]
ax.plot(neural_grad.rundata['training_loss'], ".", label="training loss")
ax.plot(neural_grad.rundata['validation_loss'], ".", label="validation loss")
ax.axhline(y=min_loss, label="loss should never go below here", linestyle="--")
ax.legend()

ax = axs[2]
ax.plot(particles.rundata['training_logp'])
# ax.plot(particles.rundata['test_logp'])

In [None]:
onp.mean(particles.rundata["accuracy"][-10:])

In [None]:
plt.plot(neural_grad.rundata["global_gradient_norm"])

In [None]:
particles.rundata.keys()

In [None]:
trajectories = onp.array(particles.rundata['particles'])
trajectories.shape

In [None]:
# visualize trajectory avg across dimensions (distinguish particles)
fig, axs = plt.subplots(2, 1, figsize=[10, 8])

ax = axs[0]
ax.plot(trajectories.mean(axis=2));  # avg across dims

ax = axs[1]
ax.plot(trajectories[:, :, 1]);  # watch single param (aka single dimension)

In [None]:
# visualize trajectory avg across particles (distinguish dims, ie parameters)
fig, axs = plt.subplots(2, 1, figsize=[10, 8])

ax = axs[0]
ax.plot(trajectories.mean(axis=1)); # avg across particles

ax = axs[1]
ax.plot(trajectories[:, 11, :]); # watch single particle

In [None]:
# generate markings based on num_steps
plt.subplots(figsize=[15, 5])
markings = onp.cumsum(neural_grad.rundata['train_steps'])[NUM_WARMUP_ITER:]

plt.plot(neural_grad.rundata['training_loss'])
plt.plot(neural_grad.rundata['validation_loss'])
plt.scatter(markings, onp.zeros(len(markings)))
plt.axhline(y=min_loss, label="loss should never go below here", linestyle="--", color="g")

In [None]:
a = 800
b = 1000
marker_a = (markings < a).sum()
marker_b = (markings < b).sum()
print(f"step {marker_a} to step {marker_b}.")

fig, axs = plt.subplots(2, figsize=[10, 10])

ax = axs[0]
ax.plot(neural_grad.rundata['training_loss'][a:b], ".")
ax.plot(neural_grad.rundata['validation_loss'][a:b], ".")
ax.scatter(markings[marker_a:marker_b] - a, onp.zeros(marker_b - marker_a))
ax.axhline(y=min_loss, label="loss should never go below here", linestyle="--", color="g")
ax.legend()
# ax.set_yscale('log')

ax = axs[1]
ax.plot(neural_grad.rundata['l2_norm'][a:b])
ax.scatter(markings[marker_a:marker_b] - a, onp.zeros(marker_b - marker_a))
# ax.set_yscale('log')

# particle gradient norms

In [None]:
gs = neural_grad.grads(particles.particles)
optax.global_norm(gs)

In [None]:
print("norm pre upate :", particles.rundata['global_grad_norm'][0])
print("norm post upate:", particles.rundata['global_grad_norm_post_update'][0])

In [None]:
plt.plot(particles.rundata['global_grad_norm'])

In [None]:
plt.plot(particles.rundata['global_grad_norm_post_update'])

In [None]:
optax.global_norm(vmap(unravel)(particles.particles))

In [None]:
jnp.sqrt(jnp.sum(particles.particles**2))

# model gradient norms

### interlude: why accuracy no change?

In [None]:
all_preds = vmap(lambda ps: model.apply(unravel(ps), val_images[:128]).argmax(axis=1))(particles.rundata['particles'][-1])
onp.unique(all_preds, return_counts=True)

### continue with scheduled programming