In [18]:
import matplotlib.pyplot as plt
import numpy as np

from keras.datasets import mnist
from keras.layers import (
    Input, Activation, BatchNormalization, Dense, Dropout, Flatten, Reshape)
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Sequential, Model
from keras.optimizers import Adam

In [3]:
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
z_dim = 100

In [8]:
def build_generator(img_shape, z_dim):
    generator_input = Input(shape=(z_dim,), name = 'generator_input')
    x = generator_input
    x = Dense(256*7*7)(x)
    x = Reshape((7, 7, 256))(x)
    x = Conv2DTranspose(128, kernel_size=3, strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.01)(x)
    x = Conv2DTranspose(64, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.01)(x)
    x = Conv2DTranspose(1, kernel_size=3, strides=2, padding='same')(x)
    x = Activation('tanh', name='generator_output')(x)
    generator_output = x
    generator = Model(generator_input, generator_output, name = 'generator')
    return generator

In [9]:
generator = build_generator(img_shape, z_dim)

In [10]:
generator.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
generator_input (InputLayer) (None, 100)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 12544)             1266944   
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 14, 14, 128)       295040    
_________________________________________________________________
batch_normalization_1 (Batch (None, 14, 14, 128)       512       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 14, 14, 64)        7379

In [11]:
def build_discriminator(img_shape):
    discriminator_input = Input(shape=img_shape, name='discriminator_input')
    x = discriminator_input
    x = Conv2D(32, kernel_size=3, strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.01)(x)
    x = Conv2D(64, kernel_size=3, strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.01)(x)
    x = Conv2D(128, kernel_size=3, strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.01)(x)
    x = Flatten()(x)
    x = Dense(1, activation='sigmoid', name='discriminator_output')(x)
    discriminator_output = x
    discriminator = Model(discriminator_input, discriminator_output, name='discriminator')
    return discriminator

In [12]:
discriminator = build_discriminator(img_shape)

In [13]:
discriminator.summary()

Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
discriminator_input (InputLa (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 32)        320       
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 7, 7, 64)          18496     
_________________________________________________________________
batch_normalization_3 (Batch (None, 7, 7, 64)          256       
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 4, 4, 128)       

In [14]:
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])
discriminator.trainable = False
model_input = Input(shape=(z_dim,), name='model_input')
model_output = discriminator(generator(model_input))
GAN = Model(model_input, model_output)
GAN.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])
discriminator.trainable = True

In [22]:
losses = []
accuracies = []
iteration_checkpoints = []
def train_discriminator(x_train, batch_size):
    real = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    
    #реальные изображения
    idx = np.random.randint(0, x_train.shape[0], batch_size)
    imgs = x_train[idx]
    d_loss_real = discriminator.train_on_batch(imgs, real)
    
    #сгенерированные изображение
    z = np.random.normal(0,1,(batch_size, z_dim))
    gen_imgs = generator.predict(z)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
     
    return d_loss_real, d_loss_fake
    
def train_generator(batch_size):
    real = np.ones((batch_size, 1))
    z = np.random.normal(0,1,(batch_size, z_dim))
    g_loss = GAN.train_on_batch(z,real)
    
    return g_loss
    
def train(iterations, batch_size, sample_interval):
    (X_train, _), (_, _) = mnist.load_data()
    X_train = X_train / 127.5 - 1.0
    X_train = np.expand_dims(X_train, axis=3)
    for iteration in range(iterations):
        
        #обучение
        d_loss_real, d_loss_fake = train_discriminator(X_train, batch_size)
        g_loss = train_generator(batch_size)
        
        d_loss, accuracy = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        #сохранение и вывод результатов через интервал
        if (iteration + 1) % sample_interval == 0:
            losses.append((d_loss, g_loss))
            accuracies.append(100.0 * accuracy)
            iteration_checkpoints.append(iteration+1)
            print('{0} [D loss: {1}, acc.: {2}] [G loss: {3}]'.format(iteration + 1, d_loss, 100.0 * accuracy, g_loss))
            sample_images(generator)

In [25]:
def sample_images(generator, image_grid_rows = 4, image_grid_columns = 4):
    z = np.random.normal(0, 1,(image_grid_rows * image_grid_columns, z_dim))
    gen_imgs = generator.predict(z)
    gen_imgs = 0.5 * gen_imgs + 0.5
    fig, axes = plt.subplots(image_grid_rows, image_grid_columns, figsize=(8,8), sharey= True, sharex = True)
    cnt = 0
    for i in range(image_grid_rows):
        for j in range(image_grid_columns):
            axes[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
            axes[i, j].axis('off')
            cnt += 1

In [None]:
iterations = 5000
batch_size = 128
sample_interval = 1000
train(iterations, batch_size, sample_interval)

1000 [D loss: 0.0701909214258194, acc.: 98.4375] [G loss: [4.833536, 0.0]]
2000 [D loss: 0.08059462159872055, acc.: 96.875] [G loss: [3.1433506, 0.046875]]
3000 [D loss: 0.08377835154533386, acc.: 98.4375] [G loss: [4.5514746, 0.0]]
4000 [D loss: 0.05281153321266174, acc.: 99.609375] [G loss: [3.8152757, 0.03125]]


In [None]:
GAN.summary()