# MNIST VIB Example

Here I demonstrate the Variational Information Bottleneck method on the MNIST dataset.

In [25]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf
import sys
import math

layers = tf.contrib.layers
ds = tf.contrib.distributions

In [26]:
tf.reset_default_graph()

# Turn on xla optimization
config = tf.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
sess = tf.InteractiveSession(config=config)

In [27]:
from tensorflow.examples.tutorials.mnist import input_data
mnist_data = input_data.read_data_sets('/tmp/mnistdata', validation_size=0)

Extracting /tmp/mnistdata\train-images-idx3-ubyte.gz
Extracting /tmp/mnistdata\train-labels-idx1-ubyte.gz
Extracting /tmp/mnistdata\t10k-images-idx3-ubyte.gz
Extracting /tmp/mnistdata\t10k-labels-idx1-ubyte.gz


In [28]:
"""mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
mnist"""

'mnist = tf.keras.datasets.mnist\n(x_train, y_train),(x_test, y_test) = mnist.load_data()\nmnist'

In [29]:
# data
images = tf.placeholder(tf.float32, [None, 784], 'images')
labels = tf.placeholder(tf.int64, [None], 'labels')
one_hot_labels = tf.one_hot(labels, 10)

In [None]:
# model

def encoder(images):
    net = layers.relu(2*images-1, 1024)
    net = layers.relu(net, 1024)
    params = layers.linear(net, 512)
    mu, rho = params[:, :256], params[:, 256:]
    encoding = ds.NormalWithSoftplusScale(mu, rho - 5.0)
    return encoding

with tf.variable_scope('encoder'):  # not important for the maths
    encoding = encoder(images)
    
def decoder(encoding_sample):
    net = layers.linear(encoding_sample, 10)
    return net

with tf.variable_scope('decoder'):
    logits = decoder(encoding.sample())  
    # logits are real numbers that will then be put into a softmax function to produce probabilities

In [30]:
# the second term in the loss 

prior = ds.Normal(0.0, 1.0)    
    
info_loss = tf.reduce_sum(tf.reduce_mean(ds.kl_divergence(encoding, prior), axis=0)) / math.log(2)

IZX_bound = info_loss  

In [31]:
# the first term in the loss 

class_loss = tf.losses.softmax_cross_entropy(
    logits=logits, onehot_labels=one_hot_labels) / math.log(2)

IZY_bound = math.log(10, 2) - class_loss  # first term is H(Y)

In [None]:
# calculating loss J

BETA = 1e-3    

total_loss = class_loss + BETA * info_loss  

In [None]:
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32))

with tf.variable_scope('decoder', reuse=True):
    many_logits = decoder(encoding.sample(12))

avg_accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(tf.reduce_mean(tf.nn.softmax(many_logits), 0), 1), labels), tf.float32))

In [None]:
batch_size = 100
steps_per_batch = int(mnist_data.train.num_examples / batch_size)

In [None]:
global_step = tf.contrib.framework.get_or_create_global_step()
learning_rate = tf.train.exponential_decay(1e-4, global_step,
                                           decay_steps=2*steps_per_batch,
                                           decay_rate=0.97, staircase=True)
opt = tf.train.AdamOptimizer(learning_rate, 0.5)

ma = tf.train.ExponentialMovingAverage(0.999, zero_debias=True)
ma_update = ma.apply(tf.model_variables())

saver = tf.train.Saver()
saver_polyak = tf.train.Saver(ma.variables_to_restore())

train_tensor = tf.contrib.training.create_train_op(total_loss, opt,
                                                   global_step,
                                                   update_ops=[ma_update])

In [None]:
tf.global_variables_initializer().run()

In [13]:
def evaluate():
    IZY, IZX, acc, avg_acc = sess.run([IZY_bound, IZX_bound, accuracy, avg_accuracy],
                             feed_dict={images: mnist_data.test.images, labels: mnist_data.test.labels})
    return IZY, IZX, acc, avg_acc, 1-acc, 1-avg_acc

In [14]:
for epoch in range(10):
    for step in range(steps_per_batch):
        im, ls = mnist_data.train.next_batch(batch_size)
        sess.run(train_tensor, feed_dict={images: im, labels: ls})
    print("{}: IZY={:.2f}\tIZX={:.2f}\tacc={:.4f}\tavg_acc={:.4f}\terr={:.4f}\tavg_err={:.4f}".format(epoch, *evaluate()))
    sys.stdout.flush()
    
savepth = saver.save(sess, '/tmp/mnistvib', global_step)

0: IZY=3.03	IZX=124.24	acc=0.9407	avg_acc=0.9497	err=0.0593	avg_err=0.0503
1: IZY=3.13	IZX=102.63	acc=0.9563	avg_acc=0.9660	err=0.0437	avg_err=0.0340
2: IZY=3.14	IZX=90.19	acc=0.9610	avg_acc=0.9708	err=0.0390	avg_err=0.0292
3: IZY=3.16	IZX=82.35	acc=0.9647	avg_acc=0.9732	err=0.0353	avg_err=0.0268
4: IZY=3.17	IZX=78.39	acc=0.9670	avg_acc=0.9740	err=0.0330	avg_err=0.0260
5: IZY=3.19	IZX=71.32	acc=0.9691	avg_acc=0.9768	err=0.0309	avg_err=0.0232
6: IZY=3.19	IZX=77.62	acc=0.9720	avg_acc=0.9773	err=0.0280	avg_err=0.0227
7: IZY=3.19	IZX=66.89	acc=0.9742	avg_acc=0.9803	err=0.0258	avg_err=0.0197
8: IZY=3.20	IZX=69.95	acc=0.9740	avg_acc=0.9810	err=0.0260	avg_err=0.0190
9: IZY=3.20	IZX=62.87	acc=0.9758	avg_acc=0.9810	err=0.0242	avg_err=0.0190


In [15]:
saver_polyak.restore(sess, savepth)
evaluate()

Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from /tmp/mnistvib-6000


(3.228585,
 71.078026,
 0.9798,
 0.9839,
 0.020200014114379883,
 0.016099989414215088)

In [16]:
saver.restore(sess, savepth)
evaluate()

INFO:tensorflow:Restoring parameters from /tmp/mnistvib-6000


(3.1953857,
 62.871254,
 0.9749,
 0.9821,
 0.025099992752075195,
 0.01789999008178711)