In [1]:
import numpy
import matplotlib.pyplot as plt

import tensorflow as tf

tfe = tf.contrib.eager
tfs = tf.contrib.summary
tfs_logger = tfs.record_summaries_every_n_global_steps

import sonnet as snt

from tqdm import tqdm_notebook as tqdm

tf.enable_eager_execution()


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



In [2]:
class Discriminator(snt.AbstractModule):
    
    def __init__(self, name='mnist_discriminator'):
        
        super(Discriminator, self).__init__(name=name)
        
    
    def _build(self, inputs):
        
        # --------------------------------------------
        # Define layers
        # --------------------------------------------
        
        conv1 = snt.Conv2D(output_channels=64,
                           kernel_shape=(5, 5),
                           stride=(1, 1),
                           padding='VALID')
        
        conv2 = snt.Conv2D(output_channels=64,
                           kernel_shape=(5, 5),
                           stride=(1, 1),
                           padding='VALID')
        
        flatten = snt.BatchFlatten()
        
        fc = snt.Linear(output_size=1)
        
        # --------------------------------------------
        # Apply layers
        # --------------------------------------------
        
        activations = tf.nn.relu(conv1(inputs))
        activations = tf.nn.relu(conv2(activations))
        activations = fc(flatten(activations))
        
        return activations
    
    

class Generator(snt.AbstractModule):
    
    def __init__(self, num_inputs, name='mnist_generator'):
        
        super(Generator, self).__init__(name=name)
        
        self._num_inputs = num_inputs
        self._is_training = True
        
    
    def _build(self, inputs):
        
        # --------------------------------------------
        # Define layers
        # --------------------------------------------
        
        fc1 = snt.Linear(output_size=196)
        bn1 = snt.BatchNorm()
        
        fc2 = snt.Linear(output_size=196)
        bn2 = snt.BatchNorm()
        
        reshape = snt.BatchReshape((14, 14, 1))
        
        deconv1 = snt.Conv2DTranspose(output_channels=32,
                                      output_shape=(28, 28),
                                      kernel_shape=(5, 5),
                                      stride=(2, 2),
                                      padding='SAME')
        bn3 = snt.BatchNorm()
        
        deconv2 = snt.Conv2DTranspose(output_channels=1,
                                      output_shape=(28, 28),
                                      kernel_shape=(5, 5),
                                      stride=(1, 1),
                                      padding='SAME')
        
        # --------------------------------------------
        # Apply layers
        # --------------------------------------------
        
        activations = bn1(tf.nn.leaky_relu(fc1(inputs)), is_training=self._is_training)
        activations = bn2(tf.nn.leaky_relu(fc2(activations)), is_training=self._is_training)
        activations = reshape(activations)
        
        activations = bn3(tf.nn.leaky_relu(deconv1(activations)), is_training=self._is_training)
        activations = tf.nn.sigmoid(deconv2(activations))
        
        return activations

In [3]:
def mnist_input_fn(data, batch_size, shuffle_buffer=5000):
    
    dataset = tf.data.Dataset.from_tensor_slices(data)
    dataset = dataset.map(mnist_process_data)
    dataset = dataset.shuffle(shuffle_buffer)
    dataset = dataset.batch(batch_size)
    
    return dataset

    
def mnist_process_data(image):
    
    processed = tf.cast(image, tf.float32) / 255
    
    return processed[..., tf.newaxis]

In [4]:
# Define constants
num_epochs = 50
num_discriminator_steps = 1
log_frequency = 10

learn_rate = 1e-4

num_inputs = 100
batch_size = 256

# Load data
((train_data, _),
 (test_data, _)) = tf.keras.datasets.mnist.load_data()

train_dataset = mnist_input_fn(train_data, batch_size=batch_size)
num_batches = len(train_data) // batch_size

# Create networks
disc = Discriminator()
gen = Generator(num_inputs=num_inputs)

# Define optimiser
optimizer = tf.train.AdamOptimizer(learn_rate)

# Define tensorflow summary stuff
global_step = tf.train.get_or_create_global_step()
log_dir = '/tmp/mnist_gan/log'
writer = tfs.create_file_writer(log_dir)
writer.set_as_default()

In [5]:
for epoch in range(num_epochs):
    k = 0
    for batch in tqdm(train_dataset.take(num_batches), total=num_batches):
        
        # Increment global step
        global_step.assign_add(1)
        
        with tfs_logger(log_frequency):

            # Update discriminator
            if k < num_discriminator_steps:

                with tf.GradientTape() as tape:
                    z = tf.random.normal(mean=0, stddev=1, shape=(batch_size, num_inputs))
                    gen_imgs = gen(z)

                    gen_pred = disc(gen_imgs)
                    true_pred = disc(batch)

                    gen_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(gen_pred),
                                                                       logits=gen_pred)

                    true_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(true_pred),
                                                                        logits=true_pred)

                    loss = tf.reduce_mean(gen_loss) + tf.reduce_mean(true_loss)

                grads = tape.gradient(loss, disc.get_all_variables())
                optimizer.apply_gradients(zip(grads, disc.get_all_variables()))
                
                tfs.scalar('Discriminator_Loss', loss)

                k += 1

            # Update generator
            else:

                with tf.GradientTape() as tape:
                    z = tf.random.normal(mean=0, stddev=1, shape=(batch_size, num_inputs))
                    gen_imgs = gen(z)

                    gen_pred = disc(gen_imgs)

                    gen_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(gen_pred),
                                                                       logits=gen_pred)

                    loss = tf.reduce_mean(gen_loss)

                grads = tape.gradient(loss, gen.get_all_variables())
                optimizer.apply_gradients(zip(grads, gen.get_all_variables()))
                
                tfs.scalar('Generator_Loss', loss)

                k = 0
            
            tfs.image('Generated_images', gen_imgs)

HBox(children=(IntProgress(value=0, max=234), HTML(value='')))

Instructions for updating:
Colocations handled automatically by placer.



HBox(children=(IntProgress(value=0, max=234), HTML(value='')))




HBox(children=(IntProgress(value=0, max=234), HTML(value='')))




HBox(children=(IntProgress(value=0, max=234), HTML(value='')))




HBox(children=(IntProgress(value=0, max=234), HTML(value='')))

KeyboardInterrupt: 