Skip to content

Commit

Permalink
Merge pull request chainer#4009 from rezoo/update-vae
Browse files Browse the repository at this point in the history
Update VAE examples
  • Loading branch information
mitmul committed Dec 21, 2017
2 parents 0a8059e + 0ebdcb9 commit 37d6b69
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 228 deletions.
82 changes: 0 additions & 82 deletions examples/vae/data.py

This file was deleted.

2 changes: 2 additions & 0 deletions examples/vae/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,7 @@ def lf(x):
self.rec_loss = rec_loss
self.loss = self.rec_loss + \
C * gaussian_kl_divergence(mu, ln_var) / batchsize
chainer.report(
{'rec_loss': rec_loss, 'loss': self.loss}, observer=self)
return self.loss
return lf
252 changes: 106 additions & 146 deletions examples/vae/train_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,154 +3,114 @@
"""
from __future__ import print_function
import argparse

import matplotlib.pyplot as plt
import numpy as np
import six
import os

import chainer
from chainer import computational_graph
from chainer import cuda
from chainer import optimizers
from chainer import serializers
from chainer import training
from chainer.training import extensions
import numpy as np

import data
import net

parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--initmodel', '-m', default='',
help='Initialize the model from given file')
parser.add_argument('--resume', '-r', default='',
help='Resume the optimization from snapshot')
parser.add_argument('--gpu', '-g', default=-1, type=int,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--epoch', '-e', default=100, type=int,
help='number of epochs to learn')
parser.add_argument('--dimz', '-z', default=20, type=int,
help='dimention of encoded vector')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='learning minibatch size')
parser.add_argument('--test', action='store_true',
help='Use tiny datasets for quick tests')
args = parser.parse_args()

batchsize = args.batchsize
n_epoch = args.epoch
n_latent = args.dimz

print('GPU: {}'.format(args.gpu))
print('# dim z: {}'.format(args.dimz))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')

# Prepare dataset
print('load MNIST dataset')
mnist = data.load_mnist_data(args.test)
mnist['data'] = mnist['data'].astype(np.float32)
mnist['data'] /= 255
mnist['target'] = mnist['target'].astype(np.int32)

if args.test:
N = 30
else:
N = 60000

x_train, x_test = np.split(mnist['data'], [N])
y_train, y_test = np.split(mnist['target'], [N])
N_test = y_test.size

# Prepare VAE model, defined in net.py
model = net.VAE(784, n_latent, 500)
if args.gpu >= 0:
cuda.get_device_from_id(args.gpu).use()
model.to_gpu()
xp = np if args.gpu < 0 else cuda.cupy

# Setup optimizer
optimizer = optimizers.Adam()
optimizer.setup(model)

# Init/Resume
if args.initmodel:
print('Load model from', args.initmodel)
serializers.load_npz(args.initmodel, model)
if args.resume:
print('Load optimizer state from', args.resume)
serializers.load_npz(args.resume, optimizer)

# Learning loop
for epoch in six.moves.range(1, n_epoch + 1):
print('epoch', epoch)

# training
perm = np.random.permutation(N)
sum_loss = 0 # total loss
sum_rec_loss = 0 # reconstruction loss
for i in six.moves.range(0, N, batchsize):
x = chainer.Variable(xp.asarray(x_train[perm[i:i + batchsize]]))
optimizer.update(model.get_loss_func(), x)
if epoch == 1 and i == 0:
with open('graph.dot', 'w') as o:
g = computational_graph.build_computational_graph(
(model.loss, ))
o.write(g.dump())
print('graph generated')

sum_loss += float(model.loss.data) * len(x.data)
sum_rec_loss += float(model.rec_loss.data) * len(x.data)

print('train mean loss={}, mean reconstruction loss={}'
.format(sum_loss / N, sum_rec_loss / N))

# evaluation
sum_loss = 0
sum_rec_loss = 0
with chainer.no_backprop_mode():
for i in six.moves.range(0, N_test, batchsize):
x = chainer.Variable(xp.asarray(x_test[i:i + batchsize]))
loss_func = model.get_loss_func(k=10)
loss_func(x)
sum_loss += float(model.loss.data) * len(x.data)
sum_rec_loss += float(model.rec_loss.data) * len(x.data)
del model.loss
print('test mean loss={}, mean reconstruction loss={}'
.format(sum_loss / N_test, sum_rec_loss / N_test))


# Save the model and the optimizer
print('save the model')
serializers.save_npz('mlp.model', model)
print('save the optimizer')
serializers.save_npz('mlp.state', optimizer)

model.to_cpu()


# original images and reconstructed images
def save_images(x, filename):
fig, ax = plt.subplots(3, 3, figsize=(9, 9), dpi=100)
for ai, xi in zip(ax.flatten(), x):
ai.imshow(xi.reshape(28, 28))
fig.savefig(filename)


train_ind = [1, 3, 5, 10, 2, 0, 13, 15, 17]
x = chainer.Variable(np.asarray(x_train[train_ind]))
with chainer.no_backprop_mode():
x1 = model(x)
save_images(x.data, 'train')
save_images(x1.data, 'train_reconstructed')

test_ind = [3, 2, 1, 18, 4, 8, 11, 17, 61]
x = chainer.Variable(np.asarray(x_test[test_ind]))
with chainer.no_backprop_mode():
x1 = model(x)
save_images(x.data, 'test')
save_images(x1.data, 'test_reconstructed')


# draw images from randomly sampled z
z = chainer.Variable(np.random.normal(0, 1, (9, n_latent)).astype(np.float32))
x = model.decode(z)
save_images(x.data, 'sampled')

def main():
parser = argparse.ArgumentParser(description='Chainer example: VAE')
parser.add_argument('--initmodel', '-m', default='',
help='Initialize the model from given file')
parser.add_argument('--resume', '-r', default='',
help='Resume the optimization from snapshot')
parser.add_argument('--gpu', '-g', default=-1, type=int,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--epoch', '-e', default=100, type=int,
help='number of epochs to learn')
parser.add_argument('--dimz', '-z', default=20, type=int,
help='dimention of encoded vector')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='learning minibatch size')
parser.add_argument('--test', action='store_true',
help='Use tiny datasets for quick tests')
args = parser.parse_args()

print('GPU: {}'.format(args.gpu))
print('# dim z: {}'.format(args.dimz))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')

# Prepare VAE model, defined in net.py
model = net.VAE(784, args.dimz, 500)

# Setup an optimizer
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

# Initialize
if args.initmodel:
chainer.serializers.load_npz(args.initmodel, model)

# Load the MNIST dataset
train, test = chainer.datasets.get_mnist(withlabel=False)
if args.test:
train, _ = chainer.datasets.split_dataset(train, 100)
test, _ = chainer.datasets.split_dataset(test, 100)

train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False)

# Set up an updater. StandardUpdater can explicitly specify a loss function
# used in the training with 'loss_func' option
updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu,
loss_func=model.get_loss_func())

trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu,
eval_func=model.get_loss_func(k=10)))
trainer.extend(extensions.dump_graph('main/loss'))
trainer.extend(extensions.snapshot(), trigger=(args.epoch, 'epoch'))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/rec_loss', 'validation/main/rec_loss', 'elapsed_time']))
trainer.extend(extensions.ProgressBar())

if args.resume:
chainer.serializers.load_npz(args.resume, trainer)

# Run the training
trainer.run()

# Visualize the results
def save_images(x, filename):
import matplotlib.pyplot as plt
fig, ax = plt.subplots(3, 3, figsize=(9, 9), dpi=100)
for ai, xi in zip(ax.flatten(), x):
ai.imshow(xi.reshape(28, 28))
fig.savefig(filename)

model.to_cpu()
train_ind = [1, 3, 5, 10, 2, 0, 13, 15, 17]
x = chainer.Variable(np.asarray(train[train_ind]))
with chainer.using_config('train', False), chainer.no_backprop_mode():
x1 = model(x)
save_images(x.data, os.path.join(args.out, 'train'))
save_images(x1.data, os.path.join(args.out, 'train_reconstructed'))

test_ind = [3, 2, 1, 18, 4, 8, 11, 17, 61]
x = chainer.Variable(np.asarray(test[test_ind]))
with chainer.using_config('train', False), chainer.no_backprop_mode():
x1 = model(x)
save_images(x.data, os.path.join(args.out, 'test'))
save_images(x1.data, os.path.join(args.out, 'test_reconstructed'))

# draw images from randomly sampled z
z = chainer.Variable(
np.random.normal(0, 1, (9, args.dimz)).astype(np.float32))
x = model.decode(z)
save_images(x.data, os.path.join(args.out, 'sampled'))


if __name__ == '__main__':
main()
Loading

0 comments on commit 37d6b69

Please sign in to comment.