In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))


# Gumbel Softmax


* Pre-print, published in ICLR 2017 https://arxiv.org/pdf/1611.01144.pdf

* https://arxiv.org/pdf/1406.2989.pdf


## Experiments

* **Dataset**: We use the MNIST dataset with fixed binarization for training and evaluation
* **Tricks**: We also found that variance normalization was necessary
* **Network**: We used sigmoid activation functions for binary (Bernoulli) neural networks and softmax activations for categorical variables.
* **Training**: Models were trained using stochastic gradient descent with momentum 0.9.
* **Learning rates**:  are chosen from {3e−5, 1e−5, 3e−4, 1e−4, 3e−3, 1e−3}; we select the best learning rate for each estimator using the MNIST validation set, and report performance on the test set.
* **Tasks** Each estimator is evaluated on two tasks: (1) structured output prediction and (2) variational training of generative models. 

### 1) Structured output prediction with stochastic binary networks

### 2) Generative modelling with variational Autoencoders

### 3) Generative semi supervised classification

# Requirements

In [10]:
import tensorflow as tf
tf.__version__
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [3]:
import pathlib
import os
import matplotlib.pyplot as plt
import numpy as np

np.set_printoptions(precision=4)

In [4]:
import pandas as pd

# Load and preprocess data

In [5]:
import tensorflow_datasets as tfds
mnist_data = tfds.load("binarized_mnist", data_dir="/tf/data")
mnist_train, mnist_test = mnist_data["train"], mnist_data["test"]
assert isinstance(mnist_train, tf.data.Dataset)

[1mDownloading and preparing dataset binarized_mnist/1.0.0 (download: 104.68 MiB, generated: Unknown size, total: 104.68 MiB) to /tf/data/binarized_mnist/1.0.0...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…







HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /tf/data/binarized_mnist/1.0.0.incompleteFZRZTW/binarized_mnist-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=50000.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /tf/data/binarized_mnist/1.0.0.incompleteFZRZTW/binarized_mnist-validation.tfrecord


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /tf/data/binarized_mnist/1.0.0.incompleteFZRZTW/binarized_mnist-test.tfrecord


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

[1mDataset binarized_mnist downloaded and prepared to /tf/data/binarized_mnist/1.0.0. Subsequent calls will reuse this data.[0m


## split upper half / lower half

In [18]:
mnist_train.take(1)

<TakeDataset shapes: {image: (28, 28, 1)}, types: {image: tf.uint8}>

In [27]:
def split_lower_upper_half(image):
    print(image["image"])
    flat_image = tf.reshape(image["image"], [-1], name=None)
    upper_half, lower_half = tf.split(flat_image, num_or_size_splits=2, axis=0, num=None, name='split_image_to_upper_lower_half')

    
    return upper_half, lower_half#upper_half, lower_half # x, y

In [28]:
labeled_train_ds = mnist_train.map(split_lower_upper_half, num_parallel_calls=AUTOTUNE)
labeled_train_ds.take(1)

Tensor("args_0:0", shape=(28, 28, 1), dtype=uint8)


<TakeDataset shapes: (784,), types: tf.uint8>

# Gumbel Softmax 
* https://gist.github.com/ericjang/1001afd374c2c3b7752545ce6d9ed349

In [None]:
def sample_gumbel(shape, eps=1e-20): 
    """Sample from Gumbel(0, 1)"""
    U = tf.random_uniform(shape,minval=0,maxval=1)
    return -tf.log(-tf.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature): 
    """ Draw a sample from the Gumbel-Softmax distribution"""
    y = logits + sample_gumbel(tf.shape(logits))
    return tf.nn.softmax( y / temperature)

def gumbel_softmax(logits, temperature, hard=False):
    """Sample from the Gumbel-Softmax distribution and optionally discretize.
    Args:
        logits: [batch_size, n_class] unnormalized log-probs
        temperature: non-negative scalar
        hard: if True, take argmax, but differentiate w.r.t. soft sample y
    Returns:
        [batch_size, n_class] sample from the Gumbel-Softmax distribution.
        If hard=True, then the returned sample will be one-hot, otherwise it will
        be a probabilitiy distribution that sums to 1 across classes
    """
    y = gumbel_softmax_sample(logits, temperature)
    if hard:
        k = tf.shape(logits)[-1]
        #y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
        y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
        y = tf.stop_gradient(y_hard - y) + y
    return y

### 1) Structured output prediction with stochastic binary networks

* **Task**: Predict lower half of mnist image given top half.
* The minimization objective for this conditional generative model is an importance-sampled estimate of the likelihood objective, Eh∼pθ(hi|xupper)  m Pm i=1 log pθ(xlower|hi)

*  where m = 1 is used for training and m = 1000 is used for evaluation.

* For bernoulli variables they use signmoid activation
* For categorical variables they use 



In [None]:
class GumbelSoftmaxStructuredOutputPrediciton(tf.keras.Model):
    """
        Predicts lower half of an mnist image given the top half.
    """

    def __init__(self):
        super(GumbelSoftmaxStructuredOutputPrediciton, self).__init__()
    
    def setup_model(self):
        self.input_layer = tf.keras.layers.Dense(200, activation=tf.nn.sigmoid) # [bs,392] => [bs,200]
        self.categorical_layer = tf.keras.layers.Dense(200, activation=None) # [bs,200] => [bs,200]
        self.output_layer = tf.keras.layers.Dense(392, activationtf.nn.sigmoid) # [bs,200] => [bs,392]
        
        
        
    def call(self, upper_image_half, temperature):
        h1 = self.input_layer(output_layer)
        logits = self.categorical_layer(h1) 
        h2 = gumbel_softmax_sample(logits, temperature)
        lower_image_half = self.output_layer(h3)       
        return lower_image_half

sop_model = GumbelSoftmaxStructuredOutputPrediciton()

In [None]:
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9, nesterov=False, name='SGD') # m {3e−5, 1e−5, 3e−4, 1e−4, 3e−3, 1e−3};

# Loss function

* They use negative log likelihood. 

TODO not sure of what, it's not quite obvious from the paper

tf.nn.sigmoid_cross_entropy_with_logits 

-tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x))
 

In [None]:
loss_object = 

# Train loss, test loss

In [None]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = 

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = 

# Train step

In [None]:
@tf.function
def train_step(images_upper_half, images_lower_half):
  with tf.GradientTape() as tape:
    predictions = model(images_upper_half, training=True)
    loss = loss_object(images_lower_half, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss)
  train_accuracy(labels, predictions)

# Training

In [None]:
EPOCHS = 5

for epoch in range(EPOCHS):
  # Reset the metrics at the start of the next epoch
  train_loss.reset_states()
  train_accuracy.reset_states()
  test_loss.reset_states()
  test_accuracy.reset_states()

  for images, labels in train_ds:
    train_step(images, labels)

  for test_images, test_labels in test_ds:
    test_step(test_images, test_labels)

  template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
  print(template.format(epoch + 1,
                        train_loss.result(),
                        train_accuracy.result() * 100,
                        test_loss.result(),
                        test_accuracy.result() * 100))