In [None]:
import os
import math
import time
from typing import *
import zipfile
import shutil

from keras.models import Sequential, load_model, Model
from keras.layers import Dense, Conv2D, Conv2DTranspose, Flatten, Reshape
from keras.layers import UpSampling2D, Activation, BatchNormalization
from keras.layers import LeakyReLU, Dropout, MaxPooling2D, Input, ZeroPadding2D
from keras.initializers import TruncatedNormal, RandomNormal
from keras.optimizers import Adam, RMSprop
import keras.backend as K
import tensorflow as tf

import xml.etree.ElementTree as ET
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from PIL import Image
import cv2
import imgaug as ia
from imgaug import augmenters as iaa


In [None]:
project_dir = './'
data_dir = '/kaggle/input/'
image_dir = f'{data_dir}/all-dogs/all-dogs/'
annotations_dir = f'{data_dir}/annotation/Annotation/'
crop_dir = f'{project_dir}/crop/'
model_dir = f'{project_dir}/saved_models/'
generated_dir = f'{project_dir}/generated_images/'

img_shape = (64, 64, 3)
batch_size = 64
train_steps = 32000 * 8 # 25_000 # 7200
save_interval = 1000 * 8

## Utils. Prepare data and project structure

In [None]:
def make_dirs(dirs):
    """
    :param List[str] dirs: list with directories to make 
        like ['folder/new_folder', 'another_new_folder/']
    """
    for directory in dirs:
        try:
            os.mkdir(directory)
            print(f'{directory} created!')
        except FileExistsError as err:
            print(err)
        
        
def delete_directory(directory: str):
    if os.path.exists(directory):
        shutil.rmtree(directory)
        print(f'{directory} deleted!')
    else:
        print(f"{directory} doesn't exist")
        
            
def get_image(image_path: str, mode: str = 'RGB',
              img_size: Tuple[int] = (64, 64, 3)) -> np.ndarray:
    """
    Read image from image_path.
    
    :param image_path: Path of image
    :param mode: Mode of image
    :return: Image data as numpy.ndarray
    """
    image = Image.open(image_path)
    # image = image.resize((img_size[0], img_size[1]), Image.ANTIALIAS)
    image = np.array(image.convert(mode))
    image = image / 127.5 - 1.
    return image


def get_batch(image_files: List[str], image_dir: str,
              mode: str = 'RGB', img_size: Tuple[int] = (64, 64, 3)) -> np.ndarray:
    """
    Load batch of pictures from directory.
    """
    data_batch = []
    for file in image_files:
        image_path = os.path.join(image_dir, file)
        image = get_image(image_path, mode, img_size=img_size)
        data_batch.append(image)
    data_batch = np.array(data_batch)
    return data_batch


def train_generator_drive(image_dir: str, batch_size: int = 128, 
                    shuffle: bool = True, img_size: Tuple[int] = (64, 64, 3)) -> np.ndarray:
    """
    Generator for training neural network.
    """
    image_files = os.listdir(image_dir)
    n_steps = math.floor(len(image_files) / batch_size) # all batches must be full
    print(f'Founded {len(image_files)} pictures')
    
    if shuffle:
        np.random.shuffle(image_files)
        
    while True:
        for i in range(n_steps):
            batch_img = image_files[i * batch_size: (i + 1) * batch_size]
            batch_img = get_batch(batch_img, image_dir, img_size=img_size)
            yield batch_img
            

def train_generator(image_dir: str, batch_size: int = 128, 
                    shuffle: bool = True, img_size: Tuple[int] = (64, 64, 3)) -> np.ndarray:
    """
    Generator for training neural network. Stores data in RAM.
    """
    image_files = os.listdir(image_dir)
    n_steps = math.floor(len(image_files) / batch_size) # all batches must be full
    images = get_batch(image_files, image_dir, img_size=img_size)
    print(f'Founded {len(images)} pictures')
    
    if shuffle:
        np.random.shuffle(images)
        
    while True:
        for i in range(n_steps):
            batch_img = images[i * batch_size: (i + 1) * batch_size]
            yield batch_img

            
def pad_square(image):
    """
    :param np.ndarray image: shape = (height, width, n_channels)
    :rtype: np.ndarray
    """
    old_size = image.shape[:2] # old_size is in (height, width) format
    desired_size = max(old_size)
    # desired_size = 64
    
    ratio = float(desired_size) / max(old_size)
    new_size = tuple([int(x * ratio) for x in old_size])
    # new_size should be in (width, height) format
    image = cv2.resize(image, (new_size[1], new_size[0]),
                       interpolation=cv2.INTER_AREA)
    
    delta_w = desired_size - new_size[1]
    delta_h = desired_size - new_size[0]
    top, bottom = delta_h // 2, delta_h // 2
    left, right = delta_w // 2, delta_w // 2
    
    color = [0, 0, 0]
    image = cv2.copyMakeBorder(image, top, bottom, left, right,
                               cv2.BORDER_CONSTANT, value=color)
    return image


def crop_square(image):
    """
    :param np.ndarray image: shape = (height, width, n_channels)
    :rtype: np.ndarray
    """
    height, width = image.shape[: -1]  # Get dimensions
    min_dim = min(width, height)

    x1 = (width - min_dim) // 2
    y1 = (height - min_dim) // 2
    x2 = (width + min_dim) // 2
    y2 = (height + min_dim) // 2

    # Crop the center of the image
    image = image[y1: y2, x1: x2]
    # resize for test interpolation
    # image = cv2.resize(image, (64, 64), interpolation=cv2.INTER_AREA)
    # image = cv2.resize(image, (64, 64))
    return image


def show_images(images, n_rows=5, n_cols=8):
    """
    Show batch of images.
    
    :param np.ndarray images: shape = (n_images, height, width, channels)
        or (height, width, channels) if one image.
    """
    # check dimention
    if len(images.shape) == 3:
        images = np.expand_dims(images, axis=0)
    plt.close('all')
    plt.figure(figsize=(18, 10))
    
    for i in range(images.shape[0]):
        plt.subplot(n_rows, n_cols, i + 1)
        image = images[i]
        plt.imshow(image)
        # plt.axis('off')
    plt.tight_layout()
    plt.show()
    plt.close('all')
    

def augmentation(images, n_transform=3):
    original_images = images.copy()
    result = list()

    transformations = []
    for i in range (n_transform):
        # include all possible changes
        transform = iaa.SomeOf((1, 1), [
            # iaa.Affine(rotate=(-30, 30), cval=0),
            # iaa.Affine(shear=(0, 15)),
            # iaa.Affine(rotate=(-15, 0), cval=0),
            # iaa.Affine(rotate=(0, 20), mode=ia.ALL, cval=(0, 1)),
            iaa.Fliplr(1),
            # iaa.Affine(translate_percent={"x": (-0.1, 0.1), "y": (-0.2, 0.2)},
            #       mode=ia.ALL, cval=(0, 255)),
            # iaa.Affine(scale={"x": (0.7, 1.4), "y": (0.7, 1.4)},
            #        mode=ia.ALL, cval=(0, 255)),
            # iaa.Affine(shear=(-30, 30), mode=ia.ALL, cval=(0, 255)),
            # Blur
            # iaa.GaussianBlur(sigma=(0.0, 2.0)),
            # iaa.AverageBlur(k=3),
            # iaa.MedianBlur(k=3),
            # iaa.AdditiveGaussianNoise(scale=0.05*255),
            # iaa.ElasticTransformation(alpha=(0, 3.0), sigma=0.4), 
            # Others
            # iaa.CropAndPad(percent=(-0.1, 0.1)),
            # iaa.Dropout(p=(0.005, 0.01)),
            # iaa.PiecewiseAffine(scale=(0.03, 0.05)),
            # iaa.Sharpen(alpha=(0.0, 1.0), lightness=(0.75, 2.0)),
            # iaa.Emboss(alpha=(0.0, 1.0), strength=(0.5, 1.5)),
            ], random_order=True)
        transformations.append(transform)
    
    for transform in transformations:
        aug_images = transform.augment_images(original_images)
        result.extend(aug_images)

    result = list(images) + result 
    return np.array(result)

In [None]:
delete_directory(crop_dir)
make_dirs([model_dir, generated_dir, crop_dir])

### Crop images

In [None]:
images = os.listdir(image_dir)
breeds = os.listdir(annotations_dir)
skip_count = 0
skip_small_dogs = 0

for breed in tqdm(breeds):
    for i, img_name in enumerate(os.listdir(f'{annotations_dir}/{breed}')):
        image = cv2.imread(os.path.join(image_dir, img_name + '.jpg'))
        if image is None:
            continue
        tree = ET.parse(annotations_dir + breed + '/' + img_name)
        myroot = tree.getroot()
        objects = myroot.findall('object')
        
        if len(objects) > 1:
            # print('More then one dog on the image. Skip', img_name)
            skip_count += 1
            continue
    
        for o in objects:
            bndbox = o.find('bndbox')
            xmin = int(bndbox.find('xmin').text)
            ymin = int(bndbox.find('ymin').text)
            xmax = int(bndbox.find('xmax').text)
            ymax = int(bndbox.find('ymax').text)
            # check dog area on the image
            h_dog = ymax - ymin
            w_dog = xmax - xmin
            # square_dog = h_dog * w_dog
            # square_img = image.shape[0] * image.shape[1]
            # sq_threshold = 0.6
            # if (square_dog / square_img) < sq_threshold:
            #     # print(f'Dog square < {sq_threshold}. Img:', img_name)
            #     skip_small_dogs += 1
            #     continue
            try: 
                image = image[ymin: ymax, xmin: xmax]
            except TypeError: 
                continue
                
            if h_dog < 64 or w_dog <= 64:
                skip_small_dogs += 1
                continue
                
            w = image.shape[1]; h = image.shape[0];    
            if w < h:
                w2 = 64; h2 = int((64/w)*h)
                image = cv2.resize(image, (w2, h2), interpolation=cv2.INTER_AREA)
                image = image[0: 64, 0: 64]
            else:
                h2 = 64; w2 = int((64/h)*w)
                image = crop_square(image)
                image = cv2.resize(image, (64, 64), interpolation=cv2.INTER_AREA)
        
            # crop center square from image
            # image = crop_square(image)
            # crop upper part of image
            # image = image[0: 64, 0: 64]
            # print(image.shape)
            
            # image = cv2.resize(image, (64, 64), interpolation=cv2.INTER_AREA)
            image = np.expand_dims(image, axis=0)
            augmented = augmentation(image, n_transform=1)
            assert augmented.shape == (2, 64, 64, 3)
            for j, image in enumerate(augmented):
                cv2.imwrite(os.path.join(crop_dir, img_name + f'_{j}.jpg'), image)
print('Skipped with several dogs =', skip_count) 
print('Skipped with too small dogs =', skip_small_dogs) 
print('Prepared =', len(os.listdir(crop_dir)))

### Check train generator

In [None]:
# 0.5 dog area remain
gen = train_generator(crop_dir, batch_size=48)
m = next(gen)
print(m.shape)
m = m * 0.5 + 0.5

show_images(m, n_rows=6, n_cols=8)

del gen

In [None]:
# check augmentation
gen = train_generator_drive(image_dir, batch_size=8)
m = next(gen)
print(m.shape)
m = m * 0.5 + 0.5
m = np.array(list(map(pad_square, m)))
m = np.array(list(map(lambda x: cv2.resize(x, (64, 64), interpolation=cv2.INTER_AREA), m)))
m = augmentation(m, n_transform=2)

print(m.shape)
show_images(m)
del gen

## Deep convolutional GAN

In [None]:
# adapted from keras.optimizers.Adam
class AdamWithWeightnorm(Adam):
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        lr = self.lr
        if self.initial_decay > 0:
            lr *= (1. / (1. + self.decay * K.cast(self.iterations, K.floatx())))

        t = K.cast(self.iterations + 1, K.floatx())
        lr_t = lr * K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t))

        shapes = [K.get_variable_shape(p) for p in params]
        ms = [K.zeros(shape) for shape in shapes]
        vs = [K.zeros(shape) for shape in shapes]
        self.weights = [self.iterations] + ms + vs

        for p, g, m, v in zip(params, grads, ms, vs):

            # if a weight tensor (len > 1) use weight normalized parameterization
            # this is the only part changed w.r.t. keras.optimizers.Adam
            ps = K.get_variable_shape(p)
            if len(ps)>1:

                # get weight normalization parameters
                V, V_norm, V_scaler, g_param, grad_g, grad_V = get_weightnorm_params_and_grads(p, g)

                # Adam containers for the 'g' parameter
                V_scaler_shape = K.get_variable_shape(V_scaler)
                m_g = K.zeros(V_scaler_shape)
                v_g = K.zeros(V_scaler_shape)

                # update g parameters
                m_g_t = (self.beta_1 * m_g) + (1. - self.beta_1) * grad_g
                v_g_t = (self.beta_2 * v_g) + (1. - self.beta_2) * K.square(grad_g)
                new_g_param = g_param - lr_t * m_g_t / (K.sqrt(v_g_t) + self.epsilon)
                self.updates.append(K.update(m_g, m_g_t))
                self.updates.append(K.update(v_g, v_g_t))

                # update V parameters
                m_t = (self.beta_1 * m) + (1. - self.beta_1) * grad_V
                v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(grad_V)
                new_V_param = V - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
                self.updates.append(K.update(m, m_t))
                self.updates.append(K.update(v, v_t))

                # if there are constraints we apply them to V, not W
                if getattr(p, 'constraint', None) is not None:
                    new_V_param = p.constraint(new_V_param)

                # wn param updates --> W updates
                add_weightnorm_param_updates(self.updates, new_V_param, new_g_param, p, V_scaler)

            else: # do optimization normally
                m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
                v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
                p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)

                self.updates.append(K.update(m, m_t))
                self.updates.append(K.update(v, v_t))

                new_p = p_t
                # apply constraints
                if getattr(p, 'constraint', None) is not None:
                    new_p = p.constraint(new_p)
                self.updates.append(K.update(p, new_p))
        return self.updates

def get_weightnorm_params_and_grads(p, g):
    ps = K.get_variable_shape(p)

    # construct weight scaler: V_scaler = g/||V||
    V_scaler_shape = (ps[-1],)  # assumes we're using tensorflow!
    V_scaler = K.ones(V_scaler_shape)  # init to ones, so effective parameters don't change

    # get V parameters = ||V||/g * W
    norm_axes = [i for i in range(len(ps) - 1)]
    V = p / tf.reshape(V_scaler, [1] * len(norm_axes) + [-1])

    # split V_scaler into ||V|| and g parameters
    V_norm = tf.sqrt(tf.reduce_sum(tf.square(V), norm_axes))
    g_param = V_scaler * V_norm

    # get grad in V,g parameters
    grad_g = tf.reduce_sum(g * V, norm_axes) / V_norm
    grad_V = tf.reshape(V_scaler, [1] * len(norm_axes) + [-1]) * \
             (g - tf.reshape(grad_g / V_norm, [1] * len(norm_axes) + [-1]) * V)

    return V, V_norm, V_scaler, g_param, grad_g, grad_V

def add_weightnorm_param_updates(updates, new_V_param, new_g_param, W, V_scaler):
    ps = K.get_variable_shape(new_V_param)
    norm_axes = [i for i in range(len(ps) - 1)]

    # update W and V_scaler
    new_V_norm = tf.sqrt(tf.reduce_sum(tf.square(new_V_param), norm_axes))
    new_V_scaler = new_g_param / new_V_norm
    new_W = tf.reshape(new_V_scaler, [1] * len(norm_axes) + [-1]) * new_V_param
    updates.append(K.update(W, new_W))
    updates.append(K.update(V_scaler, new_V_scaler))

# data based initialization for a given Keras model
def data_based_init(model, input):
    # input can be dict, numpy array, or list of numpy arrays
    if type(input) is dict:
        feed_dict = input
    elif type(input) is list:
        feed_dict = {tf_inp: np_inp for tf_inp,np_inp in zip(model.inputs,input)}
    else:
        feed_dict = {model.inputs[0]: input}

    # add learning phase if required
    if model.uses_learning_phase and K.learning_phase() not in feed_dict:
        feed_dict.update({K.learning_phase(): 1})

    # get all layer name, output, weight, bias tuples
    layer_output_weight_bias = []
    for l in model.layers:
        trainable_weights = l.trainable_weights
        if len(trainable_weights) == 2:
            W,b = trainable_weights
            assert(l.built)
            layer_output_weight_bias.append((l.name,l.get_output_at(0),W,b)) # if more than one node, only use the first

    # iterate over our list and do data dependent init
    sess = K.get_session()
    for l,o,W,b in layer_output_weight_bias:
        print('Performing data dependent initialization for layer ' + l)
        m,v = tf.nn.moments(o, [i for i in range(len(o.get_shape())-1)])
        s = tf.sqrt(v + 1e-10)
        updates = tf.group(W.assign(W/tf.reshape(s,[1]*(len(W.get_shape())-1)+[-1])), b.assign((b-m)/s))
        sess.run(updates, feed_dict)

In [None]:
class DCGAN:
    
    def __init__(self, img_shape):
        self.img_shape = img_shape
        self.noise_size = 100
        
        self.discriminator = self.build_discriminator()
        self.generator = self.build_generator()
        self.discriminator_model = self.build_discriminator_model()
        self.adversarial_model = self.build_adversarial_model()
        
    def build_discriminator(self):
        filters = 32
        dropout = 0.25 # 0.25
        model = Sequential(name='Discriminator')
        
        # In: 64 x 64 x 3, channels(filters) = 3
        # Out: 32 x 32 x 32
        model.add(Conv2D(filters, kernel_size=3, strides=2,
                  input_shape=self.img_shape, padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(dropout))
        
        # In: 32 x 32 x 32
        # Out: 18 x 18 x 64
        model.add(Conv2D(filters * 2, kernel_size=3, strides=2, padding='same'))
        model.add(ZeroPadding2D(padding=((0,1),(0,1))))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(dropout))
        
        # In: 18 x 18 x 64
        # Out: 9 x 9 x 128
        model.add(Conv2D(filters * 4, 3, strides=2, padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(dropout))
        
        # In: 9 x 9 x 128
        # Out: 9 x 9 x 256
        model.add(Conv2D(filters * 8, 3, strides=1, padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(dropout))
        
        # In: 9 x 9 x 256
        # Out: 9 x 9 x 512
        model.add(Conv2D(filters * 16, 3, strides=1, padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(dropout))
        
        # In: 9 x 9 x 512, filters = 512
        # Out: 1-dim probability
        model.add(Flatten())
        model.add(Dense(1))
        model.add(Activation('sigmoid'))
        
        # input_image = Input(shape=self.img_shape)
        # validity = model(input_image)
        # model = Model(input_image, validity)
        # model.name = 'Discriminator'
        
        print('\nDiscriminator summary (non compiled):')
        model.summary()
        
        return model
    
    def build_generator(self):
        # dropout = 0.5 # 0.4 # дропаут в G все ломает
        filters = 64 * 4
        dim = 4
        model = Sequential(name='Generator')
        
        # In: 100
        # Out: dim x dim x filters
        model.add(Dense(dim * dim * filters, input_dim=self.noise_size))
        # model.add(BatchNormalization(momentum=0.9))
        # model.add(Activation('relu'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Reshape((dim, dim, filters)))
        # model.add(Dropout(dropout))
        
        # In: dim x dim x filters
        # Out: 2*dim x 2*dim x filters
        model.add(UpSampling2D())
        model.add(Conv2D(filters, 3, padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        # model.add(Activation('relu'))
        # model.add(Dropout(dropout))
        
        # In: 2*dim x 2*dim x filters
        # Out: 4*dim x 4*dim x filters
        model.add(UpSampling2D())
        model.add(Conv2D(filters, 3, padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        # model.add(Activation('relu'))
        # model.add(Dropout(dropout))
        
        # In: 4*dim x 4*dim x filters
        # Out: 8*dim x 8*dim x filters/2
        model.add(UpSampling2D())
        model.add(Conv2D(int(filters / 2), 3, padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        # model.add(Activation('relu'))
        # model.add(Dropout(dropout))
        
        # In: 4*dim x 4*dim x filters/2
        # Out: 16*dim x 16*dim x filters/2
        model.add(UpSampling2D())
        model.add(Conv2D(int(filters / 2), 3, padding='same'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        # model.add(Activation('relu'))
        # model.add(Dropout(dropout))
        
        # In: 16*dim x 16*dim x filters/2
        # Out: 64 x 64 x 3 grayscale image [-1.0, 1.0] per pix
        model.add(Conv2D(self.img_shape[-1], 3, padding='same'))
        model.add(Activation('tanh'))
        
        # input = Input(shape=(self.noise_size,))
        # generated_image = model(input)
        # model = Model(input, generated_image)
        # model.name = 'Generator'
        
        print('\nGenerator summary (non compiled):')
        model.summary()
        
        return model
    
    def build_discriminator_model(self):
        
        # optimizer = RMSprop(lr=0.0002, decay=6e-8)
        # optimizer = Adam(1.5e-4, 0.5)
        optimizer = AdamWithWeightnorm(lr = 0.0002, beta_1 = 0.5)
        
        model = Sequential(name='Discriminator_model')
        model.add(self.discriminator)
        # model = self.discriminator
        model.compile(loss='binary_crossentropy', optimizer=optimizer,
                      metrics=['accuracy'])
        
        print('\nDiscriminator model summary (compiled):')
        model.summary()
        return model
    
    def build_adversarial_model(self):
        
        # optimizer = RMSprop(lr=0.0001, decay=3e-8)
        # optimizer = Adam(1.5e-4, 0.5)
        optimizer = AdamWithWeightnorm(lr = 0.0002, beta_1 = 0.5)
        
        # random_input = Input(shape=(self.noise_size,))
        # generated_image = self.generator(random_input)
        # 
        # self.discriminator_model.trainable = False
        # 
        # validity = self.discriminator_model(generated_image)
        # 
        # model = Model(random_input, validity)
        # model.name = 'Adversatial_model'
        # model.compile(loss='binary_crossentropy', optimizer=optimizer,
        #               metrics=['accuracy'])
        
        model = Sequential(name='Adversatial_model')
        model.add(self.generator)
        self.discriminator_model.trainable = False
        model.add(self.discriminator_model)
        model.compile(loss='binary_crossentropy', optimizer=optimizer,
                      metrics=['accuracy'])
        
        print('\nAdversarial model summary (compiled):')
        model.summary()
        return model
    
    def train(self, model_dir, images_dir, train_generator, 
              train_steps=1000, batch_size=256, save_interval=0):
        y_real = np.random.random((batch_size, 1)) * (1.2 - 0.7) + 0.7 # np.ones((batch_size, 1))
        y_fake = np.random.random((batch_size, 1)) * 0.3 # np.zeros((batch_size, 1))
        
        for i in tqdm(range(1, train_steps + 1), desc='Training'):
            x_real = next(train_generator) 
            noise = np.random.normal(0, 1, size=(batch_size, self.noise_size))
            x_fake = self.generator.predict(noise)
                  
            # X = np.concatenate((x_real, x_fake))
            # Y = np.concatenate((y_real, y_fake))
            prob = np.random.random()
            if prob >= 0.5:
                d_loss_real = self.discriminator_model.train_on_batch(x_real, y_real)
                d_loss_fake = self.discriminator_model.train_on_batch(x_fake, y_fake)
            else:
                d_loss_fake = self.discriminator_model.train_on_batch(x_fake, y_fake)
                d_loss_real = self.discriminator_model.train_on_batch(x_real, y_real)
                
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            
            a_loss = self.adversarial_model.train_on_batch(noise, y_real)
            
            if (save_interval > 0) and (i % save_interval == 0):
                self.generator.save(os.path.join(model_dir, 'generator.h5'))
                self.discriminator.save(os.path.join(model_dir, 
                                                     'discriminator.h5'))
                log_mesg = "%d: [D loss: %f, acc: %f]" % (i, d_loss[0], d_loss[1])
                log_mesg = "%s  [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
                print(log_mesg)
                img_path = os.path.join(images_dir, f'mnist_step_{i}.png')
                gen_images = self.generate_batch_images(24)
                self.save_images(gen_images, img_path, show=True)           
        
    def generate_image(self, noise=None):
        """
        :param np.ndarray noise: shape = [1, noise_size]
        
        :return: image array of shape (height, width, n_channels).
        :rtype: np.ndarray
        """
        return self.generate_batch_images(n_images=1, noise=noise)[0]
            
    def generate_batch_images(self, n_images, noise=None):
        """
        :param int n_images: not used if noise is given.
        :param np.ndarray noise: shape = [n_images, noise_size].
        
        :return: images array of shape (n_images, height, width, n_channels)
        :rtype: np.ndarray
        """
        if not noise:
            noise = np.random.normal(0, 1, size=(n_images, self.noise_size))
        images = self.generator.predict(noise)
        images = 0.5 * images + 0.5
        return images
    
    def save_images(self, images, file_path, n_rows=4, n_cols=6, show=False):
        """
        Save batch of images as one file.
        
        :param np.ndarray images: shape = (n_images, height, width, channels)
            or (height, width, channels) if one image.
        """
        # check dimention
        if len(images.shape) == 3:
            images.resize(1, images.shape[0], images.shape[1], images.shape[2])
        plt.close('all')
        plt.figure(figsize=(10, 10))
        
        for i in range(images.shape[0]):
            plt.subplot(n_rows, n_cols, i + 1)
            image = images[i, :, :, :]
            image = np.reshape(image, self.img_shape)
            plt.imshow(image)
            plt.axis('off')
        plt.tight_layout()
        plt.savefig(file_path)
        
        if show:
            plt.show()
        plt.close('all')
        
    def load_generator(self, generator_path):
        """
        Loads generator from *.h5 file.
        """
        self.generator = load_model(generator_path)
        

## Train GAN

In [None]:
gan = DCGAN(img_shape=img_shape)

In [None]:
try:
    del train_gen
except NameError:
    print("Doesn't exist")
train_gen = train_generator(crop_dir, batch_size=batch_size)

In [None]:
%%time

gan.train(model_dir, generated_dir, train_gen, train_steps=train_steps, 
          batch_size=batch_size, save_interval=save_interval)

## Overview final generator

In [None]:
images = gan.generate_batch_images(30)

In [None]:
gan.save_images(images, f'{generated_dir}/test_generator.png',
                n_rows=5, show=True)

## Save submission

In [None]:
%%time
z = zipfile.PyZipFile('images.zip', mode='w')

for i in tqdm(range(10_000), desc='Generate submission'):
    img = gan.generate_image()
    file = str(i) +'.png'
    # plt.imshow(img)
    # plt.savefig(file)
    # plt.close('all')
    img = (img * 255).astype(np.uint8)
    img = Image.fromarray(img)
    img.save(file,'PNG') # Image object
    z.write(file)
    os.remove(file)
    
z.close()

In [None]:
delete_directory(crop_dir)