# GAN (3)

## 1. Conditional GAN 을 이용한 MNIST 숫자 생성

- (ref: https://github.com/eriklindernoren/Keras-GAN/blob/master/cgan/cgan.py)

### 전체 학습코드

```python
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import numpy as np

class CGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.num_classes = 10
        self.latent_dim = 100

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss=['binary_crossentropy'],
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise and the target label as input
        # and generates the corresponding digit of that label
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,))
        img = self.generator([noise, label])

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated image as input and determines validity
        # and the label of that image
        valid = self.discriminator([img, label])

        # The combined model  (stacked generator and discriminator)
        # Trains generator to fool discriminator
        self.combined = Model([noise, label], valid)
        self.combined.compile(loss=['binary_crossentropy'],
            optimizer=optimizer)

    def build_generator(self):

        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))

        model_input = multiply([noise, label_embedding])
        img = model(model_input)

        return Model([noise, label], img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Dense(512, input_dim=np.prod(self.img_shape)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        label = Input(shape=(1,), dtype='int32')

        label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
        flat_img = Flatten()(img)

        model_input = multiply([flat_img, label_embedding])

        validity = model(model_input)

        return Model([img, label], validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, y_train), (_, _) = mnist.load_data()

        # Configure input
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)
        y_train = y_train.reshape(-1, 1)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs, labels = X_train[idx], y_train[idx]

            # Sample noise as generator input
            noise = np.random.normal(0, 1, (batch_size, 100))

            # Generate a half batch of new images
            gen_imgs = self.generator.predict([noise, labels])

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
            d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Condition on labels
            sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)

            # Train the generator
            g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 2, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        sampled_labels = np.arange(0, 10).reshape(-1, 1)

        gen_imgs = self.generator.predict([noise, sampled_labels])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
                axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    cgan = CGAN()
    cgan.train(epochs=20000, batch_size=32, sample_interval=200)
```

### 1) library import

`keras.layers`

- keras 의 하위 모듈인 layers 에서 여러가지 method 사용가능

- `Input`: keras tensor 객체 생성

- `Dense`: layer 사이를 결합 (Fully-Connected)

- `Reshape`: tensor 의 차원을 변형

- `Flatten`: 다차원 tensor 를 1 차원으로 변환

- `Dropout`: Input 에 dropout 기법 적용


In [4]:
# from __future__ import print_function, division

from keras.utils import plot_model # model visualization
from IPython.display import Image, display # model visualization

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import numpy as np

Using TensorFlow backend.


### 1-1) declaring class

```python
class CGAN():
    
    def __init__(self): # constructor
        # Input shape
        self.init_hyperparams()
        self.init_discriminator()
        self.init_generator()
        self.init_model()
```

### 2) Hyper-parameter settings & Generator / Discriminator details

#### Hyper-parameter settings

- MNIST 이미지는 28 Pixel $\times$ 28 Pixel 이미지
    
- 따라서 img_rows 와 img_cols 모두 28 이며, 흑백사진 이므로 채널은 1
    
- 분류해야할 클래스 (Y label) 은 0 ~ 9 총 10개
    
- hidden layer 은 100 개, Optimizer 는 AdamOptimizer 사용

```python
    def init_hyperparams(self):
        # input data reference
        self.img_rows = 28 # input image row size (세로)
        self.img_cols = 28 # input image column size  (가로)
        self.channels = 1 # black & white
        self.img_shape = (self.img_rows, self.img_cols, self.channels) # shape for creating image
        self.num_classes = 10 # Y label for 0 to 9
        self.latent_dim = 100 
        
        # Adam optimizer 의 beta1 default value = 0.9
        # beta 1 을 줄임으로써 GAN 을 학습하는 동안 generator 의 학습실패로 인해 discriminator 가 빠르게 학습하고, loss 가 0 으로 빠르게 수렴하는 현상 방지
        optimizer = Adam(0.0002, 0.5)
```

#### Discriminator details

`Model.compile`: keras 에서 모델의 학습과정을 설정할때 (loss, optimizer, performance measure 등) 사용


```python
    def init_discriminator(self):
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss = ['binary_crossentropy'], optimizer = self.optimizer, metrics = ['accuracy'])
```

#### Generator details

- random noise $z$ 입력 받음

- fake image $G(z)$ 를 생성

- Batch_size 가 정의되지 않은경우, 자동으로 Input_Dimension 으로 인식

- Batch_size 정의된 경우, Batch_size 를 인자로 먼저 인식

```python
    def init_generator(self):
        self.generator = self.build_generator()
```

**Generator 학습 과정**

- Generator 가 fake image 생성

- Discriminator 가 real 인지 fake 인지 판별

- Discriminator Loss 계산

- Back Propagation

> 이때 Discriminator 의 weight 가 고정된 상태에서 학습을 진행해야 하므로 `discriminator.trainable = False` 로 설정

```python
    def init_model(self):
        # The generator takes noise and the target label as input
        # and generates the corresponding digit of that label
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,))
        img = self.generator([noise, label])

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated image as input and determines validity
        # and the label of that image
        valid = self.discriminator([img, label])

        # The combined model  (stacked generator and discriminator)
        # Trains generator to fool discriminator
        self.combined = Model([noise, label], valid)
        self.combined.compile(loss=['binary_crossentropy'],
            optimizer=optimizer)
```

### 3) Building Generator

- `BatchNormalization` 을 통한 정규화 수행

- `np.prod()`: 주어진 axis 에 속한 element 간의 곱셈

- `Flatten()(Embedding(self.num_classes, self.latent_dim)(label))`: keras 함수형 API 사용

    - label 을 고차원 vector 로 Embedding 후, Flatten 해서 row vector / column vector 로 변환
    
    
- noise 와 label 을 입력받고, img 를 출력 

    - label = [0 or 1 or 2 ... or 9]

```python
    def build_generator(self):

        model = Sequential()
        
        # hidden_dim (latent_dim) 100 개의 image data 를 입력받아, MNIST 형태인 (28, 28, 1) shape 의 데이터 출력
        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh')) # np.prod(self.img_shape) 를 통해 MNIST data 의 shape 로 변환
        model.add(Reshape(self.img_shape))

        model.summary()
        
        # model visualization
        plot_model(model, show_shapes=True, to_file='./images/generator_model.png')
        display(Image(filename='./images/generator_model.png'))
        
        # noise 를 입력받아 fake img 를 생성
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))

        model_input = multiply([noise, label_embedding])
        img = model(model_input)
        
        # input: [noise, label] , output: img
        return Model([noise, label], img)
```

### 4) Building Discriminator

- Generator 와 대부분 비슷, but, img 와 label 을 입력받아서 validity 를 출력

- input 으로는 training (real) image 또는 generated (fake) image $+$ label (conditional vector) 를 받음

 ```python   
    def build_discriminator(self):

        model = Sequential()

        model.add(Dense(512, input_dim=np.prod(self.img_shape)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(1, activation='sigmoid'))
        
        model.summary()
        
        # model visualization
        plot_model(model, show_shapes=True, to_file='./images/discriminator_model.png')
        display(Image(filename='./images/discriminator_model.png'))
        
        # decide the image is real or fake
        img = Input(shape=self.img_shape)
        label = Input(shape=(1,), dtype='int32')

        label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
        flat_img = Flatten()(img)

        model_input = multiply([flat_img, label_embedding])

        validity = model(model_input)
        
        # input: [noise, label] , output: validity
        return Model([img, label], validity)
```

### 5) model training

```python
    def train(self, epochs, batch_size=128, sample_interval=50):
    
        # Load the dataset
        (X_train, y_train), (_, _) = mnist.load_data()
        
        # Pre-processing
        # Configure input
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # normalize pixel data between -1 and 1
        X_train = np.expand_dims(X_train, axis=3)
        y_train = y_train.reshape(-1, 1)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs, labels = X_train[idx], y_train[idx]

            # Sample noise as generator input
            noise = np.random.normal(0, 1, (batch_size, 100))

            # Generate a half batch of new images
            gen_imgs = self.generator.predict([noise, labels])

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
            d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
            
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # 두 loss 의 평균

            # ---------------------
            #  Train Generator
            # ---------------------

            # Condition on labels
            sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)

            # Train the generator
            g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)
```

#### Load data and Normalize

In [5]:
(X_train, y_train), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 # normalize pixel data between -1 and 1

print(X_train.shape)

Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz
(60000, 28, 28)


In [6]:
# 입 출력을 위한 dimension 변경
X_train = np.expand_dims(X_train, axis=3)
y_train = y_train.reshape(-1, 1)

print(X_train.shape)
print(y_train.shape)

(60000, 28, 28, 1)
(60000, 1)


In [7]:
# mini batch_size 크기에 맞게 valid (real) 인지 fake 인지 표기하는 vector 생성
valid = np.ones((128, 1))
fake = np.zeros((128, 1))

print(valid.shape)
print(fake.shape)

(128, 1)
(128, 1)


#### Training

**Discriminator**

In [8]:
# 0 ~ 60000 사이의 random value 128 개 (전체 data 에서 mini batch 만큼 뽑아냄)
idx = np.random.randint(0, X_train.shape[0], 128)

print(idx)

[47606 23704 22004 58514  2759 45937 53635 40637 10890  1643 27369 36976
 20783 56204 10624 23077 10969 12498 35894 37360 36684 59066 39258 27719
 20928 19963 16991 20697 26261 54622 44144 43708 10116 48971 34223 14438
  1004 18172 39704 36804 43185 47806 17056  2300 29179 40211   840 31402
  2414 39777 53600 52049 19028 37086 52974 34341  7643 52411 52936 16425
 22656 15821 43182 36955   776 16546 31590 21221 34399 31536  8521 46538
  1868 40474 27866   678 25677 22642 53394 10458 15719 21675 45518 49350
 36285 13818 18586 44380 20374 59559 47932 57224 46634 42396 50876 36339
 54984  6554 23552 24525 51947 32276 20192 24873 10639 55824  6780 25867
 59366  6310 36435 58554 48632  1107  7178 30231  3380 38638 46458 32376
 40521 43419 36078 36248 15656  5953 43859 38985]


In [9]:
# random index 값들을 이용하여 training image + y label 128 개 추출
imgs, labels = X_train[idx], y_train[idx]

In [10]:
# generator 에 입력으로 들어갈 random noise 생성
# shape (128, 100) <- (batch_size, flattened_image_shape)
noise = np.random.normal(0, 1, (128, 100))

print(noise)
print(noise.shape)

[[ 0.58379212  1.19624084 -1.20898531 ...,  0.76103456 -1.82739536
   2.88477641]
 [-1.03133211  0.98240722 -0.83367039 ..., -1.30040165 -1.21974368
  -1.12502016]
 [ 1.55537176 -0.54177673 -0.23494343 ...,  1.27145743 -0.95070008
  -1.20083224]
 ..., 
 [ 0.71016121  1.65677287 -0.29513291 ..., -0.03858801  1.6586306
   0.32300721]
 [ 0.00393512  0.40952971  0.03341764 ..., -1.47282468 -1.84854556
   0.69204356]
 [ 0.15155601  0.47499252  1.09881113 ..., -0.91022374  0.14220695
  -0.85010822]]
(128, 100)


In [None]:
# Generator 가 noise 와 label (condition) 을 입력받아 fake image 생성
gen_imgs = self.generator.predict([noise, labels])

# Train the discriminator
# 최소 판별 정확도는 각각 0.5 (random choice)
d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid) # real image 는 valid
d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake) # fake image 는 fake 로 학습

d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # 두 loss 의 평균

**Generator**

In [None]:
# 0 ~ 9 사이의 sample label (condition) 128 개 추출
sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1) # -1 : 정해지지 않음을 뜻함

print(sampled_labels[:5])

In [None]:
# noise 와 sample label 을 입력받아, discriminator 에게 valid 판결을 받는것을 목표로 학습 (deceiving discriminator)
g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)

In [None]:
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
    self.sample_images(epoch)

### 6) generator 가 생성한 이미지 출력

```python
    def sample_images(self, epoch):
        r, c = 2, 5 
        noise = np.random.normal(0, 1, (r * c, 100)) # 0 ~ 9 클래스 10개, random noise 100개
        sampled_labels = np.arange(0, 10).reshape(-1, 1) # 0 ~ 9 y label
        
        # generator 로 예측 이미지 생성
        gen_imgs = self.generator.predict([noise, sampled_labels])

        # Rescale images between 0 and 1
        gen_imgs = 0.5 * gen_imgs + 0.5
        
        # matplotlib 을 사용해 vector 를 image 로 변환
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
                axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
                axs[i,j].axis('off')
                cnt += 1
        plt.savefig("./images/%d.png" % epoch)
        plt.close()
```

### 7) 실행

In [None]:
# CGAN 객체 생성후 모델 트레이닝 실행
cgan = CGAN()
cgan.train(epochs = 20001, batch_size = 32)

#### Google colab

- google drive directory mount 필수

```python
from google.colab import drive
drive.mount('/content/gdrive')
```
...

- 이미지 저장할 위치 directory 잘 보고 설정할 것

```python
def sample_images(self, epoch):
    plt.savefig("/gdrive/My Drive/Colab Notebooks/.../images/%d.png" % epoch)
```

# Let's Do it

In [None]:
from __future__ import print_function, division

from keras.utils import plot_model # model visualization
from IPython.display import Image, display # model visualization

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import numpy as np

class CGAN():
    def __init__(self):
        self.init_hyperparams()
        self.init_discriminator()
        self.init_generator()
        self.init_model()
    
    def init_hyperparams(self):
        # Input shape
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.num_classes = 10
        self.latent_dim = 100

        self.optimizer = Adam(0.0002, 0.5)
    
    def init_discriminator(self):
        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss=['binary_crossentropy'], optimizer=self.optimizer, metrics=['accuracy'])
    
    def init_generator(self):
        # Build the generator
        self.generator = self.build_generator()
    
    def init_model(self):
        # The generator takes noise and the target label as input
        # and generates the corresponding digit of that label
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,))
        img = self.generator([noise, label])

        # For the combined model we will only train the generator
        self.discriminator.trainable = False ### 모델을 번갈아가며 순서대로 학습 (하나 고정, 나머지 하나 학습)

        # The discriminator takes generated image as input and determines validity
        # and the label of that image
        valid = self.discriminator([img, label])

        # The combined model  (stacked generator and discriminator)
        # Trains generator to fool discriminator
        self.combined = Model([noise, label], valid)
        
        self.combined.summary()
        
        self.combined.compile(loss=['binary_crossentropy'], optimizer=self.optimizer)

    def build_generator(self):

        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()
        
        plot_model(model, show_shapes=True, to_file='generator_model.png')
        display(Image(filename='generator_model.png'))

        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))

        model_input = multiply([noise, label_embedding])
        img = model(model_input)

        return Model([noise, label], img)

    def build_discriminator(self):

        model = Sequential()

        model.add(Dense(512, input_dim=np.prod(self.img_shape)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.4))
        model.add(Dense(1, activation='sigmoid'))
        
        model.summary()
        
        plot_model(model, show_shapes=True, to_file='discriminator_model.png')
        display(Image(filename='discriminator_model.png'))

        img = Input(shape=self.img_shape)
        label = Input(shape=(1,), dtype='int32')

        label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
        flat_img = Flatten()(img)

        model_input = multiply([flat_img, label_embedding])

        validity = model(model_input)

        return Model([img, label], validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, y_train), (_, _) = mnist.load_data()

        # Configure input
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)
        y_train = y_train.reshape(-1, 1)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs, labels = X_train[idx], y_train[idx]

            # Sample noise as generator input
            noise = np.random.normal(0, 1, (batch_size, 100))

            # Generate a half batch of new images
            gen_imgs = self.generator.predict([noise, labels])

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
            d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Condition on labels
            sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)

            # Train the generator
            g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)

            # Plot the progress
#             print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 2, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        sampled_labels = np.arange(0, 10).reshape(-1, 1)

        gen_imgs = self.generator.predict([noise, sampled_labels])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
                axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
                axs[i,j].axis('off')
                cnt += 1
                
        if epoch%4000 == 0:
            plt.show()
            
        plt.savefig("./%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    cgan = CGAN()
    cgan.train(epochs=20000, batch_size=32, sample_interval=500)