In [None]:
import tensorflow as tf
import os

from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

In [None]:
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

In [None]:
input_shape=(32, 32, 3)

In [None]:
def gen_noise(batch_size):
    return tf.random.normal([batch_size, input_shape[0], input_shape[1], input_shape[2]], 0, 1)

In [None]:
class GAN(tf.keras.Model):
    def __init__(self, discriminator, generator):
        super(GAN, self).__init__()
        self.disc = discriminator
        self.gen = generator
        self.gen_loss_tracker = tf.keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = tf.keras.metrics.Mean(name="discriminator_loss")
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")

    @property
    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker, self.total_loss_tracker]


    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn


    def train_step(self, X, y):
        noise_dim = 2
        batch_size = tf.shape(X)[0]

        with tf.GradientTape() as disc_tape:
            noise = gen_noise(batch_size)
            generated_output = self.gen(noise)

            all_output = tf.concat([X, generated_output], axis=0)
            all_labels = tf.concat([tf.ones(batch_size,1), tf.zeros(batch_size,1)], axis=0)

            all_predictions = self.disc(all_output)

            disc_loss = self.loss_fn(all_labels, all_predictions)

        gradients_of_discriminator = disc_tape.gradient(disc_loss, self.disc.trainable_variables)

        with tf.GradientTape() as gen_tape:
            noise = gen_noise(batch_size)
            generated_output = self.gen(noise)

            all_predictions = self.disc(generated_output)

            all_labels = tf.ones(batch_size,1)

            gen_loss = self.loss_fn(all_labels, all_predictions)


        gradients_of_generator = gen_tape.gradient(gen_loss, self.gen.trainable_variables)
        

        self.g_optimizer.apply_gradients(zip(gradients_of_generator, self.gen.trainable_variables))
        self.d_optimizer.apply_gradients(zip(gradients_of_discriminator, self.disc.trainable_variables))

        #Monitor loss.
        self.gen_loss_tracker.update_state(gen_loss)
        self.disc_loss_tracker.update_state(disc_loss)
        self.total_loss_tracker.update_state(gen_loss+disc_loss)
        return {
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
            "d_total_loss":  self.total_loss_tracker.result()
        }

gen = tf.keras.Sequential([
    tf.keras.layers.Dense(units=128, activation=tf.nn.leaky_relu),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(units=64, activation=tf.nn.leaky_relu),
    tf.keras.layers.Dense(units=2),
],name="generator")

disc = tf.keras.Sequential([
    tf.keras.layers.Dense(units=128, activation=tf.nn.leaky_relu),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(units=64, activation=tf.nn.leaky_relu),
    tf.keras.layers.Dense(units=2),
    tf.keras.layers.Dense(units=1)

],name="discriminator")

cond_gan = GAN(
    discriminator=disc, generator=gen
)

cond_gan.compile(
    d_optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
    g_optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
    loss_fn=tf.keras.losses.BinaryCrossentropy(from_logits=True),
)

cond_gan.fit(train_images, train_labels, epochs=30, batch_size=64)
