# Train a BNN to classify MNIST using SVGD

In [7]:
%load_ext autoreload
# for leonhard
import os
try:
    os.environ['XLA_FLAGS'] = "--xla_gpu_cuda_data_dir=" + os.environ["CUDA_HOME"]
except KeyError:
    pass

# Train a Bayesian neural network to classify MNIST using
# Neural SVGD
#
# If using pmap, set the environment variable
# `export XLA_FLAGS="--xla_force_host_platform_device_count=8"`
# before running on CPU (this enables pmap to "see" multiple cores).
import sys
import os
on_cluster = not os.getenv("HOME") == "/home/lauro"
if on_cluster:
    sys.path.append("/cluster/home/dlauro/projects-2020-Neural-SVGD/learning_particle_gradients/")
sys.path.append("../../experiments/")

import argparse
import matplotlib.pyplot as plt
from jax import vmap, random
import jax.numpy as jnp
import numpy as onp
from tqdm import tqdm
import optax
import bnn
import models
import metrics
import mnist
import config as cfg
import utils
from jax import jit, grad, value_and_grad

# Config
key = random.PRNGKey(0)

NUM_SAMPLES = 100
DISABLE_PROGRESS_BAR = False
USE_PMAP = False

BATCH_SIZE = 128
LAMBDA_REG = 1/2
STEP_SIZE = 1e-5
PATIENCE = 15
MAX_TRAIN_STEPS = 50

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
%autoreload

# init particles and dynamics model

In [9]:
from svgd_bnn import train as train_svgd

In [10]:
final_accs = []

In [11]:
lrs = [1e-6, 5e-6, 8e-6, 1e-5, 2e-5, 5e-5]

In [13]:
key, subkey = random.split(key)

In [None]:
for particle_stepsize in tqdm(lrs):
    final_acc = train_svgd(key=subkey,
                               particle_stepsize=particle_stepsize,
                               n_iter=200,
                               evaluate_every=-1,
                               results_file="/dev/null",
                               optimizer="sgd")
    final_accs.append((final_acc, particle_stepsize))

  0%|          | 0/6 [00:00<?, ?it/s]

Training...
Starting epoch 1


In [None]:
fa = onp.array(final_accs)
fa

In [None]:
fa[:, 1]

In [None]:
fa[:, 0]

In [None]:
STEP_SIZE = fa[fa[:, 0].argmax(), 1]

In [None]:
STEP_SIZE

In [None]:
def init_particles_fn(subkey):
    init_particles = vmap(bnn.init_flat_params)(random.split(subkey, NUM_SAMPLES))
    return init_particles


key, subkey = random.split(key)
init_particles = init_particles_fn(subkey)
opt = optax.sgd(STEP_SIZE)

key, subkey1, subkey2 = random.split(key, 3)
svgd_grad = models.KernelGradient(get_target_logp=bnn.get_minibatch_logp,
                                  scaled=False)

particles = models.Particles(key=subkey2,
                             gradient=svgd_grad.gradient,
                             init_samples=init_particles,
                             custom_optimizer=opt)

# minibatch_vdlogp = jit(vmap(value_and_grad(bnn.minibatch_logp), (0, None)))

@jit
def compute_eval(step_counter, ps, loglikelihood):
    stepdata = {
        "accuracy": (step_counter, bnn.compute_acc_from_flat(ps)),
        "particle_mean": (step_counter, ps.mean()),
        "loglikelihood": loglikelihood.mean(),
    }
    return stepdata


SGLD_STEPSIZE = 5e-8
print('SGLD noise   :', jnp.sqrt(2*SGLD_STEPSIZE))
print('NVGD stepsize:', STEP_SIZE)
sgld = utils.sgld(SGLD_STEPSIZE)
sgld_state = sgld.init(init_particles)


@jit
def sgld_step(particles, dlogp, sgld_state):
    """Update param_set elements in parallel using Langevin dynamics."""
    g, sgld_state = sgld.update(-dlogp, sgld_state, particles)
    particles = optax.apply_updates(particles, g)
    aux = {
        "global_grad_norm": optax.global_norm(g),
    }
    return particles, sgld_state, aux

step_counter = 0

# num_steps = EPOCHS * data_size // BATCH_SIZE // 5
num_steps = 200
sgld_aux = {}
for _ in tqdm(range(num_steps)):
    step_counter += 1
    train_batch = next(mnist.training_batches)
    particles.step(train_batch)

    if step_counter % 10 == 0:
        metrics.append_to_log(particles.rundata,
                              compute_eval(step_counter,
                                           particles.particles,
                                           jnp.array(1)))
        
# neural_grad.done()
# particles.done()

In [None]:
fig, ax = plt.subplots(figsize=[15, 5])
ax.plot(*zip(*particles.rundata['accuracy']), "--.", label="accuracy")

In [None]:
particles.rundata['accuracy'][-5:]

In [None]:
particles.rundata.keys()

In [None]:
trajectories = onp.array(particles.rundata['particles'])
trajectories.shape

# visualize trajectory avg across dimensions (distinguish particles)
fig, axs = plt.subplots(2, 1, figsize=[10, 8])

ax = axs[0]
ax.plot(trajectories.mean(axis=2));  # avg across dims

ax = axs[1]
ax.plot(trajectories[:, :, 1]);  # watch single param (aka single dimension)