-
Notifications
You must be signed in to change notification settings - Fork 9.7k
Closed
Description
The loss function in examples/vae/main.py
seems to be incorrect, current implementation:
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784))
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Normalise by same number of elements as in reconstruction
KLD /= args.batch_size * 784
return BCE + KLD
The KL-divergence term can be computed analytically in case of fully factorised gaussian variational posterior and fully factorised standard gaussian prior, which is performed in code by KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
, the data term is estimated using mini-batch: BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784))
.
However normalisation KLD /= args.batch_size * 784
looks incorrect, because it should be divided by training set size - current implementation seems to incorporate alpha term before KL, which is not exact lower bound.
Metadata
Metadata
Assignees
Labels
No labels