diff --git a/vae/main.py b/vae/main.py index 19d6e53c38..370ad579ce 100644 --- a/vae/main.py +++ b/vae/main.py @@ -9,9 +9,9 @@ parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=128, metavar='N', - help='input batch size for training (default: 64)') + help='input batch size for training (default: 128)') parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='number of epochs to train (default: 2)') + help='number of epochs to train (default: 10)') parser.add_argument('--no-cuda', action='store_true', default=False, help='enables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', @@ -56,11 +56,7 @@ def encode(self, x): def reparametrize(self, mu, logvar): std = logvar.mul(0.5).exp_() - if args.cuda: - eps = torch.cuda.FloatTensor(std.size()).normal_() - else: - eps = torch.FloatTensor(std.size()).normal_() - eps = Variable(eps) + eps = Variable(std.data.new(std.size()).normal_()) return eps.mul(std).add_(mu) def decode(self, z): @@ -82,7 +78,7 @@ def forward(self, x): def loss_function(recon_x, x, mu, logvar): - BCE = reconstruction_function(recon_x, x) + BCE = reconstruction_function(recon_x, x.view(-1, 784)) # see Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014