In [1]:
from tqdm import tqdm_notebook as tqdm

import numpy as np

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfe = tf.contrib.eager

from sonnet import Linear, AbstractModule, BatchFlatten, reuse_variables
tf.enable_eager_execution()


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



In [2]:
class HmcNet(AbstractModule):
    
    def __init__(self, 
                 hidden_units=200,
                 prior_sigma=1.,
                 mass=1.,
                 name="hmc_net"):
        
        super(HmcNet, self).__init__(name=name)
        
        self._hidden_units = hidden_units
        self.prior_sigma = prior_sigma
        self.mass = mass
        
    @reuse_variables
    def log_prob(self, labels):
        self._ensure_is_connected()
        
        return tf.reduce_sum(self._likelihood.log_prob(labels))
    
    
    def _build(self, inputs):
        
        flatten = BatchFlatten()
        flattened = flatten(inputs)
        
        linear1 = Linear(output_size=self._hidden_units)
        activations = tf.nn.relu(linear1(flattened))
        
        linear2 = Linear(output_size=self._hidden_units)
        activations = tf.nn.relu(linear2(activations))
        
        linear_out = Linear(output_size=10) 
        logits = linear_out(activations)

        # Regularising term
        self.prior_term = 0
        for variable in tf.trainable_variables():
            self.prior_term += tf.reduce_sum(variable**2)
        
        self._likelihood = tfd.Categorical(logits=logits)
        
        return logits
    
    def potential_energy(self, labels):

        return self.log_prob(labels) + (0.5 / self.prior_sigma**2) * self.prior_term


    def hamiltonian(self, potential, momenta):

        s = 0
        
        for k, v in momenta.items():
            s = s + tf.reduce_sum(v**2)

        return potential + 0.5 * s / self.mass

In [3]:
def mnist_dataset(data, labels, batch_size=128, shuffle_buffer=2000):
    
    mnist_ds = tf.data.Dataset.from_tensor_slices((data, labels))
    mnist_ds = mnist_ds.map(mnist_process)
    mnist_ds = mnist_ds.shuffle(buffer_size=shuffle_buffer)
    mnist_ds = mnist_ds.batch(batch_size)
    
    return mnist_ds

def mnist_process(data, label):
    
    return tf.cast(data, tf.float32) / 255., tf.cast(label, tf.int64)
    

In [4]:
((train_data, train_labels),
(test_data, test_labels)) = tf.keras.datasets.mnist.load_data()

In [5]:
num_epochs = 2

hmc_net = HmcNet(hidden_units=200)

train_dataset = mnist_dataset(train_data, train_labels)

optimizer = tf.train.GradientDescentOptimizer(3e-4)

accuracy_metric = tfe.metrics.Accuracy()

for epoch in range(num_epochs):

    print("Epoch {}: ".format(epoch + 1))
    for data_batch, label_batch in train_dataset:
    
        with tf.GradientTape() as tape:
            
            logits = hmc_net(data_batch)
        
            loss = -hmc_net.log_prob(label_batch) + hmc_net.prior_term
        
        grads = tape.gradient(loss, hmc_net.get_all_variables())
        optimizer.apply_gradients(zip(grads, hmc_net.get_all_variables()))
        
        accuracy_metric(labels=label_batch,
                        predictions=tf.argmax(logits, axis=1))
        
        print("Loss: {:.2f}, Accuracy: {:.2f} \r".format(loss, accuracy_metric.result() * 100), end="")
        
    print()

Epoch 1: 
Instructions for updating:
Colocations handled automatically by placer.
Loss: 241.21, Accuracy: 79.92 
Epoch 2: 
Loss: 169.32, Accuracy: 84.81 


In [6]:
with tf.GradientTape(persistent=True) as tape:
    hmc_net(data_batch)
    pe = hmc_net.potential_energy(label_batch)
    
    
for var in hmc_net.get_all_variables():
    print(hash(var))
    
del tape

-9223372036538493538
316282260
-9223372036538365934
-9223372036538365948
316409941
316409913


In [5]:
def run_dynamics(hmc_net, momenta, x_train, y_train, epsilon, num_steps):
    """
    Leapfrog integration of the dynamics for

    time = num_steps * epsilon
    """

    # Computing dE_dz
    with tf.GradientTape(persistent=True) as tape:
        logits = hmc_net(x_train)
        E = hmc_net.potential_energy(y_train)

    # Sample integration direction
    epsilon = epsilon * (2 * tf.cast(tf.random.categorical(tf.log([[0.5, 0.5]]), 1), tf.float32) - 1)[0, 0]

    # First leapfrog update
    for var in tf.trainable_variables():
        momenta[var] -= (epsilon / 2.) * tape.gradient(E, var)

    # delete tape here
    del tape

    for i in range(num_steps - 1):

        # Middle leapfrog steps
        # z = z + epsilon * r
        for var in tf.trainable_variables():
            var.assign_add(epsilon * momenta[var] / hmc_net.mass)

        # Computing dE_dz
        with tf.GradientTape(persistent=True) as tape:
            logits = hmc_net(x_train)
            E = hmc_net.potential_energy(y_train)

        # r = r - epsilon * dE_dz(z)
        for var in tf.trainable_variables():
            momenta[var] -= epsilon * tape.gradient(E, var)

        # delete tape here
        del tape

    # Final leapfrog steps
    # z = z + epsilon * r
    for var in tf.trainable_variables():
        var.assign_add(epsilon * tf.squeeze(momenta[var]) / hmc_net.mass)

    # Computing dE_dz
    with tf.GradientTape(persistent=True) as tape:
        logits = hmc_net(x_train)
        E = hmc_net.potential_energy(y_train)

    # r = r - epsilon * dE_dz(z)
    for var in tf.trainable_variables():
        momenta[var] -= (epsilon / 2.) * tape.gradient(E, var)

    # delete tape here
    del tape

    return momenta


def hmc_sample(hmc_net, x_train, y_train, mixing_time=50, burn_in_time=10,
               num_integral_steps=10, num_samples=100, epsilon=1e-3, log_every=1):

    num_iters = num_samples * mixing_time + burn_in_time
    num_accepted = 0

    accuracy_metric = tfe.metrics.Accuracy()
    
    for i in tqdm(range(num_iters)):

        vars_0 = {v: tf.identity(v) for v in tf.trainable_variables()}
        momenta_0 = {v: tf.random.normal(v.shape, stddev=hmc_net.mass**0.5)
                     for v in tf.trainable_variables()}
        

        hmc_net(x_train)
        H_0 = hmc_net.hamiltonian(hmc_net.potential_energy(y_train), momenta_0)

        momenta = run_dynamics(hmc_net=hmc_net,
                               momenta=momenta_0,
                               x_train=x_train,
                               y_train=y_train,
                               epsilon=epsilon,
                               num_steps=num_integral_steps)

        hmc_net(x_train)
        H = hmc_net.hamiltonian(hmc_net.potential_energy(y_train), momenta)

        threshold = tf.minimum(1, tf.exp(H_0 - H))

        u = tf.random.uniform(shape=(1,), minval=0., maxval=1.)
        is_accepted = u <= threshold

        if not is_accepted:
            for var in tf.trainable_variables():
                var.assign(vars_0[var])
        else:
            num_accepted += 1


        if i % log_every == 0:
            predictions = tf.argmax(hmc_net(x_train), axis=1)
            accuracy_metric(predictions=predictions,
                            labels=y_train)

            print('Train accuracy {:.3}, accepted {} out of {}'.format(accuracy_metric.result(), num_accepted, i + 1))

In [None]:
hmc_net = HmcNet(hidden_units=100, mass=1e-2)

train_dataset = mnist_dataset(train_data, train_labels, batch_size=128)

for data_batch, label_batch in train_dataset.take(1):
    hmc_net(data_batch)

    hmc_sample(hmc_net=hmc_net,
               x_train=data_batch, 
               y_train=label_batch,
               mixing_time=100, 
               burn_in_time=10,
               num_integral_steps=10, 
               num_samples=5, 
               epsilon=1e-4, 
               log_every=1)

HBox(children=(IntProgress(value=0, max=510), HTML(value='')))