In [1]:
from jax import nn, tree_leaves, random, numpy as jnp
from jax_sgmc import data, potential, adaption, scheduler, integrator, solver
import tensorflow as tf
import haiku as hk

from jax import config
# config.update('TF_CPP_MIN_LOG_LEVEL', 0)

## Configuration parameters

In [2]:
batch_size = 32
cached_batches = 128
num_classes = 100
weight_decay = 5.e-4


## Load dataset

In [3]:
(train_images, train_labels), (test_images, test_labels) = \
    tf.keras.datasets.cifar100.load_data(label_mode='fine')

# TODO: Would be nice to use tensorflow dataloader here?
train_loader = data.NumpyDataLoader(batch_size, X=train_images, Y=train_labels)
test_loader = data.NumpyDataLoader(batch_size, X=test_images, Y=test_labels)

train_batch_fn = data.random_reference_data(train_loader, cached_batches)

# get first batch to init NN
# TODO: Maybe write convenience function for this common usecase?
batch_init, batch_get = train_batch_fn
init_batch_state = batch_init()
_, first_batch = batch_get(init_batch_state)



{'X': ShapeDtypeStruct(shape=(32, 32, 32, 3), dtype=uint8)}
{'X': ShapeDtypeStruct(shape=(32, 32, 32, 3), dtype=uint8), 'Y': ShapeDtypeStruct(shape=(32, 1), dtype=int64)}
{'X': ShapeDtypeStruct(shape=(32, 32, 32, 3), dtype=uint8)}
{'X': ShapeDtypeStruct(shape=(32, 32, 32, 3), dtype=uint8), 'Y': ShapeDtypeStruct(shape=(32, 1), dtype=int64)}


## ResNet Model

In [4]:
def init_resnet():
    @hk.transform_with_state
    def resnet(batch, is_training=True):
        images = batch["X"].astype(jnp.float32) / 255.
        resnet50 = hk.nets.ResNet50(num_classes)
        logits = resnet50(images, is_training=is_training)
        return logits
    return resnet.init, resnet.apply

init, apply_resnet = init_resnet()
init_params, init_resnet_state = init(random.PRNGKey(0), first_batch)

# test prediction
logits, _ = apply_resnet(init_params, init_resnet_state, None, first_batch)

print(jnp.sum(logits))  # I don't think this should give plain 0, otherwise gradients will be 0

0.0


## Initialize potential

Everything below is still implemented without the state!
Can we somehow provide additional arguments to likelihood?

In [None]:
def likelihood(sample, observations):
    labels = nn.one_hot(observations["Y"], num_classes)
    logits, resnet_state = apply_resnet(sample["w"], resnet_state, observations["X"])
    softmax_xent = -jnp.sum(labels * nn.log_softmax(logits))
    softmax_xent /= labels.shape[0]
    return softmax_xent


def prior(sample):
    # Implement weight decay, corresponds to Gaussian prior over weights
    weights = sample["w"]
    l2_loss = 0.5 * sum(jnp.sum(jnp.square(p)) for p in tree_leaves(weights))
    return weight_decay * l2_loss

potential_fn = potential.minibatch_potential(prior=prior,
                                             likelihood=likelihood,
                                             strategy="vmap")

## Setup Integrator


In [None]:
# Number of iterations
iterations = 50000

# Adaption strategy
rms_prop = adaption.rms_prop()

# Integrators
rms_integrator = integrator.langevin_diffusion(potential_fn,
                                               train_batch_fn,
                                               rms_prop)

# Initial value for starting
sample = {"w": init_params}

# Schedulers
rms_step_size = scheduler.polynomial_step_size_first_last(first=0.05,
                                                          last=0.001)
burn_in = scheduler.initial_burn_in(10000)
rms_random_thinning = scheduler.random_thinning(rms_step_size, burn_in, 4000)

rms_scheduler = scheduler.init_scheduler(step_size=rms_step_size,
                                         burn_in=burn_in,
                                         thinning=rms_random_thinning)

rms_sgld = solver.sgmc(rms_integrator)
rms_run = solver.mcmc(rms_sgld, rms_scheduler)
rms = rms_run(rms_integrator[0](sample), iterations=iterations)["samples"]