In [39]:
import torch

q = torch.distributions.Normal(1, 2)
#sample with reparameterization trick
z = q.rsample()
#print(z)

p = torch.distributions.Normal(0, 1)
#q = torch.distributions.Normal(2, 4)

print(f"z size: {z.size()}")

log_pz = p.log_prob(z)
log_qzx = q.log_prob(z)

print(f"log pz size: {log_pz.size()}")

#print('log prob pz: ', log_pz, 'prob:', torch.exp(log_pz))
#print('log prob qzx: ', log_qzx, 'prob:', torch.exp(log_qzx))

kl_divergence = torch.mean(log_pz - log_qzx, 0)
print(kl_divergence)

z size: torch.Size([])
log pz size: torch.Size([])
tensor(0.7183)


We make use of Variational Inference to approximate the "true" posterior of the latent space $p(z|x)$ with an auxiliary distribution $q(z|x)$ by minimizing the KL Divergence between these two distribution:
$$ D_{KL}(q(z|x)||p(z|x)) = \sum q(z|x) \log \frac{q(z|x)}{p(z|x)} $$

We use a neural network to parameterize $q(z|x)$ such that we can minimize the $D_{KL}$ by training the network. $q(z|x)$ becomes our encoder that outputs the parameter $\theta$ of our distribution $q(z|x) = q_{\theta}(z|x)$. 

Recap the posterior is defined as (Bayes´ Rule):

$$
p(z|x) = \frac{Likelihood * Prior}{Marginal Distribution}=\frac{p(x|z)p(z)}{p(x)}
$$

We can derive the ELBO Loss from $D_{KL}(q(z|x)||p(z|x))$ which will be our overall loss function for our VAE:
\begin{align}
D_{KL}(q(z|x)||p(z|x)) &= \sum q(z|x) \log \frac{q(z|x)}{p(z|x)}\\
&= -\sum q(z|x) \log \frac{p(z|x)}{q(z|x)}\\
&= -\sum q(z|x)  \left[\log p(z|x) - \log q(z|x) \right]\\
&= -\sum q(z|x) \left[\log \frac{p(x|z)p(z)}{p(x)} - \log q(z|x) \right]\\
&= -\sum q(z|x) \left[\log p(x|z) + \log p(z) - \log p(x) - \log q(z|x) \right]\\
&= \sum q(z|x) \left[-\log p(x|z) - \log p(z) + \log p(x) + \log q(z|x) \right]\\
D_{KL}(q(z|x)||p(z|x)) - \sum q(z|x) \left[-\log p(x|z) - \log p(z) + \log q(z|x) \right] &= \log p(x)\\
-\sum q(z|x) \left[-\log p(x|z) + \log \frac{q(z|x)}{p(z)} \right] &= \\
-\mathbb{E}_{z \sim q(z|x)} \left[-\log p(x|z) + \log \frac{q(z|x)}{p(z)} \right] &=\\
-\mathbb{E}_{q} \left[-\log p(x|z) + \log \frac{q(z|x)}{p(z)} \right] &=\\
\mathbb{E}_{q} \left[\log p(x|z) - \log \frac{q(z|x)}{p(z)} \right] &=\\
D_{KL}(q(z|x)||p(z|x)) + \left[ \mathbb{E}_{q} \log p(x|z)  - \mathbb{E}_{q}  \log \frac{q(z|x)}{p(z)} \right] &= \log p(x)\\
D_{KL}(q(z|x)||p(z|x)) + Variational Lower Bound (ELBO) &= \log p(x) \\
ELBO &= \log p(x) - D_{KL}(q(z|x)||p(z|x)) \\
\left[ \mathbb{E}_{q} \log p(x|z)  - \mathbb{E}_{q}  \log \frac{q(z|x)}{p(z)} \right] &= \log p(x) - D_{KL}(q(z|x)||p(z|x)) \\
\left[ \mathbb{E}_{q} \log p(x|z)  - \mathbb{E}_{q}[ \log q(z|x) - \log p(z)] \right] &= \log p(x) - D_{KL}(q(z|x)||p(z|x)) \\
\left[ \mathbb{E}_{q} \log p(x|z)  - D_{KL}(q(x|z)||p(z)) \right] &= \log p(x) - D_{KL}(q(z|x)||p(z|x))   \leq \log p(x) \\
\end{align}

So if we maximize ELBO this corresponding to maximizing the log probability of generating real data samples $\log p(x)$ (lower bound) which is reduced by the divergence of our "true" and approximate distribution. Maximizing ELBO is same as minizing the negative ELBO:
$$ max(ELBO) = max \left( \mathbb{E}_{q} \log p(x|z)  - D_{KL}(q(x|z)||p(z)) \right) 
             = min \left((D_{KL}(q(x|z)||p(z)) - \mathbb{E}_{q} \log p(x|z) \right) = min(-ELBO) $$

In [None]:
def kl_divergence(z, mu, std):
    """
    Monte Carlo KL Divergence 
    """
    #TODO: Change p from n indep. Gaussians to one Multivariant Gaussian with mu and covariance matrix
    
    # Standard Gaussian as target distribution
    p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
    # Auxiliary distribution
    q = torch.distributions.Normal(mu, std)
    
    
    # KL Divergence according to our ELBO derivation
    kl = (q.log_prob(z) - p.log_prob(z))
    
    # trick to sum over last dimension as sampling from multivariant gaussian 
    return kl.sum(-1)



Reconstruction Loss ...

In [None]:
def reconstruction_loss(x_hat, logscale, x):
    """
    Reconstruction Loss as Regularization
    """
    scale = torch.exp(logscale)
    mean = x_hat
    dist = torch.distributions.Normal(mean, scale)
    
    return dist.log_prob(x).sum(dim=(1,2,3))