In [None]:
# Install dependencies
%pip install tensorflow
%pip install matplotlib
%pip install numpy
%pip install wandb
%pip install os

In [None]:
# Set up WandB
import wandb
wandb.login()

In [None]:
# import dependencies
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
import os

In [None]:
# Set Variables and Metrics
model_version = 'v4'
NOISE_SIZE = 100
BATCH_SIZE = 128
N_EPOCHS = 500
LEARNING_RATE = 0.0002

generator_loss_metric = tf.keras.metrics.Mean()
discriminator_loss_metric = tf.keras.metrics.Mean()
real_sample_accuracy = tf.keras.metrics.BinaryAccuracy()
fake_sample_accuracy = tf.keras.metrics.BinaryAccuracy()

In [None]:
# Instantiate Networks
generator = tf.keras.Sequential([
    tf.keras.Input(shape=NOISE_SIZE),
    tf.keras.layers.Dense(256),
    tf.keras.layers.LeakyReLU(alpha=0.2),
    tf.keras.layers.Dense(512),
    tf.keras.layers.LeakyReLU(alpha=0.2),
    tf.keras.layers.Dense(1024),
    tf.keras.layers.LeakyReLU(alpha=0.2),
    tf.keras.layers.Dense(784, activation='sigmoid') # sigmoid activation function as want data to be between 0 and 1
], name='Generator')

# Create the Discriminator Network
# This will take in both real and fake data. That means size 784
discriminator = tf.keras.Sequential([
    tf.keras.Input(shape=784), 
    tf.keras.layers.Dense(512),
    tf.keras.layers.LeakyReLU(alpha=0.2),
    tf.keras.layers.Dense(256),
    tf.keras.layers.LeakyReLU(alpha=0.2),
    tf.keras.layers.Dense(1, activation='sigmoid') # outputs a scalar
], name='Discriminator')


generator.summary()
discriminator.summary()

In [None]:
# Create Loss Functions and Optimizers
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=False)

discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)

def calculate_discriminator_loss(fake_preds, real_preds):
    fake_preds_loss = loss_fn(tf.zeros(fake_preds.shape), fake_preds)
    real_preds_loss = loss_fn(tf.ones(real_preds.shape), real_preds)
    return fake_preds_loss + real_preds_loss

def calculate_generator_loss(fake_preds):
    return loss_fn(tf.ones(fake_preds.shape), fake_preds)

    

In [2]:
# Various Functions for loading the dataset,  genearing fake data, 
def generateFakeData(model_reference, batch_size=64, noise_size=100):
    random_noise_np = np.random.normal(size=(batch_size, noise_size))
    random_noise_tensor = tf.convert_to_tensor(random_noise_np)
    random_noise_tensor = tf.reshape(random_noise_tensor, [batch_size, noise_size])
    fake_data = model_reference(random_noise_tensor)

    return fake_data

def generateRandomNoise(batch_size=BATCH_SIZE, noise_size=NOISE_SIZE):
    random_noise_np = np.random.normal(size=(batch_size, noise_size))
    random_noise_tensor = tf.convert_to_tensor(random_noise_np)
    return tf.reshape(random_noise_tensor, [batch_size, noise_size])

def loadMnistDataset(batch_size=64, input_shape=(28, 28)):
    # Load in MNIST dataset as numpy array
    (train_X, train_y), (test_X, test_y) = tf.keras.datasets.mnist.load_data()

    np.random.shuffle(train_X)

    # noramlizes and flattens dataset in numpy array
    flattened = np.reshape(train_X/255.0, [len(train_X), input_shape[0] * input_shape[1]])

    # Loads np array into a Dataset object
    train_X_ds = tf.data.Dataset.from_tensor_slices(flattened)

    # Returns the flattened, normalized, batched MNIST dataset
    return train_X_ds.batch(batch_size)

NameError: name 'BATCH_SIZE' is not defined

In [1]:
# Define Checkpoint Function
def checkpointModel(model_reference, epoch, dir='images/v4', batch_size=64, noise_size=100):
    path = os.path.join(dir, ('v4-' + epoch))
    os.mkdir(path)

    generated_imgs_batched = generateFakeData(model_reference, batch_size, noise_size).numpy()

    for i, vector in enumerate(generated_imgs_batched):
        formated_arr = np.reshape(vector, (28, 28, 1))
        plt.imshow(formated_arr, cmap='gray', vmin=0, vmax=1)
        plt.savefig(path + '-' + i)

In [None]:
# Set Training Data
training_ds_batched = loadMnistDataset()


In [None]:
# Training Loop
for epoch in range(N_EPOCHS):

    # sample minibatch size m of real data
    for step, minibatch_real in enumerate(training_ds_batched): 

        # sample minibatch size m of noise samples
        noise = generateRandomNoise(BATCH_SIZE, NOISE_SIZE)
        minibatch_fake = generator(noise)

        # calculate loss for discriminiator
        with tf.GradientTape() as tape:
            fake_preds = discriminator(minibatch_fake)
            real_preds = discriminator(minibatch_real)
            discriminator_loss = calculate_discriminator_loss(fake_preds, real_preds)

        # calculate gradient for discriminator
        discriminator_gradients = tape.gradient(discriminator_loss, discriminator.trainable_variables)
        discriminator_loss_metric.update_state(discriminator_loss)
        real_sample_accuracy.update_state(tf.ones(real_preds.shape), real_preds)
        fake_sample_accuracy.update_state(tf.zeros(fake_preds.shape), fake_preds)
        total_sample_accuracy = (real_sample_accuracy.result().numpy() + fake_sample_accuracy.result().numpy()) / 2

        # apply gradients
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))

        # train generator network
        # calculate loss for geneator
        noise = generateRandomNoise(BATCH_SIZE, NOISE_SIZE)

        with tf.GradientTape() as tape:
            # sample minibatch size m of noise samples
            minibatch_fake = generator(noise)
            fake_preds = discriminator(minibatch_fake)
            generator_loss = calculate_generator_loss(fake_preds)

        # calculate gradient for generator
        generator_gradients = tape.gradient(generator_loss, generator.trainable_variables)
        generator_loss_metric.update_state(generator_loss)

        generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))


    print('Epoch: %i Generator Loss: %.4f Discriminator Loss: %.4f' %(epoch, generator_loss_metric.result().numpy(), discriminator_loss_metric.result().numpy()))
    print('Real Sample Accuracy: %.4f Fake Sample Accuracy: %.4f' %(real_sample_accuracy.result().numpy(), fake_sample_accuracy.result().numpy()))
    wandb.log({
        "generator_loss": generator_loss_metric.result().numpy(), 
        "discriminator_loss": discriminator_loss_metric.result().numpy(),
        "real_sample_accuracy": real_sample_accuracy.result().numpy(), 
        "fake_sample_accuracy": fake_sample_accuracy.result().numpy(),
        "total_sample_accuracy": (total_sample_accuracy), 
    })
    generator_loss_metric.reset_states()
    discriminator_loss_metric.reset_states()
    real_sample_accuracy.reset_states()
    fake_sample_accuracy.reset_states()
    total_sample_accuracy = 0 # this is purely book keeping. not functional

    if epoch % 50 == 0 and not epoch == 0: 
        generator.save('savedModels/generator-' + model_version + '/epoch-%i' %epoch)
        discriminator.save('savedModels/discriminator-' + model_version + '/epoch-%i' %epoch)


    checkpointModel(generator)

In [None]:
# RUN ONLY TO TERMINATE TRAINING
wandb.finish()

In [None]:
# Final save of generator and discriminator
generator.save('savedModels/generator-' + model_version + '/final')
discriminator.save('savedModels/discriminator-' + model_version + '/final')