# Categorical VAE with Gumbel-Softmax

Partial implementation of the paper [Categorical Reparameterization with Gumbel-Softmax](https://arxiv.org/abs/1611.01144) 

The implementation follows tightly the code from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb

## Imports and Helper Functions

In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfpd
import pandas as pd

In [2]:
tf.__version__, tfp.__version__

('2.2.0', '0.9.0')

In [3]:
import sys
sys.path.append('../mdnf')

In [4]:
import aux
import time
import copy
import sys

## Configuration

In [5]:
# can be run as a script with args in format KEY=VAL,KEY=[STRVAL],...
args = aux.parse_script_args() 

In [6]:
SEED = args.get("SEED", 0)
OUT = args.get("OUT", "VAEConcrete.csv")

ST = bool(args.get("ST", 0))
LOSS = args.get("LOSS", 3)
N = args.get("N", 10) # how many latent variables
K = args.get("K", 20) # how many categories each

BASE_TEMP=args.get("BASE_TEMP", 1.0) # initial temperature
ANNEAL_RATE=args.get("ANNEAL_RATE", 0.00003)
MIN_TEMP=args.get("MIN_TEMP", 0.5)
PRIORS_TEMP=args.get("PRIORS_TEMP", 1.0) # ignored with the Jang's loss

OPTIMIZER = args.get("OPTIMIZER", "ADAM")
LR = args.get("LR", 0.001)
BATCH_SIZE=args.get("BATCH_SIZE", 256) # how many samples in minibatch
NUM_ITERS=args.get("NUM_ITERS", 100) # how many epochs

In [7]:
assert LOSS==3, "This implementation supports only LOSS==3"
ALG_NAME = {0: "VAE_CONCRETE", 
            1: "VAE_CONCRETE_MADDISON21", 
            2: "VAE_CONCRETE_MADDISON22_JANG2", 
            3: "VAE_CONCRETE_JANG",
            4: "VAE_CONCRETE_CATPRIORS"}[LOSS]
CFG = [ALG_NAME, SEED, OUT, ST, N, K, BATCH_SIZE, NUM_ITERS, 
       BASE_TEMP, PRIORS_TEMP, OPTIMIZER, LR, ANNEAL_RATE, MIN_TEMP]
print("CFG=%s" % CFG)

CFG=['VAE_CONCRETE_JANG', 0, 'VAEConcrete.csv', False, 10, 20, 256, 4, 1.0, 1.0, 'ADAM', 0.001, 3e-05, 0.5]


##  Gumbel-Softmax & Straight-Through

In [8]:
@tf.function       
def straight_through_sample(y, hard=True):
  """
  Args:
    y: [..., n_class] one-hot sample from the Gumbel-Softmax distribution. 
    hard: if True, take argmax, but differentiate w.r.t. soft sample y
  Returns:
    If hard=True, then the returned sample y will be one-hot, otherwise it will
    be a probabilitiy distribution that sums to 1 across classes
  """
  if hard:
    K = y.shape[-1]
    y_hard = tf.cast(tf.one_hot(tf.argmax(y,-1), K), y.dtype)
    y = tf.stop_gradient(y_hard - y) + y  
  return y


@tf.function        
def sample_gumbel(shape, eps=1e-20): 
  """Sample from Gumbel(0, 1)"""
  U = tf.random.uniform(shape,minval=0,maxval=1)
  return -tf.math.log(-tf.math.log(U + eps) + eps)


@tf.function   
def gumbel_softmax_sample(logits, temperature): 
  """ Draw a sample from the Gumbel-Softmax distribution"""
  y = logits + sample_gumbel(tf.shape(logits))
  return tf.nn.softmax( y / temperature)

## Model

In [9]:
class CategoricalVAE(tf.Module):

    def __init__(self, N, K, hard=False, name=None):
        """
        
            Args:
                N: number of variables
                K: number of categories for each variable
                hard: set hard=True for ST Gumbel-Softmax 
        """
        super(CategoricalVAE, self).__init__(name=name)

        self.N = N # number of categorical distributions
        self.K = K # number of classes
        self.hard = hard # set hard=True for ST Gumbel-Softmax 

        self.calc_logits_y = tf.keras.Sequential([ # encoder
                                             tf.keras.layers.Flatten(name="encoder0"),
                                             tf.keras.layers.Dense(512, activation="relu", 
                                                                   input_shape=(None, 784), name="encoder1"),
                                             tf.keras.layers.Dense(256, activation="relu", name="encoder2"),
                                             tf.keras.layers.Dense(K*N, activation=None, name="encoder3"),
                                             tf.keras.layers.Reshape( [N,K] , name="encoder4")
                                            ], name="encoder")

        self.calc_logits_x = tf.keras.Sequential( # decoder
                                            [tf.keras.layers.Flatten(name="decoder0"),
                                             tf.keras.layers.Dense(256, activation="relu", name="decoder1"),
                                             tf.keras.layers.Dense(512, activation="relu", name="decoder2"),
                                             tf.keras.layers.Dense(784, activation=None, name="decoder3"),
                                            ], name="decoder") 
    
    def __call__(self, x, temperature=5.0):
        tau = tf.constant(temperature, name="temperature", dtype=tf.float32)                        
        
        # variational posterior q(y|x), i.e. the encoder 
        # unnormalized logits for N separate K-categorical distributions 
        # (shape=(batch_size,N,K))
        logits_y = self.calc_logits_y(x)
        q_y = tf.nn.softmax(logits_y)                        
        y = gumbel_softmax_sample(logits_y, tau)  
        y = straight_through_sample(y, self.hard) # set hard=True for ST Gumbel-Softmax

        # generative model p(x|y)
        logits_x = self.calc_logits_x(y)    
        p_x = tfpd.Bernoulli(logits=logits_x)
        
        return p_x, q_y, y, logits_y
    

## Train

In [10]:
# get data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(path='mnist.npz')
frmt = lambda images: tf.cast(tf.reshape(images,(-1, 784)), tf.float32)/255.0
x_train, x_test = frmt(x_train), frmt(x_test)

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(BATCH_SIZE)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

In [11]:
OPTIMIZERS = {"RMS": tf.keras.optimizers.RMSprop,
              "ADAM": tf.keras.optimizers.Adam}
if OPTIMIZER not in OPTIMIZERS: raise ValueError("Unknown optimizer!")
optimizer_class = OPTIMIZERS[OPTIMIZER]
optimizer = optimizer_class(learning_rate=LR)

print("optimizer=%s lr=%s" % (optimizer, LR))

optimizer=<tensorflow.python.keras.optimizer_v2.adam.Adam object at 0x7fb060446c40> lr=0.001


In [12]:
def loss_Jang(x, p_x, q_y, y):
    """ Matching loss from
        https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb
    """
    batch_size, N, K = q_y.shape

    log_q_y = tf.math.log(q_y+1e-20)       
    
    KL = tf.reshape(q_y*(log_q_y-tf.math.log(1.0/K)),[-1,N,K])
    KL = tf.reduce_sum(KL, [1,2])

    elbo = tf.reduce_sum(p_x.log_prob(x), 1) - KL        
    return tf.reduce_mean(-elbo)

In [13]:
def true_loss(x, p_x, q_y, logits_y, nsamples=1000, *args):
    batch_size, N, K = q_y.shape
    
    samples = tf.stack([gumbel_softmax_sample(logits_y, 1.0) for _ in range(nsamples)])
    py = tf.reduce_mean(straight_through_sample(samples), 0)
    H_q = -tf.reduce_sum(py*np.log(py+1e-16), [-1,-2]) # sum over categories and then over independent variables

    Eq_logp = np.sum(np.mean(np.ones((batch_size,N,K))*np.log(1/K), -1), -1)

    KL = -Eq_logp -H_q
    lik = tf.reduce_sum(p_x.log_prob(x), -1)
    return float(-tf.reduce_mean(lik)+tf.reduce_mean(KL))

In [14]:
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [15]:
vae = CategoricalVAE(N=N, K=K, hard=ST)
results = []
best_loss, best_vae = float("inf"), None
start = time.time()
i = 0
np_temp=BASE_TEMP

In [16]:
loss = {
        3: loss_Jang
}[LOSS]
print("loss = %s" % loss)

loss = <function loss_Jang at 0x7fb081d0b3a0>


In [17]:
for e in range(NUM_ITERS):

    losses, true_losses = [], []
    for np_x, labels in train_ds:
        
        with tf.GradientTape() as tape:        
            p_x, q_y, y, logits_y = vae(np_x, temperature=np_temp)
            l = loss(np_x, p_x, q_y, y)
        losses.append( float(l) ); 
        tl = true_loss(np_x, p_x, q_y, logits_y)
        true_losses.append( float(tl) )
        g = tape.gradient(l, vae.trainable_variables)
        optimizer.apply_gradients(zip(g, vae.trainable_variables))
        
        if i % 1000 == 1: # following Jang's code
            np_temp = np.maximum(BASE_TEMP*np.exp(-ANNEAL_RATE*i), MIN_TEMP)
            
        i += 1
        if i%100==0 or i<10: 
            print("[%.1fs] epoch=%i/iteration=%i loss=%.2f ELBO=%.2f" % \
                  (time.time()-start, e, i, l, -tl))    
    ########################################################################
    
    if np.mean(losses)<best_loss:
        best_loss = np.mean(losses)
        try: 
            best_vae = copy.deepcopy(vae)
        except Exception as exc:
            if best_vae is None:
                    print("[ERROR] Failed to copy VAE object as best_vae: %s" % exc)
            best_vae = vae                
            
    print ("[%.1fs] %d. l=%.2f (best=%.2f) ELBO=%.2f t=%.2f" % (time.time()-start, 
            e, np.mean(losses), best_loss, -np.mean(true_losses), np_temp))                        
    results.append(CFG+[time.time()-start, 
            e, np.mean(losses), best_loss, np.mean(true_losses), np_temp])
    sys.stdout.flush()
    
    print("Saving to %s" % OUT)
    pd.DataFrame(results).to_csv(OUT, header=False, index=False)    

[1.2s] epoch=0/iteration=1 loss=543.39 ELBO=-543.48
[1.8s] epoch=0/iteration=2 loss=537.70 ELBO=-537.79
[2.4s] epoch=0/iteration=3 loss=531.24 ELBO=-531.33
[3.1s] epoch=0/iteration=4 loss=522.39 ELBO=-522.48
[3.7s] epoch=0/iteration=5 loss=510.15 ELBO=-510.24
[4.3s] epoch=0/iteration=6 loss=494.42 ELBO=-494.51
[5.0s] epoch=0/iteration=7 loss=470.15 ELBO=-470.25
[5.6s] epoch=0/iteration=8 loss=443.06 ELBO=-443.16
[6.3s] epoch=0/iteration=9 loss=409.12 ELBO=-409.21
[62.1s] epoch=0/iteration=100 loss=196.48 ELBO=-196.58
[124.9s] epoch=0/iteration=200 loss=188.83 ELBO=-188.93
[151.6s] 0. l=213.82 (best=213.82) ELBO=-213.91 t=1.00
Saving to VAEConcrete.csv
[201.5s] epoch=1/iteration=300 loss=166.84 ELBO=-166.94
[281.0s] epoch=1/iteration=400 loss=159.60 ELBO=-159.69
[336.0s] 1. l=169.31 (best=169.31) ELBO=-169.40 t=1.00
Saving to VAEConcrete.csv
[360.1s] epoch=2/iteration=500 loss=146.04 ELBO=-146.14
[441.7s] epoch=2/iteration=600 loss=145.43 ELBO=-145.52
[509.6s] epoch=2/iteration=700 loss