# Generate MNIST digits using GAN

In [2]:
from keras import layers as L
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model, Sequential
from keras.optimizers import Adam
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook

from keras import backend as K
print(K.tensorflow_backend._get_available_gpus())

['/job:localhost/replica:0/task:0/device:GPU:0']


In [4]:
# Some constants
img_height = 64
img_width = 64
img_channels = 3
img_shape = (img_height, img_width, img_channels)

In [5]:
# Loss functions
from keras.applications.vgg16 import VGG16
from keras.models import Model

def perceptual_loss(y_true, y_pred):
    vgg = VGG16(include_top=False, weights='imagenet', input_shape=image_shape)
    loss_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
    loss_model.trainable = False
    return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))


def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true*y_pred)

In [None]:
class GAN():
    def __init__(self):
        self.img_rows = 64
        self.img_cols = 64
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 200
        self.batch_size = 64
        
        self.datagen = ImageDataGenerator( 
                featurewise_center=True,
                featurewise_std_normalization=True,
                rotation_range=20,
                width_shift_range=0.2,
                height_shift_range=0.2,
                horizontal_flip=True,
                shear_range=0.2
            ).flow_from_directory(
                    '../input/all-dogs/',
                    target_size=(self.img_rows, self.img_cols),
                    batch_size=self.batch_size)

        optimizer = Adam(0.0002, 0.9, 0.9, 1e-08)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss=wasserstein_loss,
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = L.Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, [img, validity])
        self.combined.compile(loss=perceptual_loss, optimizer=optimizer)


    def build_generator(self):

        model = Sequential()

        model.add(L.Dense(256, input_dim=self.latent_dim))
        model.add(L.advanced_activations.LeakyReLU(alpha=0.2))
        model.add(L.BatchNormalization(momentum=0.8))
        model.add(L.Dense(512))
        model.add(L.advanced_activations.LeakyReLU(alpha=0.2))
        model.add(L.BatchNormalization(momentum=0.8))
        model.add(L.Dense(1024))
        model.add(L.advanced_activations.LeakyReLU(alpha=0.2))
        model.add(L.BatchNormalization(momentum=0.8))
        model.add(L.Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(L.Reshape(self.img_shape))

        model.summary()

        noise = L.Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        model = Sequential()

        model.add(L.Flatten(input_shape=self.img_shape))
        model.add(L.Dense(512))
        model.add(L.advanced_activations.LeakyReLU(alpha=0.2))
        model.add(L.Dense(256))
        model.add(L.advanced_activations.LeakyReLU(alpha=0.2))
        model.add(L.Dense(1, activation='sigmoid'))
        model.summary()

        img = L.Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):
        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):
            for idx in tqdm_notebook(range(len(self.datagen)), total=len(self.datagen)):
                batch_size = len(self.datagen[idx][0])
                # Adversarial ground truths
                valid = np.ones((batch_size, 1))
                fake = np.zeros((batch_size, 1))
                
                # ---------------------
                #  Train Discriminator
                # ---------------------

                # Select a batch of images
                imgs = self.datagen[idx][0]
                imgs = (imgs - 127.5) / 127.5

                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

                # Generate a batch of new images
                gen_imgs = self.generator.predict(noise)

                # Train the discriminator
                d_loss_real = self.discriminator.train_on_batch(imgs, valid)
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                # ---------------------
                #  Train Generator
                # ---------------------

                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

                # Train the generator (to have the discriminator label samples as valid)
                g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 2, 2
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images
        gen_imgs = ((gen_imgs+1)*127.5).astype(np.uint8)

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0])
                axs[i,j].axis('off')
                cnt += 1
        # fig.savefig("images/%d.png" % epoch)
        plt.show()
        plt.close()

In [None]:
gan = GAN()
imgs, _ = next(gan.datagen)
print(imgs.shape)
from keras.preprocessing import image
plt.imshow(image.array_to_img(imgs[3]))
plt.show()
gan.train(epochs=250, batch_size=32, sample_interval=1)