In [None]:
import numpy as np
from tensorflow import keras
from keras.layers import Conv2D, BatchNormalization, LeakyReLU, Dense, Input, Reshape, Conv2DTranspose, Flatten
from keras import backend as k
from keras.constraints import Constraint
from keras.initializers import RandomNormal
from keras.optimizers import RMSprop
import matplotlib.pyplot as plt
from keras.utils import plot_model

In [None]:
def load_data():
    (X, _),(_, _) = keras.datasets.mnist.load_data()
    X = X[:6000]
    X = np.expand_dims(X,axis=-1)
    X = X.astype('float')
    X = (X-127.5)/127.5
    return X


In [None]:
def get_real_samples(dataset, n_samples):
    idx = np.random.randint(0, dataset.shape[0], n_samples)
    X = dataset[idx]
    y = - np.ones((n_samples,1))
    return X,y

In [None]:
def generate_latent_points(latent_dim, n_samples):
    X = np.random.randn(latent_dim * n_samples).reshape((n_samples, latent_dim))
    return X

In [None]:
def generate_fake_samples(generator, latent_dim, n_samples):
    latent_points = generate_latent_points(latent_dim, n_samples)
    X = generator.predict(latent_points)
    y = np.ones((n_samples,1))
    return X, y

In [None]:
class ClipConstraint(Constraint):
    def __init__(self, clip_value):
        self.clip_value = clip_value

    def __call__(self, weights):
        return k.clip(weights, -self.clip_value, self.clip_value)

In [None]:
def wasserstein_loss(y_true, y_pred):
    return k.mean(y_true * y_pred)

In [None]:
def build_generator(latent_dim : int):
    init = RandomNormal(stddev = 0.02)

    inputs = Input(shape = (latent_dim,))

    dense_0 = Dense(128*7*7, kernel_initializer=init) (inputs)
    relu_0 = LeakyReLU(alpha=0.2) (dense_0)
    reshape_0 = Reshape((7,7,128)) (relu_0)

    convt_0 = Conv2DTranspose(128, (4,4), (2,2), padding='same', kernel_initializer=init) (reshape_0)
    norm_0 = BatchNormalization() (convt_0)
    relu_1 = LeakyReLU(alpha=0.2)(norm_0)

    convt_1 = Conv2DTranspose(128, (4,4), (2,2), padding='same', kernel_initializer=init) (relu_1)
    norm_1 = BatchNormalization() (convt_1)
    relu_2 = LeakyReLU(alpha=0.2)(norm_1)

    conv_0 = Conv2D(1, (7,7), activation ='tanh', padding='same', kernel_initializer=init) (relu_2)

    return keras.Model(inputs=inputs, outputs=conv_0, name='mnist_generator')


In [None]:
def build_critic(image_shape = (28,28,1)):
    init = RandomNormal(0.02)
    const=ClipConstraint(0.01)

    inputs = Input(shape=image_shape)

    conv_0 = Conv2D(64,(4,4), (2,2), padding='same', kernel_initializer=init, kernel_constraint=const)(inputs)
    norm_0 = BatchNormalization()(conv_0)
    relu_0 = LeakyReLU(alpha=0.2)(norm_0)

    conv_1 = Conv2D(64,(4,4), (2,2), padding='same', kernel_initializer=init, kernel_constraint=const)(relu_0)
    norm_1 = BatchNormalization()(conv_1)
    relu_1 = LeakyReLU(alpha=0.2)(norm_1)

    flatten = Flatten()(relu_1)
    dense = Dense(1)(flatten)

    model = keras.Model(inputs=inputs, outputs=dense, name='critic')
    model.compile(loss=wasserstein_loss, optimizer=RMSprop(learning_rate=0.0004))
    return model

In [None]:
def build_gan(generator, critic):
    for layer in critic.layers:
        if not isinstance(layer, BatchNormalization):
            layer.trainable=False

    model = keras.Sequential(
        [
            generator,
            critic
        ]
    )

    model.compile(loss=wasserstein_loss, optimizer=RMSprop(learning_rate=0.0005))
    return model

In [None]:
def show_performance(epoch, generator, latent_dim, n_samples = 36):
    X, _ = generate_fake_samples(generator, latent_dim, n_samples)
    X = (X+1)/2.0

    for i in range(n_samples):
        plt.subplot(10,10,i+1)
        plt.axis('off')
        plt.imshow(X[i,:,:,0], cmap = 'gray_r')
    filename = f'generated_epoch_{epoch}.png'
    plt.savefig(filename)
    plt.close()

    model_name = f'generator_{epoch}.h5'
    generator.save(model_name)

In [None]:
def plot_history(critic_loss_real, critic_loss_fake, gan_loss):
    plt.plot(critic_loss_real, label='Discriminator Real')
    plt.plot(critic_loss_fake, label='Discriminator Fake')
    plt.plot(gan_loss, label='Generator')

    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('GAN Training History')

    plt.legend()

    plt.savefig('plot_line_plot_loss.png')
    plt.close()

In [None]:
def train(generator, critic, gan, dataset, latent_dim, epochs, batch_size, n_critics):
    batch_per_epoch = int(dataset.shape[0]/batch_size)
    n_steps = batch_per_epoch * epochs
    half_batch = batch_size // 2
    critic_loss_real, critic_loss_fake, gan_loss_hist = [],[],[]
    epoch = 0


    for step in range(n_steps):
        critic_temp_real, critic_temp_fake = [], []
        for _ in range(n_critics):
            X_real, y_real = get_real_samples(dataset, batch_size)
            c1 = critic.train_on_batch(X_real, y_real)
            critic_temp_real.append(c1)

            X_fake, y_fake = generate_fake_samples(generator, latent_dim, batch_size)
            c2 = critic.train_on_batch(X_fake, y_fake)
            critic_temp_fake.append(c2)

        critic_loss_real.append(np.mean(c1))
        critic_loss_fake.append(np.mean(c2))
        X_fake = generate_latent_points(latent_dim, batch_size)
        y_fake = -np.ones((batch_size,1))

        gan_loss = gan.train_on_batch(X_fake, y_fake)
        gan_loss_hist.append(gan_loss)

        if (step+1) % batch_per_epoch == 0:
            show_performance(epoch, generator, latent_dim, batch_size)
            plot_history(critic_loss_real, critic_loss_fake, gan_loss)
            epoch+=1
            print(f'\nstarting epoch {epoch}\n')


In [None]:
LATENT_DIM = 50
BATCH_SIZE = 64
EPOCHS = 30
N_CRITICS = 5

critic = build_critic()
generator = build_generator(LATENT_DIM)
gan = build_gan(generator, critic)

dataset = load_data()


In [None]:
plot_model(critic, 'critic.png', show_shapes = True)

In [None]:
plot_model(generator, 'generator.png', show_shapes=True)

In [None]:
train(generator, critic, gan, dataset, LATENT_DIM, 50, BATCH_SIZE, N_CRITICS)