# GAN Training

<table>
    <tr>
        <th>Discriminator Learning</th>
        <th>Generator Learning</th>
    </tr>
    <tr>
        <th><img src="./images/DLearning.png" align-items="left"></th>
        <th><img src="./images/GLearning.png" align-items="left"></th>
    </tr>
</table>

<img src="./images/training_alg.png" align-items="left">

In [None]:
# Import notebooks
import import_ipynb
from gan_auxfuncs import *
from gan_nets import GAN_Model 

In [None]:
# Import modules
import os
import time
from collections import deque
from keras.optimizers import Adam
import numpy as np

In [None]:
"""
Setup 
"""
# Director parameters
input_dir = 'training_data_ANIMEFACES/'
output_dir = 'output_data/'

# Shape parameters
noise_shape = (1, 1, 100)
image_shape = (64, 64, 3)

# Training parameters
num_steps = 10000
batch_size = 64

# GAN parameters
loss = 'binary_crossentropy'
optimizer_type = Adam(lr = 0.00015, beta_1 = 0.5) #same of G
metrics=['accuracy']

"""
Initialization 
"""
# Initializing variables to track loss function
RD_loss = deque([0], maxlen = 250)     
FD_loss = deque([0], maxlen = 250)
GAN_loss = deque([0], maxlen = 250)

# Initializing dir variables 
data_dir = os.path.abspath(input_dir) + '/*'
img_save_dir = os.path.abspath(output_dir) + '/' 


# Getting GAN model
GAN, G, D = GAN_Model(noise_shape, image_shape, loss, optimizer_type, metrics)
 
"""
Running the training
"""
for step in range(num_steps):
    
    # To track intermediate results
    tot_step = step
    print("Begin step: ", tot_step)
    step_begin_time = time.time()
    
    """
    Discriminator Learning 
    """
    # (1) Freeze G weights
    G.trainable = False
    
    # (2) Unfreeze D weights
    D.trainable = True
    
    # (3) Train D
    # (3.1) Train D with the real samples (set R_X) 
    R_X = sample_from_dataset(batch_size, image_shape, data_dir = data_dir) # Set of real images     
    R_Y = get_noisy_binary_labels(batch_size, True, 0.2) # (values very close to 1.0)
    R_dis_metrics = D.train_on_batch(R_X, R_Y) # training

    # (3.2) Train D with the fake samples (set F_X)
    F_X = G.predict(get_normal_noise_vector(batch_size, noise_shape)) # Set of fake images
    F_Y = get_noisy_binary_labels(batch_size, False, 0.2) # (values very close to 0.0)
    F_dis_metrics = D.train_on_batch(F_X, F_Y) # training
    
    
    """
    Generator Learning
    """
    # (4) Freeze D
    D.trainable = False

    # (5) Unfreeze G
    G.trainable = True

    # (6) Train G by the GAN training
    X = get_normal_noise_vector(batch_size, noise_shape) # Gaussian Noise (mu = 0 sigma = 1) 
    Y = R_Y # (values very close to 1.0) 
    metrics = GAN.train_on_batch(X, Y) # training
    
    
    """
    Intermediate Fake Image Saving
    """
    if (tot_step % 10) == 0:
        step_num = str(tot_step).zfill(4)
        save_img_batch(F_X, img_save_dir + step_num + "_image.png")

    
    """
    Loss Function Tracking and Intermediate Models Saving
    """
    # Loss tracking
    RD_loss.append(R_dis_metrics[0])
    FD_loss.append(F_dis_metrics[0]) 
    GAN_loss.append(metrics[0])
    
    # Time tracking
    end_time = time.time()
    diff_time = int(end_time - step_begin_time)
    
    # Save intermiediate models
    print("Step %d completed. Time took: %s secs." % (tot_step, diff_time))
    if ((tot_step + 1) % 500) == 0:
        
        # Save intermiediate models
        D.trainable = True
        G.trainable = True
        G.save(save_model_dir + str(tot_step) + "_G_weights_and_arch.hdf5")
        D.save(save_model_dir + str(tot_step) + "_D_weights_and_arch.hdf5")
        
        # Print Average Loss
        print_average_loss(RD_loss, FD_loss, GAN_loss)
        
        # Intermediate GIF Movie Image Saving
        save_GIF_movie_from_PNG(img_save_dir, str(tot_step))
    