From ebf36143f432d7cf2a1d0407f15680b37e7fe600 Mon Sep 17 00:00:00 2001 From: Kai Arulkumaran Date: Thu, 1 Feb 2018 23:19:42 -0500 Subject: [PATCH] Fix VAE losses (sum over everything) Closes https://github.com/pytorch/examples/issues/234 and https://github.com/pytorch/examples/issues/290 --- vae/main.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vae/main.py b/vae/main.py index 46ab26cfe5..9c641e6470 100644 --- a/vae/main.py +++ b/vae/main.py @@ -77,25 +77,22 @@ def forward(self, x): model = VAE() if args.cuda: model.cuda() +optimizer = optim.Adam(model.parameters(), lr=1e-3) +# Reconstruction + KL divergence losses summed over all elements and batch def loss_function(recon_x, x, mu, logvar): - BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784)) + BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False) # see Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) - # Normalise by same number of elements as in reconstruction - KLD /= args.batch_size * 784 return BCE + KLD -optimizer = optim.Adam(model.parameters(), lr=1e-3) - - def train(epoch): model.train() train_loss = 0