# Learning Gumbel-Softmax

DreamerV3 compresses images into discrete representations. This means that the image is encoded into a number of categorical groups, each with a certain number of possible values that it can take. For example, a category can be color, shape, and/or size. If we take the color category, then the possible values are red, green, blue, yellow. A one-hot vector is used to represent the encoded choice, e.g., [1,0,0,0] for red. 

An issues arises when sampling a categorical variable though. It would typically require the argmax of the probabilities (softmaxed) for the choice and then to represent this choice as a one-hot vector, but torch.argmax isn't differentiable so backprop can't run through it in training. Gumbel-Softmax resolves this issue. 

How does it work?

Taking argmax always gives the most likely class, which is deterministic, something we're not after because we are including uncertainty (stochastic state). So we need to inject *Gumbel noise* before this argmax to add some uncertainty. But we can't use argmax, so softmax is used, a different formula that contains only exponentials, additions, and divisions, all differentiable operations. However, since we're adding uncertainty to this, we used what is called the Gumbel-Softmax. The main difference is that there is the added noise and a variable called *temperature*. The noise is to add randomness but the temperature is to get the argmax behavior to get the one-hot vectors. Basically, as the temperature approaches 0, the closer we get to a one-hot vector. And as temperature approaches neg infinity, the outputs become close to uniform. 

How gumbel_softmax works in pytorch:  
Pass the logits, the temperature, and hard=True or False. If hard is set to True, it returns a one-hot vector (like argmax) and the gradients from the softmax version. If hard is set to False, it won't return the one-hot vector, only the purely softmaxed output probabilities. 

In [28]:
# manual implementation of gumbel_softmax

import torch

torch.manual_seed(44)

# sample logits
logits = torch.tensor([2.0, 0.5, -1.0])

# temperature (tau)
tau = 0.5

shape = logits.shape

# Gumbel noise sampler
# g = -log(-log(u)) with u ~ Uniform(0, 1)
eps = 1e-20 # to prevent blowing up, e.g., log(0) -> neg inf; 1e-20 effectively is zero at float precision
u = torch.rand(shape) # rand samples from Uniform[0, 1)
g = -torch.log(-torch.log(u + eps) + eps)

# softmax_x_i = exp(x_i) / sum_j(exp(x_j))
soft = torch.softmax((logits + g) / tau, dim=-1)

# hard
idx = torch.argmax(soft, dim=-1, keepdim=True) # gets indices from the last dim and keeps the dimension
hard = torch.zeros_like(soft).scatter_(dim=-1, index=idx, value=1.0) # create a tensor of same size and modify in place

# THE important step for how gumbel-softmax works in pytorch
# when passing gradients through, only 'soft' has a gradient
# soft.detach() is treated as a constant because of detach()
# hard doesn't have a gradient because of argmax
# so only 'soft' passes its gradient through
# resolves the issue by giving a one hot vector in the forward pass for the discrete representation and training the model through the soft calculation
y = hard - soft.detach() + soft


In [29]:
import torch.nn as nn
import torch.nn.functional as F
# how it works in the RSSM:

# dummy settings
B = 8
num_groups  = 4
num_classes = 4
size_deterministic      = 64
size_obs_embed          = 32
size_stochastic_flat    = num_groups * num_classes
# dummy inputs
deterministic_state_t   = torch.randn(B, size_deterministic)
observation_embedding_t = torch.randn(B, size_obs_embed)
# posterior head in RSSM
# takes in obs embedding and the hidden state and produces logits of the size num_groups * num_classes
posterior_head = nn.Linear(size_deterministic + size_obs_embed, size_stochastic_flat)
joined = torch.cat([deterministic_state_t, observation_embedding_t], dim=-1)
group_class_logits = posterior_head(joined)
# gumbel softmax in pytorch expects 2D tensor
# so multiply B * G to get the number of independent categorical distributions, each with C classes per distribution
# allows the gumbel_softmax to treat each (B, G) pair independently and samples a one hot vector for each. 
x = group_class_logits.view(B*num_groups, num_classes)
onehots = F.gumbel_softmax(x, tau=0.5, hard=True, dim=-1)
out = onehots.view(B, num_groups, num_classes)
out.shape, out[0, 0]







(torch.Size([8, 4, 4]), tensor([0., 0., 1., 0.], grad_fn=<SelectBackward0>))