A **GAN (Generative Adversarial Network)** is a type of machine learning framework used for generating new data that resembles a given dataset. It consists of two neural networks, a **Generator** and a **Discriminator**, which are trained together in an adversarial process.

1. **Generator**:
   - The Generator creates synthetic data (e.g., images, text, etc.) from random noise.
   - Its goal is to produce data that is indistinguishable from real data.
   - Example:
     - Input: A vector of random numbers (noise).
     - Output: A generated image resembling real images from the dataset.

2. **Discriminator**:
   - The Discriminator is a classifier that distinguishes between real data (from the dataset) and fake data (generated by the Generator).
   - Its goal is to correctly identify whether the input is real or fake.

### How GANs Work
The Generator and Discriminator are trained simultaneously but with opposing goals:
- The **Generator** tries to "fool" the Discriminator into believing that its outputs are real.
- The **Discriminator** tries to correctly identify real versus fake inputs.

### The Training Process
1. **Train the Discriminator**:
   - Use real data labeled as "real" and fake data (from the Generator) labeled as "fake."
   - The Discriminator minimizes its classification error.
2. **Train the Generator**: 
   - Generate fake data and pass it to the Discriminator.
   - The Generator's goal is to maximize the Discriminator's error (make the fake data look real).

Over time, the Generator improves in creating realistic data, and the Discriminator improves in distinguishing real from fake. This adversarial training continues until the Generator produces data that is highly realistic.

### Applications of GANs
1. **Image Generation**:  Generating realistic images (e.g., faces, landscapes, objects).
2. **Data Augmentation**: Creating synthetic data for training models in low-data scenarios.
3. **Style Transfer**: Applying artistic styles to images (e.g., turning a photo into a painting style).
4. **Super-Resolution**: Enhancing the resolution of images.
5. **Text-to-Image Generation**: Generating images based on textual descriptions.
6. **Video Generation**: Creating realistic videos or animations.

### Example
For a GAN trained on the **MNIST dataset** (handwritten digits):
- **Input to Generator**: Random noise vector (e.g., 100 dimensions).
- **Output from Generator**: An image resembling a handwritten digit.
- **Input to Discriminator**: Either a real image (from the MNIST dataset) or a fake image (from the Generator).
- **Output from Discriminator**: A probability of the input being real or fake.

GANs are widely used in generative tasks and have revolutionized areas like computer vision and creative AI.

In [None]:
from __future__ import print_function, division
from sklearn import datasets
import math
import matplotlib.pyplot as plt
import numpy as np
import progressbar

from sklearn.datasets import fetch_mldata
import nbimporter
from optimizers import Adam
from loss_functions import CrossEntropy
from layers import Dense, Dropout, Flatten, Activation, Reshape, BatchNormalization
from neural_network import NeuralNetwork


class GAN():
    """A Generative Adversarial Network with deep fully-connected neural nets as
    Generator and Discriminator.

    Training Data: MNIST Handwritten Digits (28x28 images)
    """
    def __init__(self):
        self.img_rows = 28 
        self.img_cols = 28
        self.img_dim = self.img_rows * self.img_cols
        self.latent_dim = 100

        optimizer = Adam(learning_rate=0.0002, b1=0.5)
        loss_function = CrossEntropy

        # Build the discriminator
        self.discriminator = self.build_discriminator(optimizer, loss_function)

        # Build the generator
        self.generator = self.build_generator(optimizer, loss_function)

        # Build the combined model
        self.combined = NeuralNetwork(optimizer=optimizer, loss=loss_function)
        self.combined.layers.extend(self.generator.layers)
        self.combined.layers.extend(self.discriminator.layers)

        print ()
        self.generator.summary(name="Generator")
        self.discriminator.summary(name="Discriminator")

    def build_generator(self, optimizer, loss_function):
        
        model = NeuralNetwork(optimizer=optimizer, loss=loss_function)

        model.add(Dense(256, input_shape=(self.latent_dim,)))
        model.add(Activation('leaky_relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(Activation('leaky_relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(Activation('leaky_relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(self.img_dim))
        model.add(Activation('tanh'))

        return model

    def build_discriminator(self, optimizer, loss_function):
        
        model = NeuralNetwork(optimizer=optimizer, loss=loss_function)

        model.add(Dense(512, input_shape=(self.img_dim,)))
        model.add(Activation('leaky_relu'))
        model.add(Dropout(0.5))
        model.add(Dense(256))
        model.add(Activation('leaky_relu'))
        model.add(Dropout(0.5))
        model.add(Dense(2))
        model.add(Activation('softmax'))

        return model

    def train(self, n_epochs, batch_size=128, save_interval=50):

        mnist = fetch_mldata('MNIST original')

        X = mnist.data
        y = mnist.target

        # Rescale [-1, 1]
        X = (X.astype(np.float32) - 127.5) / 127.5

        half_batch = int(batch_size / 2)

        for epoch in range(n_epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            self.discriminator.set_trainable(True)

            # Select a random half batch of images
            idx = np.random.randint(0, X.shape[0], half_batch)
            imgs = X[idx]

            # Sample noise to use as generator input
            noise = np.random.normal(0, 1, (half_batch, self.latent_dim))

            # Generate a half batch of images
            gen_imgs = self.generator.predict(noise)

            # Valid = [1, 0], Fake = [0, 1]
            valid = np.concatenate((np.ones((half_batch, 1)), np.zeros((half_batch, 1))), axis=1)
            fake = np.concatenate((np.zeros((half_batch, 1)), np.ones((half_batch, 1))), axis=1)

            # Train the discriminator
            d_loss_real, d_acc_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake, d_acc_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * (d_loss_real + d_loss_fake)
            d_acc = 0.5 * (d_acc_real + d_acc_fake)


            # ---------------------
            #  Train Generator
            # ---------------------

            # We only want to train the generator for the combined model
            self.discriminator.set_trainable(False)

            # Sample noise and use as generator input
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # The generator wants the discriminator to label the generated samples as valid
            valid = np.concatenate((np.ones((batch_size, 1)), np.zeros((batch_size, 1))), axis=1)

            # Train the generator
            g_loss, g_acc = self.combined.train_on_batch(noise, valid)

            # Display the progress
            print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, acc: %.2f%%]" % (epoch, d_loss, 100*d_acc, g_loss, 100*g_acc))

            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.save_imgs(epoch)

    def save_imgs(self, epoch):
        r, c = 5, 5 # Grid size
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        # Generate images and reshape to image shape
        gen_imgs = self.generator.predict(noise).reshape((-1, self.img_rows, self.img_cols))

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        plt.suptitle("Generative Adversarial Network")
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(n_epochs=200000, batch_size=64, save_interval=400)

