In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Unpacking the paper - CATEGORICAL REPARAMETERIZATION WITH GUMBEL-SOFTMAX

* https://arxiv.org/pdf/1611.01144.pdf


## Introduction
* Not much to get here



## THE GUMBEL-SOFTMAX Distribution 

* Gumbel-Softmax distribution, a continuous distribution over the simplex that can approximate samples from a categorical distribution: Simplex means that it consist of variables that are all between 0 .. 1, and the sum of all these variables is 1. Example: [0.2,0.2,0.2,0.4] 

* z is a categorical variable with class probabilities π1, π2, ...πk
* k is the number of classes
* samples (e.g., z's) are encoded as k-dimensional 1-hot vector. So if you have five classes, an examples is: [0,0,0,1,0]


### you can draw samples z efficiently by: 

* drawing k samples from a gumbel distribution $g_1...g_k$. The samples are independent and identically distributed drawn from a Gumbel Distribution $(\mu=0,\beta=1)^1$
* calculating $argmax(g_i + log(\pi_i))$ for all $k$ samples, with $\pi_i$ being the class probability. 
* create a one hot encoded of that argmax. 

### if you use softmax as approximation of argmax, you'll get gumbel-softmax
* additionally, they add $\tau$ as a temperature parameter to their softmax
* $y_i = exp(x_i) / sum of all exp(x_n) $ for $n = 1..k$ 









https://en.wikipedia.org/wiki/Gumbel_distribution

$$c = \sqrt{a^2 + b^2}$$

## Sampling from a gumbel distribution
* GIST: https://gist.github.com/ericjang/1001afd374c2c3b7752545ce6d9ed349

Footnote 1 on page 2 

In [None]:
def sample_gumbel(shape, eps=1e-20): 
    U = tf.random_uniform(shape,minval=0,maxval=1)
    return -tf.log(-tf.log(U + eps) + eps)

In [None]:
def gumbel_softmax_sample(logits, temperature): 
    y = logits + sample_gumbel(tf.shape(logits))
    return tf.nn.softmax( y / temperature)


In [None]:
def gumbel_softmax(logits, temperature, hard=False):
  """Sample from the Gumbel-Softmax distribution and optionally discretize.
  Args:
    logits: [batch_size, n_class] unnormalized log-probs
    temperature: non-negative scalar
    hard: if True, take argmax, but differentiate w.r.t. soft sample y
  Returns:
    [batch_size, n_class] sample from the Gumbel-Softmax distribution.
    If hard=True, then the returned sample will be one-hot, otherwise it will
    be a probabilitiy distribution that sums to 1 across classes
  """
  y = gumbel_softmax_sample(logits, temperature)
  if hard:
    k = tf.shape(logits)[-1]
    #y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
    y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
    y = tf.stop_gradient(y_hard - y) + y
  return y