In [1]:
import jax 
import jax.numpy as jnp
from jax.random import gumbel
import optax
import haiku as hk
import numpy as np

Given class probabilities $p_1, \dots, p_n$, we can approximate the categorical distribution over these probabilities using a Gumbel softmax distribution as described in the paper, and this is a distribution over the probability n-simplex. Taking argmax allows us to obtain discrete value and we can use the straight through gumbel softmax estimator for gradient updates.

We first code up the gumbel softmax distribution in haiku and visualizes it at different temperatures.

In [2]:
class GumbelSoftmax(hk.Module):
    def __init__(self, probas, name=None):
        super().__init__(name)
        assert jnp.sum(probabs) == 1.
        self.log_probas = jnp.log(probas)
        self.N = len(probabs)
        
    def __call__(self):
        g = gumbel(hk.next_rng_key(), shape=self.N)
        pass

In [3]:
def f(key, logits):
    return ((jax.random.categorical(key, logits))**2).astype('float32')

In [4]:
logits = jnp.log(jnp.array([1/5, 1/5, 1/5, 1/5, 1/5]))
logits

DeviceArray([-1.609438, -1.609438, -1.609438, -1.609438, -1.609438], dtype=float32)

In [5]:
f_grad = jax.grad(f, argnums=1)
rng_key = jax.random.PRNGKey(48)
f_grad(rng_key, logits)

DeviceArray([0., 0., 0., 0., 0.], dtype=float32)

In [6]:
def ArgmaxGumbelSoftmax(key, logits, tau=1):
    g = jax.random.gumbel(key, shape=logits.shape)
    z = (g + logits) / tau
    y = jax.nn.softmax(z)
    zero = y - jax.lax.stop_gradient(y)
    return (zero + jax.lax.stop_gradient(jnp.argmax(y)))[0]

In [7]:
counts = np.zeros([5])
for i in range(1000):
    y_ = ArgmaxGumbelSoftmax(jax.random.PRNGKey(i), logits, tau=0.001)
    counts[int(y_)] += 1   
counts

array([201., 199., 198., 190., 212.])

In [9]:
def f(key, logits):
    return (ArgmaxGumbelSoftmax(key, logits))**2

In [10]:
jax.grad(f, argnums=1)(rng_key, logits)

DeviceArray([ 0.46325743, -0.09352899, -0.34142467, -0.00644364,
             -0.02186022], dtype=float32)

In [11]:
jax.grad(ArgmaxGumbelSoftmax, argnums=1)(rng_key, logits)

DeviceArray([ 0.11581436, -0.02338225, -0.08535617, -0.00161091,
             -0.00546505], dtype=float32)