<a href="https://colab.research.google.com/github/seismosmsr/machine_learning/blob/main/conditional_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Conditional GAN
**Author:** [Aron Boettcher](spectral.online)<br>
**Adapted from:** [Sayak Paul](https://twitter.com/RisingSayak)<br>
**Date created:** 2021/07/13<br>
**Last modified:** 2021/07/15<br>
**Description:** Training a GAN conditioned on class labels to generate handwritten digits.

Generative Adversarial Networks (GANs) let us generate novel image data, video data,
or audio data from a random input. Typically, the random input is sampled
from a normal distribution, before going through a series of transformations that turn
it into something plausible (image, video, audio, etc.).

However, a simple [DCGAN](https://arxiv.org/abs/1511.06434) doesn't let us control
the appearance (e.g. class) of the samples we're generating. For instance,
with a GAN that generates MNIST handwritten digits, a simple DCGAN wouldn't let us
choose the class of digits we're generating.
To be able to control what we generate, we need to _condition_ the GAN output
on a semantic input, such as the class of an image.

In this example, we'll build a **Conditional GAN** that can generate MNIST handwritten
digits conditioned on a given class. Such a model can have various useful applications:

* let's say you are dealing with an
[imbalanced image dataset](https://developers.google.com/machine-learning/data-prep/construct/sampling-splitting/imbalanced-data),
and you'd like to gather more examples for the skewed class to balance the dataset.
Data collection can be a costly process on its own. You could instead train a Conditional GAN and use
it to generate novel images for the class that needs balancing.
* Since the generator learns to associate the generated samples with the class labels,
its representations can also be used for [other downstream tasks](https://arxiv.org/abs/1809.11096).

Following are the references used for developing this example:

* [Conditional Generative Adversarial Nets](https://arxiv.org/abs/1411.1784)
* [Lecture on Conditional Generation from Coursera](https://www.coursera.org/lecture/build-basic-generative-adversarial-networks-gans/conditional-generation-inputs-2OPrG)

If you need a refresher on GANs, you can refer to the "Generative adversarial networks"
section of
[this resource](https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-12/r-3/232).

This example requires TensorFlow 2.5 or higher, as well as TensorFlow Docs, which can be
installed using the following command:

In [157]:
!pip install -q git+https://github.com/tensorflow/docs

## Imports

In [158]:
from tensorflow import keras
from tensorflow.keras import layers

from tensorflow_docs.vis import embed
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import imageio

## Constants and hyperparameters

In [159]:
batch_size = 256
num_channels = 1
num_classes = 101
image_size = 512
latent_dim = 512

## Loading the MNIST dataset and preprocessing it

In [160]:
!pip install gdown
import gdown
import zipfile
import os

# I switched to pngs and jpgs to try and use tensorflows native vectorization
#todo: get gdal working so you can just use geotiff
url = 'https://drive.google.com/uc?id=1pxhi4mvgxzODrXgYJbSSUUQ4LPWNOr8E'
output = 'sample.csv'

#Update with proper validation data
# https://drive.google.com/file/d/1unLB1XCJHoul3gqGYGzqSS8QMr5ZvJ9h/view?usp=sharing
gdown.download(url,output,quiet = False)


# cwd = os.getcwd()
# with zipfile.ZipFile(cwd+'/sample.zip', 'r') as zip_ref:
#     zip_ref.extractall(cwd+'/sample')

# PATH = os.path.join(os.path.dirname(cwd+'/sample/'), 'sample/')
# print(PATH)



Downloading...
From: https://drive.google.com/uc?id=1pxhi4mvgxzODrXgYJbSSUUQ4LPWNOr8E
To: /content/sample.csv
100%|██████████| 468M/468M [00:02<00:00, 218MB/s]


'sample.csv'

In [161]:
# generate n real samples with class labels
def generate_real_samples(s,filename = '/content/sample.csv'):
  import random
  import pandas
  import numpy
	# Sample s rows of data.frame

  #number of records in file (excludes header)
  n = sum(1 for line in open(filename)) - 1 

  #the 0-indexed header will not be included in the skip list
  skip = sorted(random.sample(range(1,n+1),n-s)) 
  df = pandas.read_csv(filename, skiprows=skip)
  X = []
  for i in range(len(df.index)):
    text_exa = df['rh'][i]
    text_exa = str(text_exa).replace("{","").replace("}", "")
    test_exa = text_exa.split(",")
    test_exa = [float(i) for i in test_exa]
    # test_exa = test_exa
    X.append(test_exa)


	# generate class labels
  y = []
  for i in range(len(df.index)):
    # y_one = numpy.ones(1)
    y_one = df['ls'][i]
    y_one = str(y_one).replace("{","").replace("}", "")
    y_one = y_one.split(",")
    y_one = [float(i) for i in y_one]
    y.append(y_one)

    
  X = numpy.array(X)
  y = numpy.array(y)

  X = (X.astype("float32") / 65455.0).astype("float32")
  y =(y.astype("float32")+100)/255
  return X, y

In [162]:
# generate_real_samples(2)

In [163]:
import numpy as np
import tensorflow as tf

# We'll use all the available examples from both the training and test
# sets.
# (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
(y_train, x_train) = generate_real_samples(2000)
(y_test, x_test) = generate_real_samples(200)

all_pixels = np.concatenate([x_train, x_test])
all_labels = np.concatenate([y_train, y_test])

# # Scale based on reflectence values:
# # https://developers.google.com/earth-engine/datasets/catalog/LANDSAT_LC08_C02_T1_L2?hl=en#bands
all_pixels = (all_pixels.astype("float32") / 65455.0).astype("float32")
all_labels =(all_labels.astype("float32")+100)/255
# # Create tf.data.Dataset.
dataset = tf.data.Dataset.from_tensor_slices((all_pixels, all_labels))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
print(f"Shape of training images: {all_labels[1]}")
print(f"Shape of training images: {all_pixels.shape}")
# print(f"Shape of training labels: {all_labels.shape}")

Shape of training images: [0.3921566  0.39215663 0.39215666 0.3921567  0.3921567  0.39215672
 0.39215672 0.39215675 0.39215675 0.39215675 0.39215678 0.39215678
 0.39215678 0.39215678 0.3921568  0.3921568  0.3921568  0.3921568
 0.3921568  0.39215684 0.39215684 0.39215684 0.39215684 0.39215684
 0.39215684 0.39215684 0.39215687 0.39215687 0.39215687 0.39215687
 0.39215687 0.39215687 0.3921569  0.3921569  0.3921569  0.3921569
 0.3921569  0.3921569  0.39215693 0.39215693 0.39215693 0.39215693
 0.39215693 0.39215696 0.39215696 0.39215696 0.39215696 0.392157
 0.392157   0.392157   0.39215702 0.39215705 0.39215708 0.3921571
 0.39215714 0.39215717 0.3921572  0.39215723 0.39215726 0.3921573
 0.39215732 0.39215732 0.39215735 0.39215738 0.39215738 0.3921574
 0.3921574  0.39215744 0.39215747 0.39215747 0.3921575  0.3921575
 0.39215752 0.39215752 0.39215755 0.39215755 0.39215758 0.3921576
 0.3921576  0.39215764 0.39215764 0.39215767 0.39215767 0.3921577
 0.39215773 0.39215773 0.39215776 0.3921578  0

## Calculating the number of input channel for the generator and discriminator

In a regular (unconditional) GAN, we start by sampling noise (of some fixed
dimension) from a normal distribution. In our case, we also need to account
for the class labels. We will have to add the number of classes to
the input channels of the generator (noise input) as well as the discriminator
(generated image input).

In [164]:
generator_in_channels = latent_dim + image_size
discriminator_in_channels = image_size + num_classes
print(generator_in_channels, discriminator_in_channels)

1024 613


## Creating the discriminator and generator

The model definitions (`discriminator`, `generator`, and `ConditionalGAN`) have been
adapted from [this example](https://keras.io/guides/customizing_what_happens_in_fit/).

In [165]:
# Create the discriminator.
discriminator = keras.Sequential(
    [
        keras.layers.InputLayer((discriminator_in_channels,)),
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(256),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(128),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(1),
    ],
    name="discriminator",
)

# Create the generator.
generator = keras.Sequential(
    [
        keras.layers.InputLayer((generator_in_channels,)),
        layers.Dense(1024),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(256),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(128),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(num_classes,  activation="linear"),
    ],
    name="generator",
)

## Creating a `ConditionalGAN` model

In [166]:

class ConditionalGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(ConditionalGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")

    @property
    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker]

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(ConditionalGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, data):
        # Unpack the data.
        real_pixels, real_labels = data
        # print(real_pixels[0])
        # Sample random points in the latent space and concatenate the labels.
        # This is for the generator.
        batch_size = tf.shape(real_pixels)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        # print(random_latent_vectors[0])
        random_vector_pixels = tf.concat(
            [random_latent_vectors, real_pixels], axis=1
        )
        # print(random_vector_pixels[0])
        # Decode the noise (guided by labels) to fake images.
        generated_labels = self.generator(random_vector_pixels)

        # Combine them with real images. Note that we are concatenating the labels
        # with these images here.
        fake_pixel_and_labels = tf.concat([generated_labels, real_pixels], -1)
        real_pixel_and_labels = tf.concat([real_labels, real_pixels], -1)
        combined_images = tf.concat(
            [fake_pixel_and_labels, real_pixel_and_labels], axis=0
        )

        # Assemble labels discriminating real from fake images.
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )

        # Train the discriminator.
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space.
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        random_vector_labels = tf.concat(
            [random_latent_vectors, real_pixels], axis=1
        )

        # Assemble labels that say "all real images".
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            fake_labels = self.generator(random_vector_labels)
            fake_pixels_and_labels = tf.concat([real_pixels, fake_labels], -1)
            predictions = self.discriminator(fake_pixels_and_labels)
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Monitor loss.
        self.gen_loss_tracker.update_state(g_loss)
        self.disc_loss_tracker.update_state(d_loss)
        return {
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
        }


## Training the Conditional GAN

In [169]:
cond_gan = ConditionalGAN(
    discriminator=discriminator, generator=generator, latent_dim=latent_dim
)
cond_gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

cond_gan.fit(dataset, epochs=5000)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch 2502/5000
Epoch 2503/5000
Epoch 2504/5000
Epoch 2505/5000
Epoch 2506/5000
Epoch 2507/5000
Epoch 2508/5000
Epoch 2509/5000
Epoch 2510/5000
Epoch 2511/5000
Epoch 2512/5000
Epoch 2513/5000
Epoch 2514/5000
Epoch 2515/5000
Epoch 2516/5000
Epoch 2517/5000
Epoch 2518/5000
Epoch 2519/5000
Epoch 2520/5000
Epoch 2521/5000
Epoch 2522/5000
Epoch 2523/5000
Epoch 2524/5000
Epoch 2525/5000
Epoch 2526/5000
Epoch 2527/5000
Epoch 2528/5000
Epoch 2529/5000
Epoch 2530/5000
Epoch 2531/5000
Epoch 2532/5000
Epoch 2533/5000
Epoch 2534/5000
Epoch 2535/5000
Epoch 2536/5000
Epoch 2537/5000
Epoch 2538/5000
Epoch 2539/5000
Epoch 2540/5000
Epoch 2541/5000
Epoch 2542/5000
Epoch 2543/5000
Epoch 2544/5000
Epoch 2545/5000
Epoch 2546/5000
Epoch 2547/5000
Epoch 2548/5000
Epoch 2549/5000
Epoch 2550/5000
Epoch 2551/5000
Epoch 2552/5000
Epoch 2553/5000
Epoch 2554/5000
Epoch 2555/5000
Epoch 2556/5000
Epoch 2557/5000
Epoch 2558/5000
Epoch 2559/5000
Epoch 2

<keras.callbacks.History at 0x7f83794bce10>

## Interpolating between classes with the trained generator

In [171]:
# We first extract the trained generator from our Conditiona GAN.
trained_gen = cond_gan.generator
(real_labels, real_pixels) = generate_real_samples(batch_size)

# batch_size = tf.shape(real_pixels)[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
# print(random_latent_vectors[0])
random_vector_pixels = tf.concat(
    [random_latent_vectors, real_pixels], axis=1
)
print(real_labels[0])

generated_waveform = trained_gen(random_vector_pixels)

print(generated_waveform[1])
# # Choose the number of intermediate images that would be generated in
# # between the interpolation + 2 (start and last images).
# num_interpolation = 25  # @param {type:"integer"}

# # Sample noise for the interpolation.
# interpolation_noise = tf.random.normal(shape=(1, latent_dim))
# interpolation_noise = tf.repeat(interpolation_noise, repeats=num_interpolation)
# interpolation_noise = tf.reshape(interpolation_noise, (num_interpolation, latent_dim))


# def interpolate_class(first_number, second_number):
#     # Convert the start and end labels to one-hot encoded vectors.
#     first_label = keras.utils.to_categorical([first_number], num_classes)
#     second_label = keras.utils.to_categorical([second_number], num_classes)
#     first_label = tf.cast(first_label, tf.float32)
#     second_label = tf.cast(second_label, tf.float32)

#     # Calculate the interpolation vector between the two labels.
#     percent_second_label = tf.linspace(0, 1, num_interpolation)[:, None]
#     percent_second_label = tf.cast(percent_second_label, tf.float32)
#     interpolation_labels = (
#         first_label * (1 - percent_second_label) + second_label * percent_second_label
#     )

#     # Combine the noise and the labels and run inference with the generator.
#     noise_and_labels = tf.concat([interpolation_noise, interpolation_labels], 1)
#     fake = trained_gen.predict(noise_and_labels)
#     return fake


# start_class = 1  # @param {type:"slider", min:0, max:9, step:1}
# end_class = 2  # @param {type:"slider", min:0, max:9, step:1}

# fake_images = interpolate_class(start_class, end_class)

[-6.04995803e-05 -5.59162800e-05 -5.19440837e-05 -4.84302218e-05
 -4.55274603e-05 -4.27774794e-05 -4.10969369e-05 -3.88052867e-05
 -3.69719673e-05 -3.52914212e-05 -3.36108787e-05 -3.25414439e-05
 -3.07081209e-05 -2.96386843e-05 -2.84164689e-05 -2.73470323e-05
 -2.61248188e-05 -2.50553821e-05 -2.44442745e-05 -2.33748378e-05
 -2.21526243e-05 -2.16942935e-05 -2.04720800e-05 -1.98609723e-05
 -1.87915357e-05 -1.81804298e-05 -1.71109932e-05 -1.64998855e-05
 -1.58887779e-05 -1.48193421e-05 -1.42082345e-05 -1.35971277e-05
 -1.25276911e-05 -1.19165834e-05 -1.13054775e-05 -1.08471468e-05
 -1.02360400e-05 -9.01382555e-06 -8.55549661e-06 -7.94438893e-06
 -7.33328216e-06 -6.72217539e-06 -6.26384553e-06 -5.65273876e-06
 -4.43052477e-06 -3.97219446e-06 -3.36108769e-06 -2.74998092e-06
 -2.13887415e-06 -1.68054385e-06 -1.06943708e-06 -4.58330135e-07
  0.00000000e+00  4.58330135e-07  1.68054385e-06  2.13887415e-06
  2.74998092e-06  3.36108769e-06  3.97219446e-06  4.43052477e-06
  5.04163199e-06  5.65273

Here, we first sample noise from a normal distribution and then we repeat that for
`num_interpolation` times and reshape the result accordingly.
We then distribute it uniformly for `num_interpolation`
with the label indentities being present in some proportion.

In [None]:
fake_images *= 255.0
converted_images = fake_images.astype(np.uint8)
converted_images = tf.image.resize(converted_images, (96, 96)).numpy().astype(np.uint8)
imageio.mimsave("animation.gif", converted_images, fps=1)
embed.embed_file("animation.gif")

We can further improve the performance of this model with recipes like
[WGAN-GP](https://keras.io/examples/generative/wgan_gp).
Conditional generation is also widely used in many modern image generation architectures like
[VQ-GANs](https://arxiv.org/abs/2012.09841), [DALL-E](https://openai.com/blog/dall-e/),
etc.

You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/conditional-gan) and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/conditional-GAN).