MultiGPU enabled image generative models (GAN and DCGAN)
Python
Latest commit ca06392 Feb 8, 2017 @tqchen committed on GitHub Merge pull request #4 from tornadomeet/master
fix viz.py with python3
Permalink
Failed to load latest commit information.
example fix viz Oct 16, 2016
mxgan use python3 Feb 8, 2017
.gitignore checkin gan Jun 17, 2016
LICENSE Initial commit Jun 17, 2016
README.md Add minibatch discrimnation layer Jun 23, 2016

README.md

MXNet GAN

MXNet module implementation of multi GPU compatible generative models.

List of Methods

  • Unsupervised Training
  • Semisupervised Training
  • Minibatch discrimation

Usage

import logging
import numpy as np
import mxnet as mx

from mxgan import module, generator, encoder, viz

def ferr(label, pred):
    pred = pred.ravel()
    label = label.ravel()
    return np.abs(label - (pred > 0.5)).sum() / label.shape[0]

lr = 0.0005
beta1 = 0.5
batch_size = 100
rand_shape = (batch_size, 100)
num_epoch = 100
data_shape = (batch_size, 1, 28, 28)
context = mx.gpu()

logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s')
sym_gen = generator.dcgan28x28(oshape=data_shape, ngf=32, final_act="sigmoid")

gmod = module.GANModule(
    sym_gen,
    symbol_encoder=encoder.lenet(),
    context=context,
    data_shape=data_shape,
    code_shape=rand_shape)

gmod.init_params(mx.init.Xavier(factor_type="in", magnitude=2.34))

gmod.init_optimizer(
    optimizer="adam",
    optimizer_params={
        "learning_rate": lr,
        "wd": 0.,
        "beta1": beta1,
})

data_dir = './../../mxnet/example/image-classification/mnist/'
train = mx.io.MNISTIter(
    image = data_dir + "train-images-idx3-ubyte",
    label = data_dir + "train-labels-idx1-ubyte",
    input_shape = data_shape[1:],
    batch_size = batch_size,
    shuffle = True)

metric_acc = mx.metric.CustomMetric(ferr)

for epoch in range(num_epoch):
    train.reset()
    metric_acc.reset()
    for t, batch in enumerate(train):
        gmod.update(batch)
        gmod.temp_label[:] = 0.0
        metric_acc.update([gmod.temp_label], gmod.outputs_fake)
        gmod.temp_label[:] = 1.0
        metric_acc.update([gmod.temp_label], gmod.outputs_real)

        if t % 100 == 0:
            logging.info("epoch: %d, iter %d, metric=%s", epoch, t, metric_acc.get())
            viz.imshow("gout", gmod.temp_outG[0].asnumpy(), 2)
            diff = gmod.temp_diffD[0].asnumpy()
            diff = (diff - diff.mean()) / diff.std() + 0.5
            viz.imshow("diff", diff)
            viz.imshow("data", batch.data[0].asnumpy(), 2)