In [1]:
import logging
logging.getLogger().setLevel(logging.INFO)

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib inline

import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import mxnet as mx

In [2]:
N_GPUS = 8
mb_size = 60#*N_GPUS

In [3]:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [4]:
def to4d(img):
    """
    reshape to 4D arrays
    """
    return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255

def get_mnist_iter():
    """
    create data iterator with NDArrayIter
    """
    (train_lbl, train_img) = (mnist.train.labels, mnist.train.images)
    (val_lbl, val_img) = (mnist.test.labels, mnist.test.images)
    train_img = (train_img-0.5)*2
    val_img = (val_img-0.5)*2
    
    train = mx.io.NDArrayIter(
        to4d(train_img), train_lbl, mb_size, shuffle=True)
    val = mx.io.NDArrayIter(
        to4d(val_img), val_lbl, mb_size)
    return (train, val)

In [5]:
class RandIter(mx.io.DataIter):
    def __init__(self, batch_size, ndim):
        self.batch_size = batch_size
        self.ndim = ndim
        self.provide_data = [('z', (batch_size, ndim))]
        self.provide_label = []

    def iter_next(self):
        return True

    def getdata(self):
        #Returns random numbers from a gaussian (normal) distribution 
        #with mean=0 and standard deviation = 1
        return [mx.random.normal(0, 1.0, shape=(self.batch_size, self.ndim))]

In [6]:
def fc_simple():
    data = mx.sym.Variable('data')
    fc1 = mx.sym.FullyConnected(data, name='fc1', num_hidden=128)
    act1 = mx.sym.Activation(fc1, name='relu1', act_type="relu")
    fc2 = mx.sym.FullyConnected(act1, name='fc2', num_hidden=10)
    fc_do = mx.sym.Dropout(fc2, p=0.5)
    out = mx.sym.SoftmaxOutput(fc_do, name = 'softmax')
    return out

In [7]:
def conv():
    data = mx.sym.Variable('data')
    conv1 = mx.sym.Convolution(
        data,
        name='conv1',
        kernel=(5,5),
        stride=(1,1),
        num_filter=128,
    )
    act1 = mx.sym.LeakyReLU(conv1, name='lrelu1')
    
    conv2 = mx.sym.Convolution(
        mx.sym.BatchNorm(act1),
        name='conv2',
        kernel=(5,5),
        stride=(2,2),
        num_filter=256,
    )
    act2 = mx.sym.LeakyReLU(conv2, name='lrelu2')
    
    conv3 = mx.sym.Convolution(
        mx.sym.BatchNorm(act2),
        name='conv3',
        kernel=(5,5),
        stride=(2,2),
        num_filter=512,
    )
    act3 = mx.sym.LeakyReLU(conv3, name='lrelu3')
    
    fc = mx.sym.FullyConnected(act3, name='fc', num_hidden=10)
    
    fc_do = mx.sym.Dropout(fc, p=0.5)
    
    out = mx.sym.SoftmaxOutput(fc_do, name = 'softmax')
    
    return out

In [8]:
def get_discriminator():
    data = mx.sym.Variable('data')
    conv1 = mx.sym.Convolution(
        data,
        name='conv1',
        kernel=(5,5),
        stride=(1,1),
        num_filter=128,
    )
    conv1_bn = mx.sym.BatchNorm(conv1, name='conv1_bn')
    act1 = mx.sym.LeakyReLU(conv1_bn, name='lrelu1')
    
    conv2 = mx.sym.Convolution(
        act1,
        name='conv2',
        kernel=(5,5),
        stride=(2,2),
        pad=(2,2),
        num_filter=256,
    )
    conv2_bn = mx.sym.BatchNorm(conv2, name='conv2_bn')
    act2 = mx.sym.LeakyReLU(conv2_bn, name='lrelu2')
    
    conv3 = mx.sym.Convolution(
        act2,
        name='conv3',
        kernel=(5,5),
        stride=(2,2),
        pad=(2,2),
        num_filter=512,
    )
    conv3_bn = mx.sym.BatchNorm(conv3, name='conv3_bn')
    act3 = mx.sym.LeakyReLU(conv3_bn, name='lrelu3')
    
    conv4 = mx.sym.Convolution(
        act3,
        name='conv4',
        kernel=(5,5),
        num_filter=1,
    )
    flat_bn = mx.sym.flatten(mx.sym.BatchNorm(conv4), name='conv4_flat_bn')
    
    fc = mx.sym.FullyConnected(flat_bn, num_hidden=1, name='fc_output')
    label = mx.sym.Variable('label')
    out = mx.sym.LogisticRegressionOutput(data=fc, label=label, name='LogRegOut')
    
    return out

In [9]:
def get_generator():
    z = mx.sym.Variable('z')
    g_fc = mx.sym.FullyConnected(
        z,
        num_hidden=7*7*512,
    )
    g_fc_reshape = mx.symbol.reshape(
        g_fc,
        shape=(-1,512,7,7)
    )
    gbn1 = mx.sym.BatchNorm(g_fc_reshape, name='gbn1')
    gact1 = mx.sym.Activation(gbn1, name='gact1', act_type='relu')
    
    g2 = mx.sym.Deconvolution(
        gact1, 
        name='g2', 
        kernel=(5,5), 
        stride=(2,2), 
        target_shape=(14,14),
        num_filter=256,
    )
    gbn2 = mx.sym.BatchNorm(g2, name='gbn2')
    gact2 = mx.sym.Activation(gbn2, name='gact2', act_type='relu')

    g3 = mx.sym.Deconvolution(
        gact2, 
        name='g3', 
        kernel=(5,5), 
        stride=(2,2), 
        target_shape=(28,28),
        num_filter=128
    )
    gbn3 = mx.sym.BatchNorm(g3, name='gbn3')
    gact3 = mx.sym.Activation(gbn3, name='gact3', act_type='relu')

    g4 = mx.sym.Deconvolution(
        gact3, 
        name='g4', 
        kernel=(5,5), 
        stride=(1,1), 
        target_shape=(28,28), 
        num_filter=1
    )
    gbn4 = mx.sym.BatchNorm(g4, name='gbn4')
    gact4 = mx.sym.Activation(gbn4, name='gan_output', act_type='tanh')
    
    return gact4

In [10]:
discriminatorSymbol = get_discriminator()
generatorSymbol = get_generator()

In [11]:
#mx.viz.plot_network(generator)

In [12]:
arg, d_out_shape, aux = discriminatorSymbol.infer_shape(data=(1000,1,28,28), label=(1000,1))
d_out_shape

[(1000, 1)]

In [13]:
arg, g_out_shape, aux = generatorSymbol.infer_shape(z=(64,100))
g_out_shape

[(64, 1, 28, 28)]

In [14]:
z_dim = 100
noise_iter = RandIter(mb_size, z_dim)
train, val = get_mnist_iter()

# Build Modules

In [15]:
ctx = [mx.gpu(i) for i in range(N_GPUS)]

## Generator

In [22]:
#lr = 0.0002
lr = 0.0005
beta1 = 0.5

generator = mx.module.Module(generatorSymbol, data_names=('z',), label_names=None, context=ctx)
generator.bind(data_shapes=noise_iter.provide_data)
generator.init_params(initializer=mx.init.Xavier())
generator.init_optimizer(
    optimizer='adam',
    optimizer_params={
        'learning_rate': lr,
        'beta1': beta1,
    })
mods = [generator]

## Discriminator

In [23]:
discriminator = mx.module.Module(discriminatorSymbol, data_names=('data',), label_names=('label',), context=ctx)
discriminator.bind(
    data_shapes=train.provide_data,
    label_shapes=[('label', (mb_size,))],
    inputs_need_grad=True
)
discriminator.init_params(initializer=mx.init.Xavier())
discriminator.init_optimizer(
    optimizer='adam',
    optimizer_params={
        'learning_rate': lr,
        'beta1': beta1,
    })
mods.append(discriminator)

In [24]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    #return fig

In [27]:
print('Training...')
for epoch in range(10):
    train.reset()
    for i, batch in enumerate(train):
        #Get a batch of random numbers to generate an image from the generator
        zbatch = noise_iter.next()
        #Forward pass on training batch
        generator.forward(zbatch, is_train=True)
        #Output of training batch is the 64x64x3 image
        outG = generator.get_outputs()
        
        #Pass the generated (fake) image through the discriminator, and save the gradient
        #Label (for logistic regression) is an array of 0's since this image is fake
        label = mx.nd.zeros((mb_size,), ctx=mx.gpu(0))
        #Forward pass on the output of the discriminator network
        discriminator.forward(mx.io.DataBatch(outG, [label]), is_train=True)
        #Do the backwards pass and save the gradient
        discriminator.backward()
        gradD = [[grad.copyto(grad.context) for grad in grads] for grads in discriminator._exec_group.grad_arrays]
        
        #Pass a batch of real images from MNIST through the discriminator
        #Set the label to be an array of 1's because these are the real images
        label[:] = 1
        batch.label = [label]
        #Forward pass on a batch of MNIST images
        discriminator.forward(batch, is_train=True)
        #Do the backwards pass and add the saved gradient from the fake images to the gradient 
        #generated by this backwards pass on the real images
        discriminator.backward()
        for gradsr, gradsf in zip(discriminator._exec_group.grad_arrays, gradD):
            for gradr, gradf in zip(gradsr, gradsf):
                gradr += gradf
        #Update gradient on the discriminator 
        discriminator.update()

        #Now that we've updated the discriminator, let's update the generator
        #First do a forward pass and backwards pass on the newly updated discriminator
        #With the current batch
        discriminator.forward(mx.io.DataBatch(outG, [label]), is_train=True)
        discriminator.backward()
        #Get the input gradient from the backwards pass on the discriminator,
        #and use it to do the backwards pass on the generator
        diffD = discriminator.get_input_grads()
        generator.backward(diffD)
        #Update the gradients on the generator
        generator.update()
        
        #Increment to the next batch, printing every 50 batches
        #i += 1
        if i % 1000 == 0:
            print('epoch:', epoch, 'iter:', i)
            #plot(outG[0][:16].asnumpy())

Training...
epoch: 0 iter: 0
epoch: 1 iter: 0
epoch: 2 iter: 0
epoch: 3 iter: 0
epoch: 4 iter: 0
epoch: 5 iter: 0
epoch: 6 iter: 0
epoch: 7 iter: 0
epoch: 8 iter: 0
epoch: 9 iter: 0


In [None]:
plot(outG[0][:16].asnumpy())

In [None]:
#score = mod.score(val, ['acc'])
#print("Accuracy score is %f" % (score[0][1]))