# Generative Adversarial Networks in Keras

In [6]:
%matplotlib inline
import importlib
import utils; importlib.reload(utils)
from utils import *

from tqdm import tqdm

# Dataset

In [25]:
from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train.shape

Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz


(60000, 28, 28)

In [29]:
X_train = X_train.reshape(-1, 28, 28, 1).astype(np.float32)
X_test = X_test.reshape(-1, 28, 28, 1).astype(np.float32)

In [30]:
X_train /= 255; X_test /= 255

In [31]:
X_train.shape, X_test.shape

((60000, 28, 28, 1), (10000, 28, 28, 1))

# Train

In [59]:
nz = 100
nc = 1

In [20]:
leak = 0.2
batch_size = 1

In [72]:
def noise(bs): return np.random.rand(bs,100)

In [73]:
def data_D(bs, G):
    real_img = X_train[np.random.randint(0,len(X_train),size=bs)]
    fake_img = G.predict(noise(bs))
    X = np.concatenate((real_img, fake_img))
    return X, [0]*bs + [1]*bs

In [74]:
def make_trainable(net, val):
    net.trainable = val
    for l in net.layers: l.trainable = val

In [135]:
def train(D, G, m, nb_epoch=5000, bs=64):
    dl,gl=[],[]
    for e in tqdm(range(nb_epoch)):
        X,y = data_D(bs//2, G)
        dl.append(D.train_on_batch(X,y))
        make_trainable(D, False)
        gl.append(m.train_on_batch(noise(bs), np.zeros([bs])))
        make_trainable(D, True)
    return dl,gl

# GAN

In [144]:
def generator(nz=100, nc=3, leak=0.1):
    return Sequential([    
        Dense(7*7*512, input_dim=nz),
        Reshape((7, 7, 512)),
        BatchNormalization(),
        LeakyReLU(leak),
    
        Conv2DTranspose(256, 5, strides=1, padding='same'),
        BatchNormalization(),
        LeakyReLU(leak),
    
        Conv2DTranspose(128, 5, strides=1, padding='same'),
        BatchNormalization(),
        LeakyReLU(leak),
    
        Conv2DTranspose(64, 5, strides=2, padding='same'),
        BatchNormalization(),
        LeakyReLU(leak),
    
        Conv2DTranspose(nc, 5, strides=2, padding='same'),
        Activation('tanh'),
    ])

In [145]:
def discriminator(nc=3, leak=0.1):
    return Sequential([
        Convolution2D(64, 5, strides=2, padding='same', input_shape=(28,28,nc)),
        LeakyReLU(leak),
    
        Convolution2D(128, 5, strides=2, padding='same'),
        BatchNormalization(),
        LeakyReLU(leak),

        Convolution2D(256, 5, strides=2, padding='same'),
        BatchNormalization(),
        LeakyReLU(leak),
    
        Reshape((4*4*256,)),
        Dense(1)
    ])

In [146]:
G = generator(nc=nc)
D = discriminator(nc=nc)

In [147]:
D.compile(Adam(1e-3), "binary_crossentropy")

In [149]:
sz = len(X_train)//200
x1 = np.concatenate([np.random.permutation(X_train)[:sz], G.predict(noise(sz))])
D.fit(x1, [0]*sz + [1]*sz, batch_size=128, epochs=1, verbose=2)

Epoch 1/1
3s - loss: 0.0296


<keras.callbacks.History at 0x7f96473a6da0>

In [150]:
m = Sequential([G, D])
m.compile(Adam(1e-4), "binary_crossentropy")

In [151]:
K.set_value(D.optimizer.lr, 1e-3)
K.set_value(m.optimizer.lr, 1e-3)

In [152]:
dl,gl = train(D, G, m, 1, 2)

100%|██████████| 1/1 [00:03<00:00,  3.77s/it]
