# CycleGAN train

In [1]:
import os
import matplotlib.pyplot as plt

from models.cycleGAN import CycleGAN
from utils.loaders import DataLoader

Using TensorFlow backend.


In [2]:

# run params
SECTION = 'paint'
RUN_ID = '0001'
DATA_NAME = 't2t'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' # 'build' # 

# data

In [3]:
IMAGE_SIZE_768 = 768
IMAGE_SIZE_512 = 512

In [4]:

data_loader = DataLoader(dataset_name=DATA_NAME, img_res=(IMAGE_SIZE_768, IMAGE_SIZE_512))


# architecture

In [5]:
gan = CycleGAN(
    input_dim = (IMAGE_SIZE_768, IMAGE_SIZE_512,3)
    ,learning_rate = 0.0002
    , buffer_max_length = 50
    , lambda_validation = 1
    , lambda_reconstr = 10
    , lambda_id = 2
    , generator_type = 'unet'
    , gen_n_filters = 32
    , disc_n_filters = 32
    )

if mode == 'build':
    gan.save(RUN_FOLDER)
else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))
    


W0323 21:28:04.088479 10452 deprecation_wrapper.py:119] From C:\ProgramData\Anaconda3\envs\py36-cycleGAN\lib\site-packages\keras\backend\tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0323 21:28:04.115380 10452 deprecation_wrapper.py:119] From C:\ProgramData\Anaconda3\envs\py36-cycleGAN\lib\site-packages\keras\backend\tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0323 21:28:04.120067 10452 deprecation_wrapper.py:119] From C:\ProgramData\Anaconda3\envs\py36-cycleGAN\lib\site-packages\keras\backend\tensorflow_backend.py:4115: The name tf.random_normal is deprecated. Please use tf.random.normal instead.

W0323 21:28:04.329533 10452 deprecation_wrapper.py:119] From C:\ProgramData\Anaconda3\envs\py36-cycleGAN\lib\site-packages\keras\optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.

W0323 

OSError: `pydot` failed to call GraphViz.Please install GraphViz (https://www.graphviz.org/) and ensure that its executables are in the $PATH.

In [None]:
gan.g_BA.summary()

In [None]:
gan.g_AB.summary()

In [None]:
gan.d_A.summary()

In [None]:
gan.d_B.summary()

# train

In [None]:
epoch1 = 1
epoch200 = 200
BATCH_SIZE = 1
EPOCHS = epoch1
PRINT_EVERY_N_BATCHES = 10

TEST_A_FILE = 'c3r1e0n1.tif'
TEST_B_FILE = 'c3r1e1n1.tif'

In [None]:
gan.train(data_loader
        , run_folder = RUN_FOLDER
        , epochs=EPOCHS
        , test_A_file = TEST_A_FILE
        , test_B_file = TEST_B_FILE
        , batch_size=BATCH_SIZE
        , sample_interval=PRINT_EVERY_N_BATCHES)
        

# loss

In [None]:
fig = plt.figure(figsize=(20,10))

plt.plot([x[1] for x in gan.g_losses], color='green', linewidth=0.1) #DISCRIM LOSS
# plt.plot([x[2] for x in gan.g_losses], color='orange', linewidth=0.1)
plt.plot([x[3] for x in gan.g_losses], color='blue', linewidth=0.1) #CYCLE LOSS
# plt.plot([x[4] for x in gan.g_losses], color='orange', linewidth=0.25)
plt.plot([x[5] for x in gan.g_losses], color='red', linewidth=0.25) #ID LOSS
# plt.plot([x[6] for x in gan.g_losses], color='orange', linewidth=0.25)

plt.plot([x[0] for x in gan.g_losses], color='black', linewidth=0.25)

# plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.ylim(0, 5)

plt.show()