In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))


# Gumbel Softmax


* Pre-print, published in ICLR 2017 https://arxiv.org/pdf/1611.01144.pdf

* https://arxiv.org/pdf/1406.2989.pdf


## Experiments

* **Dataset**: We use the MNIST dataset with fixed binarization for training and evaluation
* **Tricks**: We also found that variance normalization was necessary
* **Network**: We used sigmoid activation functions for binary (Bernoulli) neural networks and softmax activations for categorical variables.
* **Training**: Models were trained using stochastic gradient descent with momentum 0.9.
* **Learning rates**:  are chosen from {3e−5, 1e−5, 3e−4, 1e−4, 3e−3, 1e−3}; we select the best learning rate for each estimator using the MNIST validation set, and report performance on the test set.
* **Tasks** Each estimator is evaluated on two tasks: (1) structured output prediction and (2) variational training of generative models. 

### 1) Structured output prediction with stochastic binary networks

### 2) Generative modelling with variational Autoencoders

### 3) Generative semi supervised classification

# Requirements

In [2]:
import tensorflow as tf
tf.__version__
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [3]:
import pathlib
import os
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image

np.set_printoptions(precision=4)

In [4]:
import pandas as pd # do we need that

# Load and preprocess data

In [5]:
import tensorflow_datasets as tfds
mnist_data = tfds.load("binarized_mnist", data_dir="/tf/data")
mnist_train, mnist_test = mnist_data["train"], mnist_data["test"]
assert isinstance(mnist_train, tf.data.Dataset)

## split upper half / lower half

In [6]:
mnist_train.take(1)

<TakeDataset shapes: {image: (28, 28, 1)}, types: {image: tf.uint8}>

In [7]:
def split_lower_upper_half(image):
    print(image["image"])
    flat_image = tf.reshape(image["image"], [-1], name=None)
    upper_half, lower_half = tf.split(flat_image, num_or_size_splits=2, axis=0, num=None, name='split_image_to_upper_lower_half')

    
    return upper_half, lower_half#upper_half, lower_half # x, y

In [8]:
labeled_train_ds = mnist_train.map(split_lower_upper_half, num_parallel_calls=AUTOTUNE)
labeled_train_ds.take(1)

Tensor("args_0:0", shape=(28, 28, 1), dtype=uint8)


<TakeDataset shapes: ((392,), (392,)), types: (tf.uint8, tf.uint8)>

In [9]:
labeled_train_ds = labeled_train_ds.shuffle(100,reshuffle_each_iteration=True).batch(100)

# Gumbel Softmax 
* https://gist.github.com/ericjang/1001afd374c2c3b7752545ce6d9ed349

In [10]:
def sample_gumbel(shape, eps=1e-20): 
    """Sample from Gumbel(0, 1)"""
    U = tf.random.uniform(shape,minval=0,maxval=1)
    return -tf.math.log(-tf.math.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature): 
    """ Draw a sample from the Gumbel-Softmax distribution"""
    y = logits + sample_gumbel(tf.shape(logits))
    return tf.nn.softmax( y / temperature)

def gumbel_softmax(logits, temperature, hard=False):
    """Sample from the Gumbel-Softmax distribution and optionally discretize.
    Args:
        logits: [batch_size, n_class] unnormalized log-probs
        temperature: non-negative scalar
        hard: if True, take argmax, but differentiate w.r.t. soft sample y
    Returns:
        [batch_size, n_class] sample from the Gumbel-Softmax distribution.
        If hard=True, then the returned sample will be one-hot, otherwise it will
        be a probabilitiy distribution that sums to 1 across classes
    """
    y = gumbel_softmax_sample(logits, temperature)
    if hard:
        k = tf.shape(logits)[-1]
        #y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
        y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
        y = tf.stop_gradient(y_hard - y) + y
    return y

### 1) Structured output prediction with stochastic binary networks

* **Task**: Predict lower half of mnist image given top half.
* The minimization objective for this conditional generative model is an importance-sampled estimate of the likelihood objective, Eh∼pθ(hi|xupper)  m Pm i=1 log pθ(xlower|hi)

*  where m = 1 is used for training and m = 1000 is used for evaluation.

* For bernoulli variables they use signmoid activation
* For categorical variables they use 



In [11]:
class GumbelSoftmaxStructuredOutputPrediciton(tf.keras.Model):
    """
        Predicts lower half of an mnist image given the top half.
    """

    def __init__(self):
        super(GumbelSoftmaxStructuredOutputPrediciton, self).__init__()
        self.setup_model()
    
    def setup_model(self):
        self.input_layer = tf.keras.layers.Dense(200, activation=tf.nn.sigmoid) # [bs,392] => [bs,200]
        self.categorical_layer = tf.keras.layers.Dense(200, activation=None) # [bs,200] => [bs,200]
        self.output_layer = tf.keras.layers.Dense(392, activation=tf.nn.sigmoid) # [bs,200] => [bs,392]
              
        
    def call(self, upper_image_half, temperature=0.5):
        h1 = self.input_layer(upper_image_half)
        logits = self.categorical_layer(h1) 
        h2 = gumbel_softmax_sample(logits, temperature)
        lower_image_half = self.output_layer(h2)       
        return lower_image_half

sop_model = GumbelSoftmaxStructuredOutputPrediciton()

In [12]:
optimizer = tf.keras.optimizers.SGD(learning_rate=0.0001, momentum=0.9, nesterov=False, name='SGD') # m {3e−5, 1e−5, 3e−4, 1e−4, 3e−3, 1e−3};

# Loss function

* They use negative log likelihood from a bernoulli distribution where the probability is a sigmoid.

* log likelihodd


tf.nn.sigmoid_cross_entropy_with_logits 
-tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x))
 

In [13]:
loss_object =  tf.keras.losses.BinaryCrossentropy(
    from_logits=True, label_smoothing=0, name='binary_crossentropy'
) 

loss_object = tf.keras.losses.MeanSquaredError(
    name='mean_squared_error'
)

NameError: name 'losses_utils' is not defined

# Train loss, test loss

In [None]:
train_loss = tf.keras.metrics.Mean(name='train_loss')   
train_accuracy = tf.keras.metrics.BinaryAccuracy(name='binary_accuracy', dtype=None, threshold=0.5)

#test_loss =  tf.keras.metrics.BinaryCrossentropy( name='test_lsss', dtype=None, from_logits=False, label_smoothing=0)


# Train step

In [None]:
@tf.function
def train_step(images_upper_half, images_lower_half):
    #images_lower_half = tf.cast(x=images_lower_half, dtype=tf.float32)
    
    with tf.GradientTape() as tape:
        predicted_lower_half = sop_model(images_upper_half, training=True)
        #loss = tf.nn.sigmoid_cross_entropy_with_logits(
        #   labels=images_lower_half, 
        #   logits=predicted_lower_half, 
        #    name=None
        #)
        
        loss =  loss_object(images_lower_half, predicted_lower_half)
    gradients = tape.gradient(loss, sop_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, sop_model.trainable_variables))

    train_loss(loss)
    train_accuracy(images_lower_half,predicted_lower_half)


# Plot Images

In [None]:
def plotImages(images_arr,num_images=10):
    num_imgages = images_arr.shape[0]
    fig, axes = plt.subplots(1, num_images, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
images_labels_batch = np.array(list(labeled_train_ds.take(1).as_numpy_iterator())[0])

num_images = 10
images = images_labels_batch[0][0:num_images].reshape([num_images, -1, 28]) 
labels = images_labels_batch[1][0:num_images].reshape([num_images, -1, 28 ]) 

In [None]:
plotImages(images)

In [None]:
plotImages(labels)

# Training

In [None]:
EPOCHS = 100

for epoch in range(EPOCHS):
    # Reset the metrics at the start of the next epoch
    train_loss.reset_states()
    #test_loss.reset_states()

    # plot a few images every now and then
    if epoch%10==0:
        images_labels = np.array(list(labeled_train_ds.take(1).as_numpy_iterator())[0])
        images, labels = images_labels
        predictions = sop_model( images ,  training=False)

        num_images = 10
        images = images_labels_batch[0][0:num_images].reshape([num_images, -1, 28]) 
        labels = images_labels_batch[1][0:num_images].reshape([num_images, -1, 28 ]) 
        predictions = predictions.numpy()[0:num_images].reshape([num_images, -1, 28 ]) 

        plotImages(images)
        plotImages(labels)
        plotImages(predictions)



    for images, labels in labeled_train_ds:         
        train_step(images, labels)

    #for test_images, test_labels in test_ds:
    #    test_step(test_images, test_labels)

    template = 'Epoch {}, Loss: {}, Accuracy: {}'
    print(template.format(epoch + 1,
                        train_loss.result(),
                        train_accuracy.result() * 100))#,
    #                    test_loss.result(),
     #                   test_accuracy.result() * 100))