diff --git a/vae/main.py b/vae/main.py index c45382d3c1..46ab26cfe5 100644 --- a/vae/main.py +++ b/vae/main.py @@ -58,11 +58,11 @@ def encode(self, x): def reparameterize(self, mu, logvar): if self.training: - std = logvar.mul(0.5).exp_() - eps = Variable(std.data.new(std.size()).normal_()) - return eps.mul(std).add_(mu) + std = logvar.mul(0.5).exp_() + eps = Variable(std.data.new(std.size()).normal_()) + return eps.mul(std).add_(mu) else: - return mu + return mu def decode(self, z): h3 = self.relu(self.fc3(z)) @@ -129,10 +129,10 @@ def test(epoch): recon_batch, mu, logvar = model(data) test_loss += loss_function(recon_batch, data, mu, logvar).data[0] if i == 0: - n = min(data.size(0), 8) - comparison = torch.cat([data[:n], + n = min(data.size(0), 8) + comparison = torch.cat([data[:n], recon_batch.view(args.batch_size, 1, 28, 28)[:n]]) - save_image(comparison.data.cpu(), + save_image(comparison.data.cpu(), 'results/reconstruction_' + str(epoch) + '.png', nrow=n) test_loss /= len(test_loader.dataset) @@ -144,7 +144,7 @@ def test(epoch): test(epoch) sample = Variable(torch.randn(64, 20)) if args.cuda: - sample = sample.cuda() + sample = sample.cuda() sample = model.decode(sample).cpu() save_image(sample.data.view(64, 1, 28, 28), 'results/sample_' + str(epoch) + '.png')