In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))


# Gumbel Softmax


* Pre-print, published in ICLR 2017 https://arxiv.org/pdf/1611.01144.pdf


## Experiments

* **Dataset**: We use the MNIST dataset with fixed binarization for training and evaluation
* **Tricks**: We also found that variance normalization was necessary
* **Network**: We used sigmoid activation functions for binary (Bernoulli) neural networks and softmax activations for categorical variables.
* **Training**: Models were trained using stochastic gradient descent with momentum 0.9.
* **Learning rates**:  are chosen from {3e−5, 1e−5, 3e−4, 1e−4, 3e−3, 1e−3}; we select the best learning rate for each estimator using the MNIST validation set, and report performance on the test set.
* **Tasks** Each estimator is evaluated on two tasks: (1) structured output prediction and (2) variational training of generative models. 

### 1) Structured output prediction with stochastic binary networks

### 2) Generative modelling with variational Autoencoders

### 3) Generative semi supervised classification

# Requirements

In [2]:
import tensorflow as tf
tf.__version__

'2.1.0'

In [3]:
import pathlib
import os
import matplotlib.pyplot as plt
import numpy as np

np.set_printoptions(precision=4)

In [4]:
import pandas as pd

# Load and preprocess data

In [5]:
import tensorflow_datasets as tfds
mnist_data = tfds.load("binarized_mnist", data_dir="/tf/2016/gumbel-softmax/data")
mnist_train, mnist_test = mnist_data["train"], mnist_data["test"]
assert isinstance(mnist_train, tf.data.Dataset)

[1mDownloading and preparing dataset binarized_mnist/1.0.0 (download: 104.68 MiB, generated: Unknown size, total: 104.68 MiB) to /tf/2016/gumbel-softmax/data/binarized_mnist/1.0.0...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…







HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /tf/2016/gumbel-softmax/data/binarized_mnist/1.0.0.incompleteRB7X4H/binarized_mnist-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=50000.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /tf/2016/gumbel-softmax/data/binarized_mnist/1.0.0.incompleteRB7X4H/binarized_mnist-validation.tfrecord


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /tf/2016/gumbel-softmax/data/binarized_mnist/1.0.0.incompleteRB7X4H/binarized_mnist-test.tfrecord


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

[1mDataset binarized_mnist downloaded and prepared to /tf/2016/gumbel-softmax/data/binarized_mnist/1.0.0. Subsequent calls will reuse this data.[0m


# Gumbel Softmax 
* https://gist.github.com/ericjang/1001afd374c2c3b7752545ce6d9ed349

In [6]:
def sample_gumbel(shape, eps=1e-20): 
    """Sample from Gumbel(0, 1)"""
    U = tf.random_uniform(shape,minval=0,maxval=1)
    return -tf.log(-tf.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature): 
    """ Draw a sample from the Gumbel-Softmax distribution"""
    y = logits + sample_gumbel(tf.shape(logits))
    return tf.nn.softmax( y / temperature)

def gumbel_softmax(logits, temperature, hard=False):
    """Sample from the Gumbel-Softmax distribution and optionally discretize.
    Args:
        logits: [batch_size, n_class] unnormalized log-probs
        temperature: non-negative scalar
        hard: if True, take argmax, but differentiate w.r.t. soft sample y
    Returns:
        [batch_size, n_class] sample from the Gumbel-Softmax distribution.
        If hard=True, then the returned sample will be one-hot, otherwise it will
        be a probabilitiy distribution that sums to 1 across classes
    """
    y = gumbel_softmax_sample(logits, temperature)
    if hard:
        k = tf.shape(logits)[-1]
        #y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
        y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
        y = tf.stop_gradient(y_hard - y) + y
    return y

### 1) Structured output prediction with stochastic binary networks

* **Task**: Predict lower half of mnist image given top half.
* The minimization objective for this conditional generative model is an importance-sampled estimate of the likelihood objective, Eh∼pθ(hi|xupper)  m Pm i=1 log pθ(xlower|hi)

*  where m = 1 is used for training and m = 1000 is used for evaluation.

In [7]:
class GumbelSoftmaxStructuredOutputPrediciton(tf.keras.Model):
    """
        Predicts lower half of an mnist image given the top half.
    """

    def __init__(self):
        super(GumbelSoftmaxStructuredOutputPrediciton, self).__init__()
    
    def setup_model(self):
        self.dense1 = tf.keras.layers.Dense(200, activation=tf.nn.sigmoid) # [bs,392] => [bs,200]
        self.dense2 = tf.keras.layers.Dense(200, activation=tf.nn.sigmoid) # [bs,200] => [bs,200]
        self.dense3 = tf.keras.layers.Dense(392, activation=None) # [bs,200] => [bs,392]
        
    def call(self, inputs, temperature):
        h1 = self.dense1(inputs)
        h2 = self.dense2(h1)
        logits = self.dense3(h2) 
        y = gumbel_softmax_sample(logits, temperature)
        return y

sop_model = GumbelSoftmaxStructuredOutputPrediciton()

In [8]:
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9, nesterov=False, name='SGD') # m {3e−5, 1e−5, 3e−4, 1e−4, 3e−3, 1e−3};

In [None]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)