# CycleGAN train

In [1]:
import os
import matplotlib.pyplot as plt

from models.cycleGAN import CycleGAN
from utils.loaders import DataLoader

os.environ['KMP_DUPLICATE_LIB_OK']='True' #for mac os x weird bug

Using TensorFlow backend.


In [2]:

# run params
SECTION = 'paint'
RUN_ID = '0001'
DATA_NAME = 'apple2orange'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' # 'build' # 

# data

In [3]:
IMAGE_SIZE = 128

In [4]:

data_loader = DataLoader(dataset_name=DATA_NAME, img_res=(IMAGE_SIZE, IMAGE_SIZE))


# architecture

In [5]:
gan = CycleGAN(
    input_dim = (IMAGE_SIZE,IMAGE_SIZE,3)
    ,learning_rate = 0.0002
    , buffer_max_length = 50
    , lambda_validation = 1
    , lambda_reconstr = 10
    , lambda_id = 2
    , generator_type = 'unet'
    , gen_n_filters = 32
    , disc_n_filters = 32
    )

if mode == 'build':
    gan.save(RUN_FOLDER)
else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))
    


In [6]:
gan.g_BA.summary()

Model: "model_4"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, 128, 128, 3)  0                                            
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 64, 64, 32)   1568        input_4[0][0]                    
__________________________________________________________________________________________________
instance_normalization_14 (Inst (None, 64, 64, 32)   0           conv2d_19[0][0]                  
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 64, 64, 32)   0           instance_normalization_14[0][0]  
____________________________________________________________________________________________

In [7]:
gan.g_AB.summary()

Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            (None, 128, 128, 3)  0                                            
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 64, 64, 32)   1568        input_3[0][0]                    
__________________________________________________________________________________________________
instance_normalization_7 (Insta (None, 64, 64, 32)   0           conv2d_11[0][0]                  
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 64, 64, 32)   0           instance_normalization_7[0][0]   
____________________________________________________________________________________________

In [8]:
gan.d_A.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 128, 128, 3)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 64, 64, 32)        1568      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 64, 64, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 32, 64)        32832     
_________________________________________________________________
instance_normalization_1 (In (None, 32, 32, 64)        0         
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 16, 16, 128)       1312

In [9]:
gan.d_B.summary()

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         (None, 128, 128, 3)       0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 64, 64, 32)        1568      
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 64, 64, 32)        0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 32, 32, 64)        32832     
_________________________________________________________________
instance_normalization_4 (In (None, 32, 32, 64)        0         
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 16, 16, 128)       1312

# train

In [10]:
BATCH_SIZE = 1
EPOCHS = 200
PRINT_EVERY_N_BATCHES = 10

TEST_A_FILE = 'n07740461_14740.jpg'
TEST_B_FILE = 'n07749192_4241.jpg'

In [11]:
gan.train(data_loader
        , run_folder = RUN_FOLDER
        , epochs=EPOCHS
        , test_A_file = TEST_A_FILE
        , test_B_file = TEST_B_FILE
        , batch_size=BATCH_SIZE
        , sample_interval=PRINT_EVERY_N_BATCHES)
        

  'Discrepancy between trainable weights and collected trainable'


[Epoch 0/200] [Batch 0/995] [D loss: 1.335864, acc:  32%] [G loss: 27.144220, adv: 2.424437, recon: 2.078396, id: 1.967913] time: 0:00:57.491947 
[Epoch 0/200] [Batch 1/995] [D loss: 1.119941, acc:  37%] [G loss: 19.346275, adv: 1.531295, recon: 1.494560, id: 1.434687] time: 0:01:04.167604 
[Epoch 0/200] [Batch 2/995] [D loss: 0.952480, acc:  37%] [G loss: 19.007957, adv: 1.716501, recon: 1.443557, id: 1.427946] time: 0:01:06.472727 
[Epoch 0/200] [Batch 3/995] [D loss: 0.945272, acc:  42%] [G loss: 17.361084, adv: 1.434583, recon: 1.321089, id: 1.357804] time: 0:01:09.160271 
[Epoch 0/200] [Batch 4/995] [D loss: 0.736059, acc:  44%] [G loss: 16.129923, adv: 1.709264, recon: 1.220059, id: 1.110035] time: 0:01:11.502483 
[Epoch 0/200] [Batch 5/995] [D loss: 0.730646, acc:  43%] [G loss: 15.360565, adv: 1.538663, recon: 1.165901, id: 1.081448] time: 0:01:15.680109 
[Epoch 0/200] [Batch 6/995] [D loss: 0.776774, acc:  40%] [G loss: 19.356518, adv: 1.277091, recon: 1.531268, id: 1.383376] 

KeyboardInterrupt: 

# loss

In [None]:
fig = plt.figure(figsize=(20,10))

plt.plot([x[1] for x in gan.g_losses], color='green', linewidth=0.1) #DISCRIM LOSS
# plt.plot([x[2] for x in gan.g_losses], color='orange', linewidth=0.1)
plt.plot([x[3] for x in gan.g_losses], color='blue', linewidth=0.1) #CYCLE LOSS
# plt.plot([x[4] for x in gan.g_losses], color='orange', linewidth=0.25)
plt.plot([x[5] for x in gan.g_losses], color='red', linewidth=0.25) #ID LOSS
# plt.plot([x[6] for x in gan.g_losses], color='orange', linewidth=0.25)

plt.plot([x[0] for x in gan.g_losses], color='black', linewidth=0.25)

# plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.ylim(0, 5)

plt.show()