In [2]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))


# Gumbel Softmax


* Pre-print, published in ICLR 2017 https://arxiv.org/pdf/1611.01144.pdf

* https://arxiv.org/pdf/1406.2989.pdf


## Experiments

* **Dataset**: We use the MNIST dataset with fixed binarization for training and evaluation
* **Tricks**: We also found that variance normalization was necessary
* **Network**: We used sigmoid activation functions for binary (Bernoulli) neural networks and softmax activations for categorical variables.
* **Training**: Models were trained using stochastic gradient descent with momentum 0.9.
* **Learning rates**:  are chosen from {3e−5, 1e−5, 3e−4, 1e−4, 3e−3, 1e−3}; we select the best learning rate for each estimator using the MNIST validation set, and report performance on the test set.
* **Tasks** Each estimator is evaluated on two tasks: (1) structured output prediction and (2) variational training of generative models. 

### 1) Structured output prediction with stochastic binary networks

### 2) Generative modelling with variational Autoencoders

### 3) Generative semi supervised classification

# Requirements

In [3]:
import tensorflow as tf
tf.__version__

'2.1.0'

In [4]:
import pathlib
import os
import matplotlib.pyplot as plt
import numpy as np

np.set_printoptions(precision=4)

In [5]:
import pandas as pd

# Load and preprocess data

In [6]:
import tensorflow_datasets as tfds
mnist_data = tfds.load("binarized_mnist", data_dir="/tf/data")
mnist_train, mnist_test = mnist_data["train"], mnist_data["test"]
assert isinstance(mnist_train, tf.data.Dataset)

In [10]:
## split upper half / lower half

In [27]:
mnist_train.unbatch().shuffle(10000).batch(32)
https://www.tensorflow.org/tutorials/quickstart/advanced


TypeError: 'BatchDataset' object does not support indexing

In [9]:

train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).shuffle(10000).batch(32)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)



<BatchDataset shapes: {image: (None, 28, 28, 1)}, types: {image: tf.uint8}>

# Gumbel Softmax 
* https://gist.github.com/ericjang/1001afd374c2c3b7752545ce6d9ed349

In [6]:
def sample_gumbel(shape, eps=1e-20): 
    """Sample from Gumbel(0, 1)"""
    U = tf.random_uniform(shape,minval=0,maxval=1)
    return -tf.log(-tf.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature): 
    """ Draw a sample from the Gumbel-Softmax distribution"""
    y = logits + sample_gumbel(tf.shape(logits))
    return tf.nn.softmax( y / temperature)

def gumbel_softmax(logits, temperature, hard=False):
    """Sample from the Gumbel-Softmax distribution and optionally discretize.
    Args:
        logits: [batch_size, n_class] unnormalized log-probs
        temperature: non-negative scalar
        hard: if True, take argmax, but differentiate w.r.t. soft sample y
    Returns:
        [batch_size, n_class] sample from the Gumbel-Softmax distribution.
        If hard=True, then the returned sample will be one-hot, otherwise it will
        be a probabilitiy distribution that sums to 1 across classes
    """
    y = gumbel_softmax_sample(logits, temperature)
    if hard:
        k = tf.shape(logits)[-1]
        #y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
        y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
        y = tf.stop_gradient(y_hard - y) + y
    return y

### 1) Structured output prediction with stochastic binary networks

* **Task**: Predict lower half of mnist image given top half.
* The minimization objective for this conditional generative model is an importance-sampled estimate of the likelihood objective, Eh∼pθ(hi|xupper)  m Pm i=1 log pθ(xlower|hi)

*  where m = 1 is used for training and m = 1000 is used for evaluation.

In [7]:
class GumbelSoftmaxStructuredOutputPrediciton(tf.keras.Model):
    """
        Predicts lower half of an mnist image given the top half.
    """

    def __init__(self):
        super(GumbelSoftmaxStructuredOutputPrediciton, self).__init__()
    
    def setup_model(self):
        self.input_layer = tf.keras.layers.Dense(200, activation=tf.nn.sigmoid) # [bs,392] => [bs,200]
        self.categorical_layer = tf.keras.layers.Dense(200, activation=None) # [bs,200] => [bs,200]
        self.output_layer = tf.keras.layers.Dense(392, activationtf.nn.sigmoid) # [bs,200] => [bs,392]
        
    def call(self, upper_image_half, temperature):
        h1 = self.input_layer(output_layer)
        logits = self.categorical_layer(h1) 
        h2 = gumbel_softmax_sample(logits, temperature)
        lower_image_half = self.output_layer(h3)       
        return lower_image_half

sop_model = GumbelSoftmaxStructuredOutputPrediciton()

In [8]:
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9, nesterov=False, name='SGD') # m {3e−5, 1e−5, 3e−4, 1e−4, 3e−3, 1e−3};

In [None]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)