In [22]:
import numpy as np
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Reshape, Flatten, Conv2D, Conv2DTranspose, LeakyReLU, Dropout, BatchNormalization, ReLU
from matplotlib import pyplot as plt
import tensorflow as tf
##WGAN-additional
from keras import backend
from keras.constraints import Constraint


In [28]:
def loaddata(datasettype):
    if datasettype == 'mnist':
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
        x_train = (x_train[(y_train == 5) | (y_train == 7),]- 127.5) / 127.5
        x_train = np.expand_dims(x_train, axis=3)
        y_train = y_train[(y_train == 5) | (y_train == 7)]
        y_train = np.where(y_train == 5, 0, 1).reshape((-1,1))

        x_test = (x_test[(y_test == 5) | (y_test == 7),]- 127.5) / 127.5
        x_test = np.expand_dims(x_test, axis=3)
        y_test = y_test[(y_test == 5) | (y_test == 7)]
        y_test = np.where(y_test == 5, 0, 1).reshape((-1,1))

    elif datasettype == 'cifar10':
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
        y_train = np.squeeze(y_train)
        y_test = np.squeeze(y_test)
        x_train = (x_train[(y_train == 1) | (y_train == 7),]- 127.5) / 127.5
        y_train = y_train[(y_train == 1) | (y_train == 7)]
        y_train = np.where(y_train == 1, 0, 1).reshape((-1,1))

        x_test = (x_test[(y_test == 1) | (y_test == 7),]- 127.5) / 127.5
        y_test = y_test[(y_test == 1) | (y_test == 7)]
        y_test = np.where(y_test == 1, 0, 1).reshape((-1,1))

    return x_train, y_train, x_test, y_test

### WGAN - Wasserstein loss
def wasserstein_loss(y_true, y_pred):
    return backend.mean(y_true * y_pred)

### WGAN - Critic Weight Clipping
# clip model weights to a given hypercube
class ClipConstraint(Constraint):
    # set clip value when initialized
    def __init__(self, clip_value):
        self.clip_value = clip_value
 
    # clip model weights to hypercube
    def __call__(self, weights):
        return backend.clip(weights, -self.clip_value, self.clip_value)
 
    # get the config
    def get_config(self):
        return {'clip_value': self.clip_value}

def makediscriminator(input_shape):
    # define weight clipping constraint
    const = ClipConstraint(0.01)
    
    model = Sequential()
    ### WGAN-add clipping kernel_contraint
    model.add(Conv2D(64, (5,5), strides=(2, 2), padding='same', input_shape = input_shape, kernel_constraint=const))
    model.add(LeakyReLU())
    model.add(Dropout(0.3))
    ### WGAN-add clipping kernel_contraint
    model.add(Conv2D(128, (5,5), strides=(2, 2), padding='same', kernel_constraint=const))
    model.add(LeakyReLU(alpha=0.3))
    model.add(Dropout(0.3))

    model.add(Flatten())
    ### WGAN-Change to linear activation
    model.add(Dense(1, activation='linear'))
    ###WGAN-Use RMSPROP
    opt = RMSprop(lr=0.00005)
    ### WGAN-Wasserstein Loss
    model.compile(loss=wasserstein_loss, optimizer=opt, metrics=['accuracy'])
    return model

def makegenerator(datasettype, latent_dim):
    if datasettype == 'mnist':
        channels = 1
        finalsize = 28
    elif datasettype == 'cifar10':
        channels = 3
        finalsize = 32

    model = Sequential()
    if datasettype == 'mnist':

        model.add(Dense(7*7*256, use_bias=False, input_shape=(100,)))
        model.add(BatchNormalization())
        model.add(LeakyReLU())

        model.add(Reshape((7, 7, 256)))

        model.add(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
        model.add(BatchNormalization())
        model.add(LeakyReLU())

        model.add(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
        model.add(BatchNormalization())
        model.add(LeakyReLU())

        model.add(Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))

    elif datasettype == 'cifar10':
        n_nodes = 256 * 4 * 4
        model.add(Dense(n_nodes, input_dim=latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Reshape((4, 4, 256)))
        # upsample to 8x8
        model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        # upsample to 16x16
        model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        # upsample to 32x32
        model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        # output layer
        model.add(Conv2D(3, (3,3), activation='tanh', padding='same'))
    return model

def makegan(g_model, d_model):
    # set discriminator to not trainable
    d_model.trainable = False
    # setup gan
    model = Sequential()
    model.add(g_model)
    model.add(d_model)
    ###WGAN-Use RMSPROP
    opt = RMSprop(lr=0.00005)
    model.compile(loss=wasserstein_loss, optimizer=opt)
    return model

def gen_real(dataset, n_samples):
    # generate random indices for subsampling
    idx = np.random.randint(0, dataset.shape[0], n_samples)
    x = dataset[idx]
    ### for WGAN generate class labels == -1 for real
    y = -np.ones((n_samples, 1))
    return x, y

def gen_fake(g_model, latent_dim, n_samples):
    # generate points in latent space
    x_input = get_latent(latent_dim, n_samples)
    # predict outputs
    X = g_model.predict(x_input)
    ### for WGAN generate class labels == 1 for fake
    y = np.ones((n_samples, 1))
    return X, y

def get_latent(latent_dim, n_samples):
    # generate latent input points
    x_input = np.random.randn(latent_dim * n_samples)
    # reshape to fit network
    x_input = x_input.reshape(n_samples, latent_dim)
    return x_input

def save_plot(savepath, examples, epoch, n=4):
    # plot images
    print(examples.shape)
    print(np.min(examples))
    print(np.max(examples))
    for i in range(n * n):
        # define subplot
        plt.subplot(n, n, 1 + i)
        # turn off axis
        plt.axis('off')
        # plot raw pixel data
        plt.imshow(examples[i].astype('uint8'))
    # save plot to file
    filename = 'Generated_plot_e%03d.png' % (epoch+1)
    plt.savefig(savepath + filename)
    plt.close()

def performance(savepath, epoch, g_model, d_model, dataset, latent_dim, n_samples=100):
    # prepare real samples
    X_real, y_real = gen_real(dataset, n_samples)
    # evaluate discriminator on real examples
    _, acc_real = d_model.evaluate(X_real, y_real, verbose=0)
    # prepare fake examples
    x_fake, y_fake = gen_fake(g_model, latent_dim, n_samples)
    # evaluate discriminator on fake examples
    _, acc_fake = d_model.evaluate(x_fake, y_fake, verbose=0)
    # summarize discriminator performance
    print('Performance test> Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
    # save plot
    fakeplot = (x_fake * 127.5) + 127.5

    save_plot(savepath, fakeplot, epoch)
    # save the generator model tile file
    g_filename = 'generator_model_%03d.h5' % (epoch + 1)
    g_model.save(savepath + g_filename)
    d_filename = 'discriminator_model_%03d.h5' % (epoch + 1)
    d_model.save(savepath + d_filename)

def train(datasettype, savepath, x_train, latent_dim, n_critic=5, n_epochs=200, batchsize=256, retries = 5, max_loss_increase_epochs = 10):
    batch_per_epoch = int(x_train.shape[0] / batchsize)
    half_batch = int(batchsize / 2)
    # manually enumerate epochs
    
    for r in range(retries):
        print(f'Attempt:{r+1}')
        d_model = makediscriminator(x_train.shape[1:])
        g_model = makegenerator(datasettype, latent_dim)
        gan_model = makegan(g_model, d_model)

        d_loss_real = []
        d_loss_fake = []
        d_acc_real = []
        d_acc_fake = []

        g_loss = []
        g_loss_epoch = []
        g_loss_inc_counter = 0
        for i in range(n_epochs):
            print(f'Epoch: {i+1}')
            # enumerate batches over the training set
            for j in range(batch_per_epoch):
                
                ### WGAN-update critic more times than generator
                for _ in range(n_critic):
                    # get randomly selected 'real' samples
                    X_real, y_real = gen_real(x_train, half_batch)
                    # generate 'fake' examples
                    X_fake, y_fake = gen_fake(g_model, latent_dim, half_batch)
                    # update discriminator model weights
                    d_l_real, d_a_real = d_model.train_on_batch(X_real, y_real)
                    d_l_fake, d_a_fake = d_model.train_on_batch(X_fake, y_fake)
                # prepare points in latent space as input for the generator
                X_gan = get_latent(latent_dim, batchsize)
                ###WGAN- minust labels
                y_gan = -np.ones((batchsize, 1))
                # update the generator via the discriminator's error
                g_l = gan_model.train_on_batch(X_gan, y_gan)
                # summarize loss on this batch
                d_loss_real.append(d_l_real)
                d_loss_fake.append(d_l_fake)
                d_acc_real.append(d_a_real)
                d_acc_fake.append(d_a_fake)
                g_loss.append(g_l)
            print('Try:%d, Epoch:%d, D_Loss_Real=%.3f, D_Loss_Fake=%.3f, D_Acc_Real=%.3f, D_Acc_Fake=%.3f, GAN_Loss=%.3f\n' % (r+1, i+1, d_l_real, d_l_fake, d_a_real, d_a_fake, g_l))
            g_loss_epoch.append(g_l)
            if len(g_loss_epoch) >= 2:
                if g_loss_epoch[-1] >  g_loss_epoch[-2]:
                    g_loss_inc_counter += 1
                    print(f'Generator loss increased for {g_loss_inc_counter} epochs.')
                elif g_loss_epoch[-1] <  g_loss_epoch[-2]:
                    g_loss_inc_counter = 0

            if g_loss_inc_counter == max_loss_increase_epochs:
                print(f'Try: {r+1} failed to train. Restarting training')
                break
    #         evaluate the model performance, sometimes
            if (i+1) % 20 == 0:
                performance(savepath, i, g_model, d_model, x_train, latent_dim)

            if i == n_epochs - 1:
                return g_model, d_model, gan_model, d_loss_real, d_loss_fake, d_acc_real, d_acc_fake, g_loss

    # reach here if fail to train
    print("Training Stopped...")
    return g_model, d_model, gan_model, d_loss_real, d_loss_fake, d_acc_real, d_acc_fake, g_loss


In [24]:
x_train, y_train, x_test, y_test = loaddata('mnist')

In [25]:
savepath = './testwgan/'

In [30]:
g_model, d_model, gan_model, d_loss_real, d_loss_fake, d_acc_real, d_acc_fake, g_loss = \
train('mnist', savepath, x_train, latent_dim=100, n_critic=5, n_epochs=200, batchsize=256, retries = 5, max_loss_increase_epochs = 10000)

Attempt:1
Epoch: 1
Try:1, Epoch:1, D_Loss_Real=-87.470, D_Loss_Fake=9.400, D_Acc_Real=0.000, D_Acc_Fake=1.000, GAN_Loss=-83.938

Epoch: 2
Try:1, Epoch:2, D_Loss_Real=-236.885, D_Loss_Fake=63.586, D_Acc_Real=0.000, D_Acc_Fake=1.000, GAN_Loss=-275.404

Epoch: 3
Try:1, Epoch:3, D_Loss_Real=-301.134, D_Loss_Fake=208.418, D_Acc_Real=0.000, D_Acc_Fake=1.000, GAN_Loss=-365.851

Epoch: 4
Try:1, Epoch:4, D_Loss_Real=-294.337, D_Loss_Fake=297.251, D_Acc_Real=0.000, D_Acc_Fake=1.000, GAN_Loss=-355.912

Generator loss increased for 1 epochs.
Epoch: 5
Try:1, Epoch:5, D_Loss_Real=-218.614, D_Loss_Fake=253.253, D_Acc_Real=0.000, D_Acc_Fake=1.000, GAN_Loss=-267.006

Generator loss increased for 2 epochs.
Epoch: 6
Try:1, Epoch:6, D_Loss_Real=-129.149, D_Loss_Fake=154.281, D_Acc_Real=0.000, D_Acc_Fake=1.000, GAN_Loss=-154.167

Generator loss increased for 3 epochs.
Epoch: 7
Try:1, Epoch:7, D_Loss_Real=-62.671, D_Loss_Fake=71.844, D_Acc_Real=0.000, D_Acc_Fake=1.000, GAN_Loss=-70.409

Generator loss incre

Try:1, Epoch:53, D_Loss_Real=-1.253, D_Loss_Fake=1.231, D_Acc_Real=0.000, D_Acc_Fake=1.000, GAN_Loss=-1.192

Generator loss increased for 4 epochs.
Epoch: 54
Try:1, Epoch:54, D_Loss_Real=-1.222, D_Loss_Fake=1.219, D_Acc_Real=0.000, D_Acc_Fake=1.000, GAN_Loss=-1.179

Generator loss increased for 5 epochs.
Epoch: 55
Try:1, Epoch:55, D_Loss_Real=-1.244, D_Loss_Fake=1.174, D_Acc_Real=0.000, D_Acc_Fake=1.000, GAN_Loss=-1.189

Epoch: 56
Try:1, Epoch:56, D_Loss_Real=-1.200, D_Loss_Fake=1.184, D_Acc_Real=0.000, D_Acc_Fake=1.000, GAN_Loss=-1.124

Generator loss increased for 1 epochs.
Epoch: 57
Try:1, Epoch:57, D_Loss_Real=-1.175, D_Loss_Fake=1.194, D_Acc_Real=0.000, D_Acc_Fake=1.000, GAN_Loss=-1.092

Generator loss increased for 2 epochs.
Epoch: 58
Try:1, Epoch:58, D_Loss_Real=-1.160, D_Loss_Fake=1.153, D_Acc_Real=0.000, D_Acc_Fake=1.000, GAN_Loss=-1.075

Generator loss increased for 3 epochs.
Epoch: 59
Try:1, Epoch:59, D_Loss_Real=-1.145, D_Loss_Fake=1.130, D_Acc_Real=0.000, D_Acc_Fake=1.000,

Try:1, Epoch:105, D_Loss_Real=-0.687, D_Loss_Fake=0.683, D_Acc_Real=0.000, D_Acc_Fake=0.977, GAN_Loss=-0.641

Generator loss increased for 1 epochs.
Epoch: 106
Try:1, Epoch:106, D_Loss_Real=-0.675, D_Loss_Fake=0.665, D_Acc_Real=0.000, D_Acc_Fake=0.961, GAN_Loss=-0.632

Generator loss increased for 2 epochs.
Epoch: 107
Try:1, Epoch:107, D_Loss_Real=-0.680, D_Loss_Fake=0.670, D_Acc_Real=0.000, D_Acc_Fake=0.953, GAN_Loss=-0.634

Epoch: 108
Try:1, Epoch:108, D_Loss_Real=-0.685, D_Loss_Fake=0.655, D_Acc_Real=0.000, D_Acc_Fake=0.938, GAN_Loss=-0.611

Generator loss increased for 1 epochs.
Epoch: 109
Try:1, Epoch:109, D_Loss_Real=-0.689, D_Loss_Fake=0.648, D_Acc_Real=0.000, D_Acc_Fake=0.875, GAN_Loss=-0.600

Generator loss increased for 2 epochs.
Epoch: 110
Try:1, Epoch:110, D_Loss_Real=-0.673, D_Loss_Fake=0.670, D_Acc_Real=0.000, D_Acc_Fake=0.961, GAN_Loss=-0.630

Epoch: 111
Try:1, Epoch:111, D_Loss_Real=-0.662, D_Loss_Fake=0.672, D_Acc_Real=0.000, D_Acc_Fake=0.992, GAN_Loss=-0.621

Generato

Try:1, Epoch:158, D_Loss_Real=-0.527, D_Loss_Fake=0.524, D_Acc_Real=0.000, D_Acc_Fake=0.641, GAN_Loss=-0.485

Generator loss increased for 3 epochs.
Epoch: 159
Try:1, Epoch:159, D_Loss_Real=-0.542, D_Loss_Fake=0.523, D_Acc_Real=0.000, D_Acc_Fake=0.633, GAN_Loss=-0.501

Epoch: 160
Try:1, Epoch:160, D_Loss_Real=-0.526, D_Loss_Fake=0.521, D_Acc_Real=0.000, D_Acc_Fake=0.594, GAN_Loss=-0.492

Generator loss increased for 1 epochs.
Performance test> Accuracy real: 0%, fake: 38%
(100, 28, 28, 1)
0.0
254.99977
Epoch: 161
Try:1, Epoch:161, D_Loss_Real=-0.545, D_Loss_Fake=0.535, D_Acc_Real=0.000, D_Acc_Fake=0.680, GAN_Loss=-0.477

Generator loss increased for 2 epochs.
Epoch: 162
Try:1, Epoch:162, D_Loss_Real=-0.519, D_Loss_Fake=0.520, D_Acc_Real=0.000, D_Acc_Fake=0.594, GAN_Loss=-0.489

Epoch: 163
Try:1, Epoch:163, D_Loss_Real=-0.520, D_Loss_Fake=0.515, D_Acc_Real=0.000, D_Acc_Fake=0.547, GAN_Loss=-0.486

Generator loss increased for 1 epochs.
Epoch: 164
Try:1, Epoch:164, D_Loss_Real=-0.530, D_