### Kullback-Leibler Divergence

For two distributions q (the one I have) and p (the one I want to match), the KL divergence measures how different they are.  

KL(q||p) = sum_i(q_i * log(q_i/p_i))

It's always >= 0, and 0 only when q = p  
It's asymmetric: KL(q||p) != KL(p||q)

This is required to train the prior network when there are observations being passed through. The posterior network uses the observation and produces a distribution over latent states conditioned on it, while the prior network predicts the latent distribution purely from its own recurrent state (i.e., without seeing the observation). The KL divergence between these two distributions measures how far the prior's prediction is from the conditioned posterior network. 

During training, minimizing KL updates the prior's networks weights so that its predicted distribution better matches the posterior's prediction, which is basically teaching the prior to infer what the posterior would for future observations.

In [11]:
import torch

torch.manual_seed(44)
# (1) will take both logits from the posterior and prior networks
# create some mock data
postr = torch.randn(2, 4, 4)
prior = torch.randn(2, 4, 4)

# (2) apply log softmax to both

# manual implementation
def log_softmax(x, dim=-1):
    return x - torch.log(torch.sum(torch.exp(x), dim=dim, keepdim=True))

q_log = log_softmax(postr)
p_log = log_softmax(prior)

# (3) also apply .exp() to the output of log_softmax(q) to get q
q = q_log.exp()

# (4) compute kl 
# we want a value that represents how *surprised* the prior network is by the posterior's dist for that specific obs
# double sum to get this value because want (B, G, C) -> (B)
# kl ~ 0 is when the posterior ~ prior
# kl > 0 (moderate) posterior is somewhat different 
# kl > 0 (large) posterior is VERY different -> big update req 
kl = (q * (q_log - p_log)).sum(dim=-1).sum(dim=-1)



tensor([1.1031, 3.5609])

In [12]:
# use as a function and using pytorch
import torch.nn.functional as F

def _kl(posterior_logits, prior_logits):
    q_log = F.log_softmax(posterior_logits, dim=-1)       # (B, G, C)
    p_log = F.log_softmax(prior_logits, dim=-1)           # (B, G, C)
    q     = q_log.exp()
    kl    = (q * (q_log - p_log)).sum(dim=-1).sum(dim=-1) # (B)
    return kl