## Imports

In [None]:
import os
import math
import shutil
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from read_data import *
from model_builder import *

## Set Styles for transfer

In [None]:
styles = {
    0 : 'Gorodets',
    1 : 'Gzhel',
    2 : 'Iznik',
    3 : 'Khokhloma',
    4 : 'Neglyubka',
    5 : 'Wycinanki_Å‚owickie',
    6 : 'Wzory_kaszubskie'
}

style_X = styles[1]
style_Y = styles[4]

## Set saving / restoring

In [None]:
restoring_mode = False
saving_mode = False

restoring_name = 'first_model.ckpt'
saving_name = 'first_model.ckpt'

restoring_path = os.path.join('models', style_X + ' =|= ' + style_Y, restoring_name)
saving_path = os.path.join('models', style_X + ' =|= ' + style_Y, saving_name)

## Adjust Hyperparameters

In [None]:
EPOCHS = 100
BATCH_SIZE = 1
IMG_SIZE = 150

LAMBDA = 5
GEN_STEPS = 1
DSC_STEPS = 1

## Build the model and deploy it on a device

In [None]:
with tf.device('/cpu:0'):
    
    #==================[ READ AND PROCESS THE INPUT ]==================#
            
    # load training data from input queues     
    X = inputs(style_X, BATCH_SIZE, EPOCHS)
    Y = inputs(style_Y, BATCH_SIZE, EPOCHS)
    
    # normalize the images     
    X = tf.div(tf.cast(X, tf.float32), 255.0)
    Y = tf.div(tf.cast(Y, tf.float32), 255.0)

### Generators

In [None]:
with tf.device('/gpu:0'):
    
    #==================[ G(X) -> Y ]==================#
    
    G_x = generator(X, 'X')
    
    #==================[ F(Y) -> X ]==================#   
    
    F_y = generator(Y, 'Y')

### Discriminators

In [None]:
with tf.device('/gpu:0'):
    
    #==================[ Dy ]==================#
    
    dsc_Y = discriminator(Y, 'Y')
    dsc_Fake_Y = discriminator(G_x, 'Y')
    
    #==================[ Dx ]==================#
    
    dsc_X = discriminator(X, 'X')
    dsc_Fake_X = discriminator(F_y, 'X')
    
    #================[ Cyclic ]================#
    
    cyc_X = generator(G_x, 'Y')
    cyc_Y = generator(F_y, 'X')
    

## Losses

In [None]:
with tf.device('/gpu:0'):
    
    #==================[ Adversarial Loss ]==================#
    
    # discriminators
    L_DSC_X = tf.reduce_mean(tf.squared_difference(dsc_X, 1) + tf.square(dsc_Fake_X))
    L_DSC_Y = tf.reduce_mean(tf.squared_difference(dsc_Y, 1) + tf.square(dsc_Fake_Y))
    
    # generators
    L_GEN_X = tf.reduce_mean(tf.squared_difference(dsc_Fake_Y, 1))
    L_GEN_Y = tf.reduce_mean(tf.squared_difference(dsc_Fake_X, 1))
    
    #==================[ Consistency Loss ]==================#
    
    L_CYC = tf.reduce_mean(tf.abs(X - cyc_X)) + tf.reduce_mean(tf.abs(Y - cyc_Y))
    
    #=====================[ Final Loss ]=====================#
    
    Loss_Gen_X = L_GEN_X + LAMBDA * L_CYC
    Loss_Gen_Y = L_GEN_Y + LAMBDA * L_CYC
    Loss_Dsc_X = L_DSC_X
    Loss_Dsc_Y = L_DSC_Y
    

## Optimization

In [None]:
with tf.device('/gpu:0'):
    
    optimizer = tf.train.AdamOptimizer(learning_rate=1e-2)
    
    gen_X_op = optimizer.minimize(Loss_Gen_X, var_list=tf.get_collection("GEN_X"))
    gen_Y_op = optimizer.minimize(Loss_Gen_Y, var_list=tf.get_collection("GEN_Y"))
    dsc_X_op = optimizer.minimize(Loss_Dsc_X, var_list=tf.get_collection("DSC_X"))
    dsc_Y_op = optimizer.minimize(Loss_Dsc_Y, var_list=tf.get_collection("DSC_Y"))

## Create the session and start the threads for input queues

In [None]:
# create the session saver
saver = tf.train.Saver()

# create a session for running operations in the graph.
sess = tf.Session()

# create the variable initializers
init_op = tf.group(tf.global_variables_initializer(),
                   tf.local_variables_initializer())

# initialize the variables
sess.run(init_op)

if restoring_mode:
    # previously saved model is restored
    saver.restore(sess, restoring_path)
    
# start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

## Collect Data for Tensorboard

In [None]:
# clear tensorboard old data
try:
    shutil.rmtree('tensorboard')
except FileNotFoundError:
    pass

with tf.device('/cpu:0'):
        
    loss_gen_x = tf.summary.scalar('loss_gen_x', Loss_Gen_X)
    loss_gen_y = tf.summary.scalar('loss_gen_y', Loss_Gen_Y)
    loss_dsc_x = tf.summary.scalar('loss_dsc_x', Loss_Dsc_X)
    loss_dsc_y = tf.summary.scalar('loss_dsc_y', Loss_Dsc_Y)
    
    x_original = tf.summary.image('X_original', X)
    y_fake = tf.summary.image('Y_fake', G_x)
    y_original = tf.summary.image('Y_original', Y)
    x_fake = tf.summary.image('X_fake', F_y)
    
    merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter('tensorboard', sess.graph)

## Training loop

In [None]:
try:
    step = 0   
    # feed data until the epoch limit is reached     
    while not coord.should_stop():
        step += 1
        
        # train generator X
        for i in range(GEN_STEPS):
            _, loss = sess.run([gen_X_op, Loss_Gen_X])
            print("Step {0:5d} | Generator Y     {1:5d} | loss = {2:6.3f}".format(
                    step, i, loss))    
#         # train discriminator Y
#         for i in range(DSC_STEPS):
#             _, loss = sess.run([dsc_Y_op, Loss_Dsc_Y])
#             print("Step {0:5d} | Discriminator Y {1:5d} | loss = {2:6.3f}".format(
#                     step, i, loss)) 
        
        # train generator Y
        for i in range(GEN_STEPS):
            _, loss, summary = sess.run([gen_Y_op, Loss_Gen_Y, merged])
            print("Step {0:5d} | Generator X     {1:5d} | loss = {2:6.3f}".format(
                    step, i, loss))    
#         # train discriminator X
#         for i in range(DSC_STEPS):
#             _, loss, summary = sess.run([dsc_X_op, Loss_Dsc_X, merged])
#             print("Step {0:5d} | Discriminator X {1:5d} | loss = {2:6.3f}".format(
#                     step, i, loss)) 
            
        # save stats to log         
        summary_writer.add_summary(summary, step)
                
except tf.errors.OutOfRangeError:
    
    print('\nDone training -- epoch limit reached\n')
    
finally:
    
    # when done, ask the threads to stop
    coord.request_stop()

    # wait for threads to finish
    coord.join(threads)
    sess.close()