Skip to content

VAE loss function #294

@alexturn

Description

@alexturn

The loss function in examples/vae/main.py seems to be incorrect, current implementation:

def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784))

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    # Normalise by same number of elements as in reconstruction
    KLD /= args.batch_size * 784

    return BCE + KLD 

screen shot 2018-01-30 at 14 38 35

The KL-divergence term can be computed analytically in case of fully factorised gaussian variational posterior and fully factorised standard gaussian prior, which is performed in code by KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()), the data term is estimated using mini-batch: BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784)).

However normalisation KLD /= args.batch_size * 784 looks incorrect, because it should be divided by training set size - current implementation seems to incorporate alpha term before KL, which is not exact lower bound.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions