In [1]:
import numpy as np
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Reshape, Flatten, Conv2D, Conv2DTranspose, LeakyReLU, Dropout, BatchNormalization, ReLU
from matplotlib import pyplot as plt
import tensorflow as tf
##WGAN-additional
from keras import backend
from keras.constraints import Constraint
from tqdm import tqdm


In [35]:
def loaddata(datasettype):
    if datasettype == 'mnist':
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
        x_train = (x_train[(y_train == 5) | (y_train == 7),]- 127.5) / 127.5
        x_train = np.expand_dims(x_train, axis=3)
        y_train = y_train[(y_train == 5) | (y_train == 7)]
        y_train = np.where(y_train == 5, 0, 1).reshape((-1,1))

        x_test = (x_test[(y_test == 5) | (y_test == 7),]- 127.5) / 127.5
        x_test = np.expand_dims(x_test, axis=3)
        y_test = y_test[(y_test == 5) | (y_test == 7)]
        y_test = np.where(y_test == 5, 0, 1).reshape((-1,1))

    elif datasettype == 'cifar10':
        (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
        y_train = np.squeeze(y_train)
        y_test = np.squeeze(y_test)
        x_train = (x_train[(y_train == 1) | (y_train == 7),]- 127.5) / 127.5
        y_train = y_train[(y_train == 1) | (y_train == 7)]
        y_train = np.where(y_train == 1, 0, 1).reshape((-1,1))

        x_test = (x_test[(y_test == 1) | (y_test == 7),]- 127.5) / 127.5
        y_test = y_test[(y_test == 1) | (y_test == 7)]
        y_test = np.where(y_test == 1, 0, 1).reshape((-1,1))

    return x_train, y_train, x_test, y_test

def makediscriminator(input_shape):
    model = Sequential()

    model.add(Conv2D(64, (5,5), strides=(2, 2), padding='same', input_shape = input_shape))
    model.add(LeakyReLU())
    model.add(Dropout(0.3))

    model.add(Conv2D(128, (5,5), strides=(2, 2), padding='same'))
    model.add(LeakyReLU(alpha=0.3))
    model.add(Dropout(0.3))

    model.add(Flatten())
    ### WGAN-Change to linear activation
    model.add(Dense(1, activation='linear'))
    ###WGAN-Use RMSPROP
    # opt = RMSprop(lr=0.00005)
    ### WGAN-Wasserstein Loss
    # model.compile(loss=wasserstein_loss, optimizer=opt, metrics=['accuracy'])
    return model

def makegenerator(datasettype, latent_dim):
    if datasettype == 'mnist':
        channels = 1
        finalsize = 28
    elif datasettype == 'cifar10':
        channels = 3
        finalsize = 32

    model = Sequential()
    if datasettype == 'mnist':

        model.add(Dense(7*7*256, use_bias=False, input_shape=(100,)))
        model.add(BatchNormalization())
        model.add(LeakyReLU())

        model.add(Reshape((7, 7, 256)))

        model.add(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
        model.add(BatchNormalization())
        model.add(LeakyReLU())

        model.add(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
        model.add(BatchNormalization())
        model.add(LeakyReLU())

        model.add(Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))

    elif datasettype == 'cifar10':
        n_nodes = 256 * 4 * 4
        model.add(Dense(n_nodes, input_dim=latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Reshape((4, 4, 256)))
        # upsample to 8x8
        model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        # upsample to 16x16
        model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        # upsample to 32x32
        model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        # output layer
        model.add(Conv2D(3, (3,3), activation='tanh', padding='same'))
    return model

def gen_real(dataset, n_samples):
    # generate random indices for subsampling
    idx = np.random.randint(0, dataset.shape[0], n_samples)
    x = dataset[idx]
    ### for WGAN generate class labels == -1 for real
    y = -np.ones((n_samples, 1))
    return x, y

def gen_fake(g_model, latent_dim, n_samples):
    # generate points in latent space
    x_input = get_latent(latent_dim, n_samples)
    # predict outputs
    X = g_model.predict(x_input)
    ### for WGAN generate class labels == 1 for fake
    y = np.ones((n_samples, 1))
    return X, y

def get_latent(latent_dim, n_samples):
    # generate latent input points
    x_input = np.random.randn(latent_dim * n_samples)
    # reshape to fit network
    x_input = x_input.reshape(n_samples, latent_dim)
    return x_input

def save_plot(savepath, examples, epoch, n=4):
    # plot images
    print(examples.shape)
    print(np.min(examples))
    print(np.max(examples))
    for i in range(n * n):
        # define subplot
        plt.subplot(n, n, 1 + i)
        # turn off axis
        plt.axis('off')
        # plot raw pixel data
        plt.imshow(examples[i].astype('uint8'))
    # save plot to file
    filename = 'Generated_plot_e%03d.png' % (epoch+1)
    plt.savefig(savepath + filename)
    plt.close()

def performance(savepath, epoch, g_model, d_model, dataset, latent_dim, n_samples=100):
    # prepare real samples
    # X_real, y_real = gen_real(dataset, n_samples)
    # evaluate discriminator on real examples
    # _, acc_real = d_model.evaluate(X_real, y_real, verbose=0)
    # prepare fake examples
    x_fake, _ = gen_fake(g_model, latent_dim, n_samples)
    # evaluate discriminator on fake examples
    # _, acc_fake = d_model.evaluate(x_fake, y_fake, verbose=0)
    # summarize discriminator performance
    # print('Performance test> Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
    # save plot
    fakeplot = (x_fake * 127.5) + 127.5

    save_plot(savepath, fakeplot, epoch)
    # save the generator model tile file
    g_filename = 'generator_model_%03d.h5' % (epoch + 1)
    g_model.save(savepath + g_filename)
    d_filename = 'discriminator_model_%03d.h5' % (epoch + 1)
    d_model.save(savepath + d_filename)

def generator_loss(fake_output):
    gen_loss = tf.math.reduce_mean(fake_output)
    return gen_loss

def discriminator_loss(real_output, fake_output):
    loss = tf.math.reduce_mean(real_output) - tf.math.reduce_mean(fake_output)
    return loss

def train(datasettype, savepath, x_train, latent_dim, n_critic=5, n_epochs=200, batchsize=256, retries = 5, max_loss_increase_epochs = 10):
    batch_per_epoch = int(x_train.shape[0] / batchsize)
     # manually enumerate epochs
    
    for r in range(retries):
        print(f'Attempt:{r+1}')
        d_model = makediscriminator(x_train.shape[1:])
        g_model = makegenerator(datasettype, latent_dim)

        d_losses = []
        g_losses = []
        g_loss_epoch = []
        
        for i in range(n_epochs):
            print(f'Epoch: {i+1}')
            # enumerate batches over the training set
            for j in tqdm(range(batch_per_epoch)):
                
                ### WGAN-update critic more times than generator
                for _ in range(n_critic):
                    # get randomly selected 'real' samples
                    X_real, y_real = gen_real(x_train, batchsize)
                    # generate 'fake' examples
                    X_fake, y_fake = gen_fake(g_model, latent_dim, batchsize)
                    # update discriminator model weights

                    with tf.GradientTape() as tape:
                        # tape.watch(d_model.trainable_variables)
                        d_real_output = d_model(X_real, training = True)
                        d_fake_output = d_model(X_fake, training = True)

                        d_loss = discriminator_loss(d_real_output, d_fake_output)

                    gradients_d_model = tape.gradient(d_loss, d_model.trainable_variables)
                    
                    opt = RMSprop(learning_rate=0.00005)

                    opt.apply_gradients(zip(gradients_d_model, d_model.trainable_variables))
                
                # clip weights
                for w in d_model.trainable_variables:
                    w.assign(tf.clip_by_value(w, -0.01, 0.01))
                  
                # prepare points in latent space as input for the generator
                X_gan = get_latent(latent_dim, batchsize)
                ###WGAN- minust labels
                y_gan = -np.ones((batchsize, 1))

                ### WGAN-update generator
                with tf.GradientTape() as gen_tape:
                    # tape.watch(g_model.trainable_variables)
                    generated_images = g_model(X_gan, training=True)
                    g_fake_output = d_model(generated_images, training=False)
                    g_loss = generator_loss(g_fake_output)
                
                gradients_g_model = gen_tape.gradient(g_loss, g_model.trainable_variables)
                opt = RMSprop(learning_rate=0.00005)

                opt.apply_gradients(zip(gradients_g_model, g_model.trainable_variables))

                # summarize loss on this batch
                g_losses.append(g_loss)
                d_losses.append(d_loss)

            print('Try:%d, Epoch:%d, D_Loss=%.3f, G_Loss=%.3f\n' % (r+1, i+1, d_loss, g_loss))
            g_loss_epoch.append(g_loss)
            # if len(g_loss_epoch) >= 2:
            #     if g_loss_epoch[-1] >  g_loss_epoch[-2]:
            #         g_loss_inc_counter += 1
            #         print(f'Generator loss increased for {g_loss_inc_counter} epochs.')
            #     elif g_loss_epoch[-1] <  g_loss_epoch[-2]:
            #         g_loss_inc_counter = 0

            # if g_loss_inc_counter == max_loss_increase_epochs:
            #     print(f'Try: {r+1} failed to train. Restarting training')
            #     break
    #         evaluate the model performance, sometimes
            if (i+1) % 20 == 0:
                performance(savepath, i, g_model, d_model, x_train, latent_dim)

            if i == n_epochs - 1:
                return g_model, d_model, d_losses, g_losses

    # reach here if fail to train
    print("Training Stopped...")
    return g_model, d_model, d_losses, g_losses


In [3]:
x_train, y_train, x_test, y_test = loaddata('mnist')

In [4]:
savepath = './testwgan/'

In [36]:
g_model, d_model, d_loss, g_loss = \
train('mnist', savepath, x_train, latent_dim=100, n_critic=5, n_epochs=200, batchsize=256, retries = 5, max_loss_increase_epochs = 10000)

Attempt:1
Epoch: 1


100%|██████████| 182/182 [00:33<00:00,  5.38it/s]


Try:1, Epoch:1, D_Loss=-39.833, G_Loss=-18.450

Epoch: 2


100%|██████████| 182/182 [00:34<00:00,  5.29it/s]


Try:1, Epoch:2, D_Loss=-28.498, G_Loss=-0.081

Epoch: 3


100%|██████████| 182/182 [00:36<00:00,  5.00it/s]


Try:1, Epoch:3, D_Loss=-0.434, G_Loss=-4.158

Epoch: 4


100%|██████████| 182/182 [00:36<00:00,  4.99it/s]


Try:1, Epoch:4, D_Loss=-0.096, G_Loss=-1.137

Epoch: 5


100%|██████████| 182/182 [00:36<00:00,  5.02it/s]


Try:1, Epoch:5, D_Loss=-0.073, G_Loss=-1.006

Epoch: 6


100%|██████████| 182/182 [00:36<00:00,  4.96it/s]


Try:1, Epoch:6, D_Loss=-0.073, G_Loss=-0.825

Epoch: 7


100%|██████████| 182/182 [00:36<00:00,  4.95it/s]


Try:1, Epoch:7, D_Loss=-0.092, G_Loss=-1.079

Epoch: 8


100%|██████████| 182/182 [00:36<00:00,  5.00it/s]


Try:1, Epoch:8, D_Loss=-0.077, G_Loss=-0.996

Epoch: 9


100%|██████████| 182/182 [00:37<00:00,  4.89it/s]


Try:1, Epoch:9, D_Loss=-0.019, G_Loss=-0.745

Epoch: 10


100%|██████████| 182/182 [00:35<00:00,  5.06it/s]


Try:1, Epoch:10, D_Loss=-0.054, G_Loss=-0.044

Epoch: 11


100%|██████████| 182/182 [00:32<00:00,  5.53it/s]


Try:1, Epoch:11, D_Loss=-0.091, G_Loss=-1.331

Epoch: 12


100%|██████████| 182/182 [00:32<00:00,  5.56it/s]


Try:1, Epoch:12, D_Loss=-0.054, G_Loss=-0.185

Epoch: 13


100%|██████████| 182/182 [00:32<00:00,  5.63it/s]


Try:1, Epoch:13, D_Loss=-0.091, G_Loss=-1.638

Epoch: 14


100%|██████████| 182/182 [00:32<00:00,  5.65it/s]


Try:1, Epoch:14, D_Loss=0.007, G_Loss=-0.644

Epoch: 15


100%|██████████| 182/182 [00:32<00:00,  5.54it/s]


Try:1, Epoch:15, D_Loss=-0.035, G_Loss=0.026

Epoch: 16


100%|██████████| 182/182 [00:34<00:00,  5.22it/s]


Try:1, Epoch:16, D_Loss=-0.106, G_Loss=0.213

Epoch: 17


100%|██████████| 182/182 [00:33<00:00,  5.36it/s]


Try:1, Epoch:17, D_Loss=-0.076, G_Loss=0.110

Epoch: 18


100%|██████████| 182/182 [00:33<00:00,  5.36it/s]


Try:1, Epoch:18, D_Loss=-0.120, G_Loss=0.094

Epoch: 19


100%|██████████| 182/182 [00:33<00:00,  5.43it/s]


Try:1, Epoch:19, D_Loss=-0.039, G_Loss=-0.229

Epoch: 20


100%|██████████| 182/182 [00:35<00:00,  5.11it/s]


Try:1, Epoch:20, D_Loss=-0.024, G_Loss=-0.044

(100, 28, 28, 1)
0.0
255.0
Epoch: 21


 59%|█████▉    | 107/182 [00:21<00:14,  5.04it/s]


KeyboardInterrupt: 