In [2]:
import numpy as np
import tensorflow as tf
from tensorflow.contrib.layers import fully_connected, batch_norm
from tensorflow.examples.tutorials.mnist import input_data
from datetime import datetime
import os
import sys

### Some Functions

In [3]:
def show_reconstructed_digits(X, outputs, model_path = None, n_test_digits = 2):
    with tf.Session() as sess:
        if model_path:
            saver.restore(sess, model_path)
        X_test = mnist.test.images[:n_test_digits]
        outputs_val = outputs.eval(feed_dict={X: X_test})

    fig = plt.figure(figsize=(8, 3 * n_test_digits))
    for digit_index in range(n_test_digits):
        plt.subplot(n_test_digits, 2, digit_index * 2 + 1)
        plot_image(X_test[digit_index])
        plt.subplot(n_test_digits, 2, digit_index * 2 + 2)
        plot_image(outputs_val[digit_index])

### Construction Phase

We will construct the graph for the VFAE architecture:

- Input: X = [X_without_s, s], where s is the sensitive feature
- Middle Encodings: We're learning the parameters for the distribution of the encodings. What's different here is that we inject both the response y and the sensitive features in the middle layers.
- Output: X_copy

In [4]:
# Construction phase
n_s = 10 # number of sensitive features
n_inputs = 28*28 - n_s # number of non-sensitive features

# encoders
n_hidden1 = 500
n_hidden2 = 20 # codings
n_hidden3 = 500
n_hidden4 = 20

# decoders
n_hidden5 = 500
n_hidden6 = 20
n_hidden7 = 500

# final output can take a random sample from the posterior
n_outputs = n_inputs + n_s

In [7]:
### Training rates
alpha = 1
learning_rate = 0.001

In [8]:
### Setting up the graph
with tf.contrib.framework.arg_scope(
        [fully_connected],
        activation_fn = tf.nn.elu,
        weights_initializer = tf.contrib.layers.variance_scaling_initializer()):
    X = tf.placeholder(tf.float32, shape = [None, n_inputs], name="X_wo_s")
    s = tf.placeholder(tf.float32, shape = [None, n_s], name="s")
    X_full = tf.concat([X,s], axis=1)
    y = tf.placeholder(tf.int32, shape = [None, 1], name="y") # for your example, switch this to tf.float32 bc you'll be doing reg
    is_unlabelled = tf.placeholder(tf.bool, shape=(), name='is_training') # don't worry about this
    with tf.name_scope("X_encoder"):
        hidden1 = fully_connected(tf.concat([X, s], axis=1), n_hidden1)
        hidden2_mean = fully_connected(hidden1, n_hidden2, activation_fn = None)
        hidden2_gamma = fully_connected(hidden1, n_hidden2, activation_fn = None)
        hidden2_sigma = tf.exp(0.5 * hidden2_gamma)
    noise1 = tf.random_normal(tf.shape(hidden2_sigma), dtype=tf.float32)
    hidden2 = hidden2_mean + hidden2_sigma * noise1         # z1
    with tf.name_scope("Z1_encoder"):
        hidden3_ygz1 = fully_connected(hidden2, n_hidden4, activation_fn = tf.nn.tanh)
        hidden4_softmax_mean = fully_connected(hidden3_ygz1, 10, activation_fn = tf.nn.softmax)
        if is_unlabelled == True:
            # impute by sampling from q(y|z1)
            y = tf.assign(y, tf.multinomial(hidden4_softmax_mean, 1,
                                output_type = tf.int32))
        hidden3 = fully_connected(tf.concat([hidden2, tf.cast(y, tf.float32)], axis=1),
                        n_hidden3, activation_fn=tf.nn.tanh)
        hidden4_mean = fully_connected(hidden3, n_hidden4, activation_fn = None)
        hidden4_gamma = fully_connected(hidden3, n_hidden4, activation_fn = None)
        hidden4_sigma = tf.exp(0.5 * hidden4_gamma)
    noise2 = tf.random_normal(tf.shape(hidden4_sigma), dtype=tf.float32)
    hidden4 = hidden4_mean + hidden4_sigma * noise2     # z2
    with tf.name_scope("Z1_decoder"):
        hidden5 = fully_connected(tf.concat([hidden4, tf.cast(y, tf.float32)], axis=1 ),
                    n_hidden5, activation_fn = tf.nn.tanh)
        hidden6_mean = fully_connected(hidden5, n_hidden6, activation_fn = None)
        hidden6_gamma = fully_connected(hidden5, n_hidden6, activation_fn = None)
        hidden6_sigma = tf.exp(0.5 * hidden6_gamma)
    noise3 = tf.random_normal(tf.shape(hidden6_sigma), dtype=tf.float32)
    hidden6 = hidden6_mean + hidden6_sigma * noise3     # z1 (decoded)
    with tf.name_scope("X_decoder"):
        hidden7 = fully_connected(tf.concat([hidden6, s], axis=1), n_hidden7,
                                 activation_fn = tf.nn.tanh)
        hidden8 = fully_connected(hidden7, n_outputs, activation_fn = None)
    outputs = tf.sigmoid(hidden8, name="decoded_X")

### Loss Function: ELBO

In [9]:
# expected lower bound
with tf.name_scope("ELB"):
    kl_z2 = 0.5 * tf.reduce_sum(
                    tf.exp(hidden4_gamma)
                    + tf.square(hidden4_mean)
                    - 1
                    - hidden4_gamma
                    )

    kl_z1 = 0.5 * (tf.reduce_sum(
                    (1 / (1e-10 + tf.exp(hidden6_gamma))) * tf.exp(hidden2_gamma)
                    - 1
                    + hidden6_gamma
                    - hidden2_gamma
                    ) + tf.einsum('ij,ji -> i', # this might not work for you depending on version of tflow
                        (hidden6_mean-hidden2_mean) * (1 / (1e-10 + tf.exp(hidden6_gamma))),
                        tf.transpose((hidden6_mean-hidden2_mean))))

    indices = tf.range(tf.shape(y)[0])
    indices = tf.concat([indices[:, tf.newaxis], y], axis=1)
    eps = 1e-10
    log_q_y_z1 = tf.reduce_sum(tf.log(eps + tf.gather_nd(hidden4_softmax_mean, indices)))

    # Bernoulli log-likelihood
    reconstruction_loss = -(tf.reduce_sum(X_full * tf.log(outputs)
                            + (1 - X_full) * tf.log(1 - outputs)))
    cost = kl_z2 + kl_z1 + reconstruction_loss + alpha * log_q_y_z1

In [10]:
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(cost)

### Initialize Graph & Load Data

In [15]:
init = tf.global_variables_initializer()
mnist = input_data.read_data_sets("/tmp/data/")

Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting /tmp/data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


In [16]:
# Training
n_epochs = 50
batch_size = 100
n_digits = 60

with tf.Session() as sess:
    init.run()
    for epoch in range(n_epochs):
        n_batches = mnist.train.num_examples // batch_size
        for iteration in range(n_batches):
            print("\r{}%".format(100 * iteration // n_batches), end="")
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            sess.run(training_op, feed_dict={X: X_batch[:,:-n_s],
                                    s: X_batch[:,-n_s:],
                                    y: y_batch[:,np.newaxis],
                                    is_unlabelled: False})
        kl_z2_val, kl_z1_val, log_q_y_z1_val, reconstruction_loss_val, loss_val = sess.run([
                kl_z2,
                kl_z1,
                log_q_y_z1,
                reconstruction_loss,
                cost],
                feed_dict={X: X_batch[:,:-n_s],
                        s: X_batch[:,-n_s:],
                        y: y_batch[:,np.newaxis]})
        print("\r{}".format(epoch), "Train total loss:", loss_val,
         "\tReconstruction loss:", reconstruction_loss_val,
          "\tKL-z1:", kl_z1_val,
          "\tKL-z2:", kl_z2_val,
          "\tlog_q(y|z1):", log_q_y_z1_val)

0 Train total loss: 15808.6 	Reconstruction loss: 16510.4 	KL-z1: 544.288 	KL-z2: 767.143 	log_q(y|z1): -2013.19
1 Train total loss: 14580.0 	Reconstruction loss: 15223.1 	KL-z1: 548.365 	KL-z2: 786.87 	log_q(y|z1): -1978.38
2 Train total loss: 13775.5 	Reconstruction loss: 14337.7 	KL-z1: 552.086 	KL-z2: 842.155 	log_q(y|z1): -1956.4
3 Train total loss: 13613.7 	Reconstruction loss: 14258.4 	KL-z1: 563.264 	KL-z2: 840.809 	log_q(y|z1): -2048.83
4 Train total loss: 12805.3 	Reconstruction loss: 13532.6 	KL-z1: 485.266 	KL-z2: 836.376 	log_q(y|z1): -2049.03
5 Train total loss: 12624.4 	Reconstruction loss: 13176.3 	KL-z1: 504.406 	KL-z2: 877.674 	log_q(y|z1): -1933.99
6 Train total loss: 11958.7 	Reconstruction loss: 12733.8 	KL-z1: 474.23 	KL-z2: 845.952 	log_q(y|z1): -2095.22
7 Train total loss: 11915.2 	Reconstruction loss: 12431.5 	KL-z1: 606.83 	KL-z2: 833.997 	log_q(y|z1): -1957.11
8 Train total loss: 12707.2 	Reconstruction loss: 13417.9 	KL-z1: 550.244 	KL-z2: 857.368 	log_q(y|z