# Introduction to Generative Adversarial Networks (GANs)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
def show_img(path, title, figsize=(12,6)):
    plt.figure(figsize=figsize)
    img = mpimg.imread(path)
    imgplot = plt.imshow(img)
    plt.axis('off')
    plt.title(title)
    plt.show()

In this notebook, we will be familiar with an advanced deep learning technique named Generative Adversarial Networks (GANs). This notebook contains background, application, concept, algorithm, architecture, evaluation, and a code demo of typical GANs. We only focus on two typical GANs: original <a href='http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf'>GANs</a> and Deep Convolutional GAN (<a href='https://arxiv.org/pdf/1511.06434.pdf%C3'>DCGAN</a>).

Generative Adversarial Networks (GANs) is a cutting-edge technique of deep neural networks, which was first come up by Ian Goodfellow in 2014. In 2016, Yann LeCun (Facebook AI research director) described GAN as
> “the most interesting idea in the last 10 years in Machine Learning.”

GAN is a very new stuff and has a promising future. Especially in the recent two years, GAN was developed with an exponential increment (Fig.1). Although it is an infant technique, there are bunch of models proposed with the suffix “GAN”, such as ACGAN, DCGAN, WGAN, BEGAN, CycleGAN, StackGAN. There is a website called [“The GAN Zoo”](https://github.com/hindupuravinash/the-gan-zoo), where includes hundreds variants of GAN.

In [None]:
show_img('../input/introductiongan/introduction-gan/cumulative_gans.jpg', 'Fig.1 The GAN papers counts by The GAN Zoo')

# Application
### Image Generation
Most of deep learning techniques we learned previously are used for recognition. GAN technique is going to let the machine learn to generate new stuff. To note, GAN does not simply memorize the given dataset. Generation is always harder than recognition. For example, we first learned how to recognize digits like 0-9, then we tried to mimic the shape of digits and created digits in our styles, which is called generation.
As discussed in Chapter 20 of *Deep Learning*, there are many other generative models, such as Restricted Boltzmann Machine (RBM), Generative Stochastic Networks (GSNs), and Variational auto-encoder (VAE). Comparing with these generative models, GAN has the state-of-the-art performance in the filed of image generation. In other words, GAN is the most powerful image generative model.
<center></center>

In [None]:
show_img('../input/introductiongan/introduction-gan/began_face.jpg', 'Fig.2 Generated facial images by BEGAN', (8,4))

### Image-to-Image Translation
GAN can learn the features of two image collections and translate one images from one domain to another. CycleGAN is one representative unsupervised GAN with such functionality. To note, CycleGAN does not require these two datasets to be paired. For example, in Summer<->Winter translation, we need to collect photos of the same scene in different seasons. However, paired datasets are very expensive and often not available. Without paired datasets, CycleGAN can translate landscape photos into a particular painting style, such as Monet, Van Gogh. Besides, it can also translate zebras into horses.

In [None]:
show_img('../input/introductiongan/introduction-gan/cyclegan.jpg', 'Image-to-Image translation by CycleGAN')

### Text-to-Image Translation 
If we embed text information as the label of corresponding images, GAN will learn the mapping between sentiment vectors and image features. StackGAN is a conditional GAN which can generate images based on text description.

In [None]:
show_img('../input/introductiongan/introduction-gan/stackgan.jpg', 'Text-to-Image translation by StackGAN')

### Image Manipulation
There are many other applications in image manipulation. e.g. Super resolution: recovering the photo-realistic texture for the low resolution image. Photo inpainting: filling the missing area in a given image.

In [None]:
show_img('../input/introductiongan/introduction-gan/srgan.png', 'Super resolution by SRGAN')

In [None]:
show_img('../input/introductiongan/introduction-gan/inpainting.png', 'Photo inpainting with GAN')

### Section Summary
The underlying functionality of GAN is transforming data distribution from domain X to domain Y . If we define domain X as a normal (noise) distribution, GAN is going to generate images. If domain X is another image distribution, e.g. different style images, lower resolution images or masked images, GAN is going to transform input images X -> target images Y . If domain X is a text data distribution, GAN is going to generate images based on the input text query. After we learn the GAN technique, you are highly encouraged to propose some creative ideas on further applications.

# Concept
### What is a GAN?
It is like a zero-sum game in Game Theory (Example 1). i.e. One becomes better means the opponent must be worse. In a typical GAN, the “Criminal” is named as Generator G while the “Investigator” is Discriminator D. In this game, G is trying to generate real-like images to fool D while D is trying to figure all fake images out. These two models compete with each other and eventually reach to a Nash equilibrium, where both G and D cannot get better or worse any more.

---
**Example 1**. How to be a master of producing fake dollars?  
*Criminal*: Produced 1st version fake dollars and the *investigator* cannot figure them out.  
*Investigator*: Found the new fake dollars and successfully figured all 1st version fake dollars out.  
*Criminal*: Produced 2nd version fake dollars and the *investigator* cannot figure them out.  
*Investigator*: Found the new fake dollars and successfully figured all 2nd version fake dollars out.  
...  
n-th version fake dollars  
...  
⇒ *Criminal*: An expert in making fake dollars.  
⇒ *Investigator*: An expert in figuring fake dollars.  

---

### Why GAN works?
If a model only learns the data points rather than the whole distribution, this model can only memorize the dataset. A model can generate new data samples when it learns the whole distribution. GAN applies a zero-sum game to help G to simulate the real distribution. Then, we can sample any random points from this distribution as the generated data.

### Objective function
- **Discriminator**  

Discriminator $D$ is going to distinguish all fake images. Intuitively, we can scratch an objective function of $D$:
$$\max~J^{(D)} = \mathbb{E}_{x_r \sim X_r}[\textrm{score}(x_r)] - \mathbb{E}_{x_f \sim X_f}[\textrm{score}(x_f)] \tag{1}$$
where $x_r $ is sampled from the real distribution $X_r $, $x_f $ is sampled from the fake distribution $X_f$, $\textrm{score}()$ denotes an evaluation function which gives high scores on real images and low scores on fake images. Actually, Discriminator $D$ is like a binary classifier to distinguish real or fake.  
Let's replace the $\textrm{score}()$ function by $D$. If the output $D(x)$ is a real number, we hope $D(x_r) \rightarrow \infty$ and $D(x_f) \rightarrow -\infty$ to achieve the maximum $J^{(D)}$. However, the objective $J^{(D)} \rightarrow \infty$ is not applicable in optimization. Thus, we prefer to map the range to probability space, $(-\infty, \infty) \mapsto (0, 1)$. In the objective function, we add a sigmoid transfer function to the output layer.
$$sigmoid(x) \equiv \dfrac{1}{1+e^{-x}}$$  

In [None]:
show_img('../input/introductiongan/introduction-gan/d_simple.png', 'Discriminator with softmax/sigmoid output')

Discriminator with softmax/sigmoid output. In binary classification, these two methods have the same effect. For the following discussion, we use sigmoid output in $D$ by default.  
With the sigmoid function, the output denotes the probability of real, $D(x) = p(real)$. We hope the output probability of real images close to 1, $D(x_r) \rightarrow 1$, while the probability of fake images close to 0, $D(x_f) \rightarrow 0$. The later one can also be reformulated as $(1 - D(x_f)) \rightarrow 1$. Then, we modify the objective function (Eq.1) scratched at the beginning. Besides, maximizing the objective function (Eq.2a) is equivalent to minimizing the negative objective function, which is called loss function (Eq.2b).  
$$\max~J^{(D)} = \mathbb{E}_{x_r \sim X_r}[D(x_r)] + \mathbb{E}_{x_f \sim X_f}[1 - D(x_f)] \tag{2.a}$$
$$\Leftrightarrow \min~L^{(D)} =- \mathbb{E}_{x_r \sim X_r}[D(x_r)] - \mathbb{E}_{x_f \sim X_f}[1 - D(x_f)] \tag{2.b}$$  
The loss function (Eq.2b) is still based on mean absolute error. As for the loss functions in classification, we often prefer to use cross-entropy (CE) loss rather than mean absolute error (MAE, L1 loss) and mean squared error (MSE, L2 loss). As shown in Fig.3, CE loss can mitigate the gradient vanishing problem when we apply the sigmoid to the output layer.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# a = sigmoid(wp+b)
# error = t - a
error = np.linspace(0, 1, 100)
# loss_mse = e^2
# dL/dw = 2 * e * de/dw = 2 * e * -1 * da/dw = -2 * e * (1 - a) * a * dn/dw
#  = -2 * e * (1 - a) * a * p = -2 * e * e * (1 - e) * p  ... when t = 1
gradient_mse = - error ** 2 * (1 - error)

# loss_ce = - t * log a
# dL/dw = - t * 1/a * da/dw - log a = - t * 1/a * a * (1 - a) dn/dw
#  = - t * (1 - a) * p = - e * p   ... when t = 1
gradient_ce = - error

# loss_mae = t - a = e   ... when t = 1
# dL/dw = - 1 * da/dw = -1 * a * (1 - a) * dn/dw
#  = -1 * a * (1 - a) * p = - (1 - e) * e * p
gradient_mae = - error * (1 - error)

plt.plot(error, -gradient_ce, label='CE')
plt.plot(error, -gradient_mse, label='MSE')
plt.plot(error, -gradient_mae, label='MAE')
plt.xlabel("absolute error")
plt.ylabel("norm of gradient")
plt.title("Fig3. Compare Cross-entropy with MAE, MSE")
plt.legend()
plt.show()

MAE/MSE loss: gradient vanishing when the error is large. CE loss: the norm of gradient $\|\nabla\|$ is positively related to the error, a large error leads to a fast learning speed.

> **Definition 1.** The cross-entropy of the distribution $q$ relative to a distribution $p$ over a given set is defined as follows:
> $$CE(p, q) = -\mathbb{E}_p[\log q]$$

Let's apply the CE loss into the loss function for $D$. For real images, $p$ is  $[1~0]^T$, $q$ is $[D(x_r)~1-D(x_r)]^T$. For fake images, $p$ is $[0~1]^T$, $q$ is $[D(x_f)~1-D(x_f)]^T$. The following equation is a typical binary cross-entropy (BCE) loss.

$$
\begin{align}
BCE & =
	- \mathbb{E}_{x_r \sim X_r}
	\left[
		\begin{bmatrix}
           1 & 0
    		\end{bmatrix}
   		 ~
    		\begin{bmatrix}
          \log D(x_r) \\
           \log (1 - D(x_r))
    		\end{bmatrix}
	\right]
	-
	\mathbb{E}_{x_f \sim X_f}
	\left[
		\begin{bmatrix}
           0 & 1
    		\end{bmatrix}
   		 ~
    		\begin{bmatrix}
           \log D(x_f) \\
           \log (1 - D(x_f))
    		\end{bmatrix}
	\right]
\end{align}
$$

With algebraic operations, the loss function for $D$ can be derived as below. For notation consistency, we replace $x_f, X_f$ by $x_g, X_g$ because generated images $x_g$ are exactly identified as fake images $x_f$ by $D$. $\theta_D$ denotes the parameters in $D$ model, which are the weights in neural networks. Our goal is to find the optimal weights to minimize the following BCE loss function.
$$\Rightarrow \underset{\theta_D} \min~L^{(D)} =- \mathbb{E}_{x_r \sim X_r}[\log D(x_r)] - \mathbb{E}_{x_g \sim X_g}[\log (1 - D(x_g))] \tag{3}$$

> Any loss consisting of a negative log-likelihood is a cross-entropy between the empirical distribution defined by the training set and the probability distribution defined by model. For example, mean squared error is the cross-entropy between the empirical distribution and a Gaussian model.
-- Page 132, *Deep Learning*, 2016

The nature of Eq.3: Cross-entropy is a divergence to measure the difference between two distributions. $- \mathbb{E}_{x_r \sim X_r}[\log D(x_r)]$ is minimizing the divergence between real distribution $X_r$ in the training set and the probability distribution defined by $D$ model; and $- \mathbb{E}_{x_f \sim X_f}[\log (1 - D(x_f))]$ is maximizing the divergence between generated distribution $X_f$ and the probability distribution defined by $D$ model.

- **Generator**

Generator $G$ is going to generate real-like images to fool $D$.
Similarly, we can scratch an objective function for $G$:
$$\max~J^{(G)} = \mathbb{E}_{x_g \sim X_g}[\textrm{score}(x_g)] \tag{4}$$
where $\textrm{score}()$ still denotes the $D$ function, $x_g$ is sampled from the generated distribution $X_g$. Here, we use $G$  to map a normal distribution with the generated distribution, $N(0,1) \overset{G} \mapsto X_g$. In other words, we feed a noise vector $z \sim N(0,1)$ into the $G$ function and get a generated image, $x_g = G(z)$.

In [None]:
show_img('../input/introductiongan/introduction-gan/g_simple.png', 'The training process for Generator', (8,4))

As we have trained $D$ to classify real/fake images. The objective of $G$ is to make $D$ classify the generated images into real type.

Corresponding to $L^{(D)}$ (Eq.3), we also use CE loss on the generated images. There is a trick that $L^{(G)}$ can simply borrow the second term from $L^{(D)}$. i.e. $D$ is maximizing  $- \mathbb{E}_{x_g \sim X_g}[\log (1 - D(x_g))]$ and $G$ can minimizing this term to compete with $D$. This method is exactly a minimax game and is proposed in the original GAN paper. We refer it (Eq.5a) as Saturating loss function. However, Saturating loss function cannot provide sufficient gradient for $G$ to learn. For example, in the early learning step, $D$ can easily distinguish generated images and real images. Thus, $D(x_g)$ is close to 0, $\log(1-D(x_g))$ will saturate to 0. (Fig.4)
$$ \textrm{Saturating: } \underset{\theta_G} \min~L^{(G)} = \mathbb{E}_{x_g \sim X_g}[\log (1 - D(x_g))] \tag{5.a}$$
Later, Ian Goodfellow proposed a more stable and efficient loss function for $G$. We can calculate the CE of generated images independently, $p$ is $[1~0]^T$, $q$ is $[D(x_g)~1-D(x_g)]^T$. We refer it (Eq.5b) as Non-Saturating loss function. With this loss function, $-\log D(x_g)$ is very large and non-saturating at the early learning period. (Fig.4)
$$ \textrm{Non-Saturating: } \underset{\theta_G} \min~L^{(G)} = - \mathbb{E}_{x_g \sim X_g}[\log D(x_g)] \tag{5.b}$$

In [None]:
x = np.linspace(0,1,50)[1:-1]
y_ns = -np.log(x)
y_s = np.log(1-x)
plt.plot(x, y_ns, label='Non-saturating')
plt.plot(x, y_s, label='Saturating')
plt.xlabel('$D(G(z))$')
plt.title('Fig.4 Saturating vs Non-saturating')
plt.legend()
plt.show()

In practice, Non-Saturating $L^{(G)}$ is better than Saturating $L^{(G)}$. Intuitively, it would be better to set a large learning rate at the early stage of gradient descent while a small learning rate when closing to the optima. 

# Algorithm
If we combine $L^{(G)}$ of Eq.5a and $L^{(D)}$ of Eq.3 together, the aggregate objective function for GAN, $L^{(GAN)}$, can be derived as:
$$ \textrm{minimax: } \underset{\theta_G} \min~\underset{\theta_D}\max~L^{(GAN)} = \mathbb{E}_{x_r \sim X_r}[\log D(x_r)] + \mathbb{E}_{x_g \sim X_g}[\log (1 - D(x_g))] \tag{6.a}$$
It is the objective function of original GAN with a minimax game. However, we cannot optimize this combined loss function by changing $G$ and $D$ simultaneously. In practice, we can only train these two models alternatively. In each training iteration, we train $D$ and $G$ sequentially. (Algorithm~1)

In [None]:
show_img('../input/introductiongan/algorithm.png', 'Pseudo code of GAN')

**Training Discriminator.** First, we feed noise vectors $z$ into the $G$ from previous iteration to generate images $x_g = G(z)$. Then, we feed half batch of generated images $x_g$ and half batch of real images $x_r$ into the $D$ from previous iteration. With the loss function (Eq.3), back-propagating gradients to update the parameters in Discriminator $\theta_D$. To note, the parameters in Generator $\theta_G$ is frozen. Up to now, we only update $D$ in current iteration.

In [None]:
show_img('../input/introductiongan/introduction-gan/d_train.png', 'Training Discriminator', (8,4))

**Training Generator.** Next, we freeze $\theta_D$ and going to update $\theta_G$. Similarly, we feed noise vectors $z$ into the $G$ from previous iteration to generate images $x_g = G(z)$. Then, we feed a batch of generated images $x_g$ into the updated $D$. With the loss function (Eq.5b), back-propagating gradients to update the parameters in Generator $\theta_G$. After that, we complete one training iteration for GAN.

In [None]:
show_img('../input/introductiongan/introduction-gan/g_train.png', 'Training Generator', (8,4))

# Architecture
- **Discriminator**

Downsampling networks (Input: images $\mapsto$ Output: probability)

In [None]:
show_img('../input/introductiongan/introduction-gan/d_arch.png', 'An example architecture of the Discriminator in DCGAN')

- Dense / Fully connected layer: i) Decreasing number of neurons for layers. ii) It cannot be deep and is not good at extracting features.
- Maxpooling2D. i) The output is just selecting the maximum input value within the  `[height, width]` window of input values. ii) There is no weight and no trainable parameters introduced by this operation.

$$
\begin{align}
\texttt{Maxpooling2D([2,2])} & :
		\begin{bmatrix}
           1 & 2 & 3 & 4 \\
           5 & 6 & 7 & 8 \\
           8 & 7 & 4 & 3 \\
           6 & 5 & 2 & 1 \\
    		\end{bmatrix}
	\mapsto
    		\begin{bmatrix}
           6 & 8 \\
           8 & 4 
    		\end{bmatrix}
\end{align}
$$

- Convolution2D (`stride=2`). i)The output is a linear combination of the input values times a weight for each cell in the `[height, width]` kernel/filter. ii) These weights become trainable parameters in your model.

In [None]:
show_img('../input/introductiongan/introduction-gan/conv.png', 'Convolution operation. (Modified from indoml.com)')

- **Generator**

Upsampling networks. (Input: noise vectors $\mapsto$ Output: images)

In [None]:
show_img('../input/introductiongan/introduction-gan/g_arch.png', 'An example architecture of the Generator in DCGAN')

- Dense / Fully connected layer: i) Increasing number of neurons for layers. ii) It cannot be deep and is not good at generating features.
- Upsampling2D. i) The `[height, width]` window of output values is just repeating the corresponding input value.
ii) There is no weight and no trainable parameters introduced by this operation.

$$
\begin{align}
\texttt{Upsampling2D([2,2])} & :
		\begin{bmatrix}
           1 & 2 \\
           3 & 4 
    		\end{bmatrix}
	\mapsto
		\begin{bmatrix}
           1 & 1 & 2 & 2 \\
           1 & 1 & 2 & 2 \\
           3 & 3 & 4 & 4 \\
           3 & 3 & 4 & 4 \\
    		\end{bmatrix}
\end{align}
$$

- ConvTranspose2D / Deconvolution2D (`stride=2`). 
i) It is also named as fractional-strided convolution. Stride here is the reciprocal of the moving step. e.g. `stride=2` $\Rightarrow$  moving step $=\frac{1}{2}$. How to move $\frac{1}{2}$ step? Inserting zero columns and rows to the original input.
ii) Similar to convolution operation, these weights become trainable parameters in your model.

In [None]:
show_img('../input/introductiongan/introduction-gan/deconv.png', 'Transposed convolution operation. \n (View more visualizations at \n https://github.com/vdumoulin/conv_arithmetic \n by Vincent Dumoulin, Francesco Visin.)')

# Evaluation
GAN is a very new topic, it is still an open problem to find a perfect evaluation of GAN or generated images. Generally, we measure the generated images in two dimensions: quality of images and diversity of images. 

Apart from comparing the generated images with real images by our eyes, there are two common mathematical metrics to evaluate the quality of the generated images: Inception Score (IS) and Frèchet Inception Distance (FID). Both of these two measurements are based on the Inception V3 network, which is pretrained on ImageNet dataset. IS is derived from the classification output while FID is derived from the feature layer. ****

In [None]:
show_img('../input/introductiongan/introduction-gan/fid_is.png', 'FID vs IS')

- **Inception Score**

IS measures the KL divergence (similar to cross-entropy, see Def.2) between the generated sample distribution and the ImageNet distribution, whereas FID calculates the feature-level distance between the generated sample distribution and the real sample distribution.

$$IS(X_g) = \exp\left(\mathbb{E}_{x_g \sim X_g} \left[D_\textrm{KL}\left(~p\left(y|x_g\right) \parallel p\left(y\right)~\right)\right]\right)$$

where $x_g \sim X_g$ indicates that $x_g$ is an image sampled from $X_g$, $D_\textrm{KL}(P \parallel Q)$ is the KL divergence between the distributions P and Q, $p(y|x_g)$ is the conditional class distribution, and $p(y) = \mathbb{E}_{x_g \sim X_g} \left[p(y|x_g)\right]$ is the marginal class distribution.

$D_\textrm{KL}$ calculation will output a small number. The $\exp$ in the expression is there to make the values easier to compare, so it will be ignored if we use $\ln(IS)$ without loss of generality.

**Remark.** InceptionV3 network is trained on ImageNet and used to predict the class of given images. If the softmax output is sharp ($p(\textrm{one class}) \rightarrow 1, p(\textrm{other classes}) \rightarrow 0$), the IS will be small and the quality of this image is high. Besides, the diversity of softmax output indicates the diversity of images.

> **Definition 2.** In mathematical statistics, the Kullback–Leibler (KL) divergence (also called relative entropy) is a measure of how one probability distribution is different from a second, reference probability distribution. 

$$
\begin{align*}
D_\textrm{KL}(P \parallel Q) &= \sum_{x\in\mathcal{X}} P(x) \log\left(\frac{P(x)}{Q(x)}\right)\\
&= \underbrace{\left(-\sum_{x\in\mathcal{X}} P(x) \log Q(x)\right)}_{\textrm{Cross-entropy}(P,Q)} - \underbrace{\left(-\sum_{x\in\mathcal{X}} P(x) \log P(x)\right)}_{\textrm{Cross-entropy}(P,P)}
\end{align*}
$$

This equation is a discrete case of KL divergence, where $P$ and $Q$ are two probability distributions, $\mathcal{X}$ is the probability space. For the continuous case, $P(x), Q(x)$ are probability density functions.

- **Frèchet Inception Distance**

IS totally depends on the InceptionV3 knowledge on ImageNet. If we input a high quality image which cannot be classified into ImageNet classes to InceptionV3 network, then the softmax output will not be a sharp value. In other words, IS can only evaluate the generated images which should belong to ImageNet dataset. In comparison, FID is a more general evaluation and applicable to new datasets. Besides, feature-level information far surpasses classification-level information.

$$FID={\Vert}\mu_r-\mu_g{\Vert}^{2}+Tr\left(\Sigma_r+\Sigma_g-2\left(\Sigma_r\Sigma_g\right)^{1/2}\right)$$

where $\mu_r $ is the mean of the real features, $\mu_g $ is the mean of the generated features, $\Sigma_r $ is the covariance matrix of the real features, $\Sigma_g $ is the covariance matrix of the generated features. 

# DCGAN Code Practice

In [None]:
# %% --------------------------------------- Load Packages -------------------------------------------------------------
import os
import random
import tensorflow as tf  # tf.__version__ >= 2.2.0
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Input, Reshape, Dense, Dropout, \
    LeakyReLU, Conv2D, Conv2DTranspose, Embedding, \
    Concatenate, multiply, Flatten, BatchNormalization
from tensorflow.keras.initializers import glorot_normal
from tensorflow.keras.optimizers import Adam

In [None]:
# %% --------------------------------------- Fix Seeds -----------------------------------------------------------------
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)
weight_init = glorot_normal(seed=SEED)

In [None]:
# # Load MNIST Fashion
from tensorflow.keras.datasets.fashion_mnist import load_data

In [None]:
# %% ---------------------------------- Data Preparation ---------------------------------------------------------------
# change as channel last (n, dim, dim, channel)
def change_image_shape(images):
    shape_tuple = images.shape
    if len(shape_tuple) == 3:
        images = images.reshape(-1, shape_tuple[-1], shape_tuple[-1], 1)
    elif shape_tuple == 4 and shape_tuple[-1] > 3:
        images = images.reshape(-1, shape_tuple[-1], shape_tuple[-1], shape_tuple[1])
    return images

# # Load training set
(x_train, y_train), (x_test, y_test) = load_data()
x_train, x_test = change_image_shape(x_train), change_image_shape(x_test)
y_train, y_test = y_train.reshape(-1), y_test.reshape(-1)

######################## Preprocessing ##########################
# Set channel
channel = x_train.shape[-1]

# It is suggested to use [-1, 1] input for GAN training
x_train = (x_train.astype('float32') - 127.5) / 127.5
x_test = (x_test.astype('float32') - 127.5) / 127.5

# Get image size
img_size = x_train[0].shape

# Get number of classes
n_classes = len(np.unique(y_train))

In [None]:
# %% ---------------------------------- Hyperparameters ----------------------------------------------------------------

# optimizer = Adam(lr=0.0002, beta_1=0.5, beta_2=0.9)
latent_dim = 32
## trainRatio === times(Train D) / times(Train G)
# trainRatio = 5

In [None]:
# %% ---------------------------------- Models Setup -------------------------------------------------------------------
# Build Generator with convolution layer
def generator_conv():
    noise = Input(shape=(latent_dim,))
    x = Dense(3 * 3 * 128)(noise)
    x = LeakyReLU(alpha=0.2)(x)

    ## Out size: 3 x 3 x 128
    x = Reshape((3, 3, 128))(x)

    ## Size: 7 x 7 x 128
    # remove padding='same' to scale 6x6 up to 7x7
    x = Conv2DTranspose(filters=128,
                        kernel_size=(3, 3),
                        strides=(2, 2),
                        # padding='same',
                        kernel_initializer=weight_init)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)

    ## Size: 14 x 14 x 64
    x = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', kernel_initializer=weight_init)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)

    ## Size: 28 x 28 x channel
    out = Conv2DTranspose(channel, (3, 3), activation='tanh', strides=(2, 2), padding='same',
                          kernel_initializer=weight_init)(x)

    model = Model(inputs=noise, outputs=out)
    return model

In [None]:
# Build Discriminator with convolution layer
def discriminator_conv():
    # 28 x 28 x channel
    img = Input(img_size)

    # 14 x 14 x 32
    x = Conv2D(32, kernel_size=(3, 3), strides=(2, 2), padding='same', kernel_initializer=weight_init)(img)
    x = LeakyReLU(0.2)(x)

    # 7 x 7 x 64
    x = Conv2D(64, (3, 3), strides=(2, 2), padding='same', kernel_initializer=weight_init)(x)
    x = LeakyReLU(0.2)(x)

    # 3 x 3 x 128
    x = Conv2D(128, (3, 3), strides=(2, 2), kernel_initializer=weight_init)(x)
    x = LeakyReLU(0.2)(x)

    x = Flatten()(x)
    x = Dropout(0.4)(x)
    out = Dense(1)(x)

    model = Model(inputs=img, outputs=out)
    return model

In [None]:
# %% ----------------------------------- GAN Part ----------------------------------------------------------------------
# Build our GAN
class DCGAN(Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
        train_ratio=1,
    ):
        super(DCGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.train_ratio = train_ratio

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(DCGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def train_step(self, data):
        if isinstance(data, tuple):
            real_images = data[0]
        else:
            real_images = data
        # Get the batch size
        batch_size = tf.shape(real_images)[0]

        ########################### Train the Discriminator ###########################
        # training train_ratio times on D while training once on G
        for i in range(self.train_ratio):
            # Get the latent vector
            random_latent_vectors = tf.random.normal(
                shape=(batch_size, self.latent_dim)
            )
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator(random_latent_vectors, training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator(fake_images, training=True)
                # Get the logits for real images
                real_logits = self.discriminator(real_images, training=True)

                # Calculate loss of D
                d_loss = self.d_loss_fn(real_logits, fake_logits)


            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(
                zip(d_gradient, self.discriminator.trainable_variables)
            )

        ########################### Train the Generator ###########################
        # Get the latent vector
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator(random_latent_vectors, training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator(generated_images, training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )
        return {"d_loss": d_loss, "g_loss": g_loss}

In [None]:
# %% ----------------------------------- Compile Models ----------------------------------------------------------------
# Optimizer for both the networks
# learning_rate=0.0002, beta_1=0.5, beta_2=0.9 are recommended
generator_optimizer = Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
discriminator_optimizer = Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)

# Define the loss functions to be used for discrimiator
def discriminator_loss(real_logits, fake_logits):
    real_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits, labels=tf.ones_like(real_logits)))
    fake_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.zeros_like(fake_logits)))

    return fake_loss + real_loss

# Define the loss functions to be used for generator
def generator_loss(fake_logits):
    fake_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.ones_like(fake_logits)))
    return fake_loss

d_model = discriminator_conv()
g_model = generator_conv()

dcgan = DCGAN(generator=g_model,
              discriminator=d_model,
              latent_dim=latent_dim,
              train_ratio=1)

dcgan.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
    g_loss_fn=generator_loss,
    d_loss_fn=discriminator_loss,
)

In [None]:
# %% ----------------------------------- Start Training ----------------------------------------------------------------
# Plot/save generated images through training
def plt_img(generator):
    np.random.seed(42)
    n = n_classes

    noise = np.random.normal(size=(n * n, latent_dim))
    decoded_imgs = generator.predict(noise)

    decoded_imgs = decoded_imgs * 0.5 + 0.5
    x_real = x_test * 0.5 + 0.5

    plt.figure(figsize=(n, n + 1))
    for i in range(n):
        # display original
        ax = plt.subplot(n + 1, n, i + 1)
        if channel == 3:
            plt.imshow(x_real[y_test == i][0].reshape(img_size[0], img_size[1], img_size[2]))
        else:
            plt.imshow(x_real[y_test == i][0].reshape(img_size[0], img_size[1]))
            plt.gray()
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

        for j in range(n):
            # display generation
            ax = plt.subplot(n + 1, n, (i + 1) * n + j + 1)
            if channel == 3:
                plt.imshow(decoded_imgs[i * n + j].reshape(img_size[0], img_size[1], img_size[2]))
            else:
                plt.imshow(decoded_imgs[i * n + j].reshape(img_size[0], img_size[1]))
                plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
    plt.show()
    return

############################# Start training #############################
LEARNING_STEPS = 6
for learning_step in range(LEARNING_STEPS):
    print('LEARNING STEP # ', learning_step + 1, '-' * 50)
    dcgan.fit(x_train, epochs=1, batch_size=128)
    if (learning_step+1)%2 == 0:
        plt_img(dcgan.generator)