# Plan
## This notebook
1.   Walk through the boilerplate code
2.   You write your first 1D GAN

## Teaser for the next seminar
1. Adapt the 1D GAN for 5D conditional BaBar DIRC generation
2. Make it into a Wassershtein GAN

__@this notebook__ will guide you through a very simple case of generative adversarial networks. Like.. veeery simple. Generative adversarial network learn to sample distributions. And here we will solve the easiest imaginable task - learning to sample a 1D normal data distribution.

This notebook features a lot of useful visualizations that will help you both get acquainted with the behavior of two networks and debug common errors without having to wait hours of GPU time.

Inherited from https://github.com/yandexdataschool/mlhep2019/blob/master/notebooks/day-6/06_1D_GAN.ipynb

# Task 1 (difficulty: the easiest GAN we managed to invent)

In [0]:
%matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import clear_output
import numpy as np
from scipy.stats import ks_2samp
import torch
import torch.nn as nn
import torch.nn.functional as F

In [0]:
LATENT_DIM = 4
MEAN = 5
STD = 1.5
DATA_DISTRIBUTION = norm(loc=MEAN, scale=STD)
def sample_noise(batch_size):
    """Returns uniform noise of shape [batch_size, LATENT_DIM] in range [0, 1]"""
    return torch.rand(batch_size, LATENT_DIM, device="cuda")

def sample_real_data(batch_size):
    """
    Returns a sample of the target distribution -
    Gaussian(mu=MEAN, std=1.5) of shape [batch_size, 1]
    """
    return torch.randn(batch_size, 1, device="cuda") * STD + MEAN

In [0]:
# Generator converts noise into 1D data
gen = nn.Sequential(nn.Linear(LATENT_DIM, 16), nn.ELU(),
                    nn.Linear(16, 16), nn.ELU(),
                    nn.Linear(16, 1)).cuda()
gen_opt = torch.optim.SGD(gen.parameters(), lr=1e-2)

# Discriminator converts data into a single number, whose
# softmax is the probability of the example being real.
# It is deliberately made stronger than generator to make sure disc
# is slightly "ahead in the game".
disc = nn.Sequential(nn.Linear(1, 32), nn.ELU(),
                     nn.Linear(32, 32), nn.ELU(),
                     nn.Linear(32, 1)).cuda()
disc_opt = torch.optim.SGD(disc.parameters(), lr=1e-2)

OK, now that we have our model defined, we need our loss functions. Historically the first loss used in GANs is the cross-entropy that we already used so many times:
$$\mathscr{L}^{\text{discr}} =
-\mathbb{E}\log\left[P(\text{real is real})\right]
-\mathbb{E}\log\left[P(\text{fake is fake})\right]
=
-\mathbb{E}\left[\log D(x_\text{real})\right]
-\mathbb{E}\left[\log (1 - D(x_\text{fake}))\right] 
$$

The generator loss presents a choice. We can either maximize the likelihood that generated examples are labeled as real, or minimize the likelihood that they are labeled false.

$$\mathscr{L}^{\text{gen}} =
-\mathbb{E}\log\left[P(\text{fake is real})\right] =
-\mathbb{E}\left[\log (D(x_\text{fake}))\right]$$

OR

$$\mathscr{L}^{\text{gen}} =
\mathbb{E}\log\left[P(\text{fake is fake})\right] =
\mathbb{E}\left[\log (1 - D(x_\text{fake}))\right]$$

**Question to you.** Which way do think is better?

Note that here $D(x)$ is the probability the discriminator assigns to $x$ to be from the real dataset, don't forget to apply sigmoid to the NN's output: $D(x) = \sigma(\text{discriminator}(x))$. Try implementing these loss functions below. You should use the `logsigmoid` as a stable realization of $log\left[\sigma(x)\right]$. Note that $1 - \sigma(x) = \sigma(-x)$.

In [0]:
def train_disc(batch_size):
    """
    Trains the discriminator for one step. Please note
    this is not a pure function, it captutes the majority of variables
    from the context.
    """

    # compute log[p(real | x)]
    real_data = sample_real_data(batch_size)
    logp_real_is_real = F.logsigmoid(disc(real_data))

    # there is a function in the begining of the notebook...
    # noise = <sample noise>
    # YOUR CODE HERE
    raise NotImplementedError()

    # supply the generated noise sample as input to the
    # generator network
    # gen_data = <generate data given noise>
    # YOUR CODE HERE
    raise NotImplementedError()

    # logp_gen_is_fake = <compute log[p(example is fake)]>
    # YOUR CODE HERE
    raise NotImplementedError()
    # Note that sigmoid from the discrimiantor output is
    # the probability that the example is real. The definition of sigmoid
    # https://en.wikipedia.org/wiki/Sigmoid_function
    # and the text above this cell
    # are your friends here for an elegant expression.

    # You want the discriminator to maximize the probabilities
    # of real examples being labeled real and fake examples being labeled fake
    # L_D = - log[p(real is real)] - log[p(generated is fake.mean()
    # disc_loss = <disc_loss>
    # YOUR CODE HERE
    raise NotImplementedError()

    # SGD step. We zero_grad first to clear any gradients left from generator training
    disc_opt.zero_grad()
    disc_loss.backward()
    disc_opt.step()


In [0]:
def train_gen(batch_size):
    """Trains generator for one step"""

    # sample the generator
    # noise = <sample noise>
    # gen_data = <generate data given noise>
    # YOUR CODE HERE
    raise NotImplementedError()

    # compute log[p(generated data is real)]
    # logp_gen_is_real = <compute log(p(generated example is REAL))>
    # YOUR CODE HERE
    raise NotImplementedError()

    # The generator training wants to maximize the probability that
    # the generated examples are labeled real by the descriminator
    # gen_loss = <generator loss>
    # YOUR CODE HERE
    raise NotImplementedError()

    gen_opt.zero_grad()
    gen_loss.backward()
    gen_opt.step()

In [0]:
DISCRIMINATOR_ITERATIONS_PER_GENEREATOR = 5
TRAIN_BATCH_SIZE = 128
VALIDATION_BATCH_SIZE = 2048
VALIDATION_INTERVAL = 256
HIST_BINS=30
DATA_HIST_RANGE=[0, 10]

data_linspace_np = np.linspace(0, 10, num=256, dtype=np.float32)
data_linspace_torch = torch.from_numpy(data_linspace_np)[:, None].cuda()

for i in range(10000):
    for _ in range(DISCRIMINATOR_ITERATIONS_PER_GENEREATOR):
        train_disc(TRAIN_BATCH_SIZE)

    train_gen(TRAIN_BATCH_SIZE)

    if i % VALIDATION_INTERVAL == 0:
        clear_output(True)
        validation_noise = sample_noise(VALIDATION_BATCH_SIZE)
        validation_data = sample_real_data(VALIDATION_BATCH_SIZE)
        validation_generated = gen(validation_noise)

        validation_generated_np = validation_generated.data.cpu().numpy()
        validation_data_np = validation_data.data.cpu().numpy()

        fig, (ax_data, ax_losses) = plt.subplots(ncols=2, figsize=[14, 6])
        ax_data.set_title("Data distributions")
        ax_data_plots = []
        ax_data_plots.append(ax_data.hist(validation_generated_np, range=DATA_HIST_RANGE,
                             alpha=0.5, density=True, label='Generated', bins=HIST_BINS)[2][0])
        ax_data_plots.append(ax_data.hist(validation_data_np, range=DATA_HIST_RANGE,
                             alpha=0.5, density=True, label='Real', bins=HIST_BINS)[2][0])
        ax_data.set_xlabel("x")
        ax_data.set_ylabel("data & generated distributions")
        disc_preal_np = torch.sigmoid(disc(data_linspace_torch)).data.cpu().numpy()
        ax_disc = ax_data.twinx()
        ax_data_plots.extend(ax_disc.plot(data_linspace_np, disc_preal_np, label="discriminator's P(x is real)"))
        ax_disc.set_ylabel("discriminator's P(x is real)")
        ax_data.legend(ax_data_plots, [o.get_label() for o in ax_data_plots], loc="best")

        ax_losses.set_title("Discriminator readout distribution")
        ax_losses.hist(torch.sigmoid(disc(validation_generated)).data.cpu().numpy(),
                 range=[0, 1], alpha=0.5, label='D(G(z)) //  generated data', bins=HIST_BINS, density=True)
        ax_losses.hist(torch.sigmoid(disc(validation_data)).data.cpu().numpy(),
                 range=[0, 1], alpha=0.5, label='D(x) // real data', bins=HIST_BINS, density=True)
        ax_losses.legend()
        ax_losses.set_xlabel("Discriminator output")
        ks_result = kstest(validation_generated_np.ravel(), DATA_DISTRIBUTION.cdf)
        fig.suptitle(f"Iteration {i}, Kolmogorov-Smirnov statistic {ks_result.statistic:.3}, p-value {ks_result.pvalue:.3}")
        plt.show()

In [0]:
# Note, the test uses a larger VALIDATION_BATCH_SIZE than the monitoring in the loop
validation_noise = sample_noise(100*VALIDATION_BATCH_SIZE)
validation_generated = gen(validation_noise)
validation_generated_np = validation_generated.data.cpu().numpy()
ks_result = kstest(validation_generated_np.ravel(), DATA_DISTRIBUTION.cdf)
assert(ks_result.statistic) < 0.05

Question to you. Do you think adding dropout will improve the model?

__What to expect:__
* __left:__ two distributions will start differently, but generator distribution should match real data _almost_ everywhere. The curve represents discriminator's opinion on all possible values of x. It should slowly get closer to 0.5 over areas where real data is dense.
* __right:__ this chart shows how frequently does discriminator assign given probability to samples from real and generated data samples (shown in different colors). First several iterations will vary, but eventually they will both have nearly all probability mass around 0.5 as generator becomes better at it's job.
 * If instead it converges to two delta-functions around 0(gen) and 1(real) each, your discriminator has won. _Check generator loss function_. As a final measure, try decreasing discriminator learning rate. This can also happen if you replace mean over batch with sum or similar.
 * If it converges to 0.5 and stays there for several iterations but generator haven't learned to generate plausible data yet, generator is winning the game. _Double-check discriminator loss function_.
 
 __Reference plots for a trained GAN__:
 ![Reference plots for a trained GAN](https://github.com/yandexdataschool/mlhep2019/raw/master/notebooks/day-6/trained_1D_GAN.png)

# Task 1.X Bonus (diffuculty: we don't guarantee it's possible)
The GAN above transformes 4D latent space into 1D Gaussian.
Make a GAN that would transform 1D $\mathbb U[0, 1]$ space into the same 1D Gaussian.
P. S.
It is a trivila task for, say, quantile transform. But we speak GAN here.

# P. S.
If you feel stuck, feel free to take a look at a [completed implementation of JS GAN](https://github.com/kazeevn/mnist-gan-demo/blob/master/GAN%20training.ipynb), that we used to produce the [demo](https://en.pelican.study/classroom/213/dialogs/2619/run/). In the opposite case of finishing the assignment, you might still want to take a look. Most of the GAN applications are related to images - and the MNIST GAN is an image GAN.