# Cycle GAN in TensorFlow


Version: https://github.com/leehomyc/cyclegan-1

Data: cxr8

###  Import load data class

In [1]:
from helpers import *
from data_loader import Data_loader

### cycle GAN main code and training

In [2]:
from tensorflow.core.protobuf import config_pb2

"""Code for training CycleGAN."""
from datetime import datetime
import json
import numpy as np
import os
import random
from scipy.misc import imsave
import tensorflow as tf

from losses import *
from model import *
from parameters import *

slim = tf.contrib.slim

config = tf.ConfigProto()
config.gpu_options.allow_growth = False

In [3]:
class CycleGAN:
    """The CycleGAN module."""

    def __init__(self, to_restore, output_root_dir,
                 dataset_name, checkpoint_dir, single_input_dir):
        
        current_time = datetime.now().strftime("%Y%m%d-%H%M%S")

        
        self._pool_size = POOL_SIZE
        self._lambda_a = LAMBDA_A
        self._lambda_b = LAMBDA_B
        self._output_dir = os.path.join(output_root_dir, current_time)
        
        self._images_train = os.path.join(self._output_dir, 'imgs_train')
        self._images_val = os.path.join(self._output_dir, 'imgs_val')
        self._images_test = os.path.join(self._output_dir, 'imgs_test')
        
        self._num_imgs_to_save = SAVE_IMAGES 
        self._to_restore = to_restore
        self._base_lr = BASE_LR
        self._max_step = MAX_STEP
        self.generator_frequency = GEN_FREQ
        
        self._network_version = NET_VERSION
        self._dataset_name = dataset_name
        self._checkpoint_dir = checkpoint_dir
        self._do_flipping = DO_FLIPPING
        self._skip = SKIP
        
        self.batch_size = BATCH_SIZE

        self.fake_images_A = np.zeros(
            (self._pool_size, BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH,
             IMG_CHANNELS)
        )
        self.fake_images_B = np.zeros(
            (self._pool_size, BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH,
             IMG_CHANNELS)
        )

    def model_setup(self):
        """
        This function sets up the model to train.

        self.input_A/self.input_B -> Set of training images.
        self.fake_A/self.fake_B -> Generated images by corresponding generator
        of input_A and input_B
        self.lr -> Learning rate variable
        self.cyc_A/ self.cyc_B -> Images generated after feeding
        self.fake_A/self.fake_B to corresponding generator.
        This is use to calculate cyclic loss
        """
        self.input_a = tf.placeholder(
            tf.float32, [
                None,
                IMG_WIDTH,
                IMG_HEIGHT,
                IMG_CHANNELS
            ], name="input_A")
        self.input_b = tf.placeholder(
            tf.float32, [
                None,
                IMG_WIDTH,
                IMG_HEIGHT,
                IMG_CHANNELS
            ], name="input_B")

        self.fake_pool_A = tf.placeholder(
            tf.float32, [
                None,
                IMG_WIDTH,
                IMG_HEIGHT,
                IMG_CHANNELS
            ], name="fake_pool_A")
        self.fake_pool_B = tf.placeholder(
            tf.float32, [
                None,
                IMG_WIDTH,
                IMG_HEIGHT,
                IMG_CHANNELS
            ], name="fake_pool_B")

        self.global_step = slim.get_or_create_global_step()

        self.num_fake_inputs = 0

        self.learning_rate = tf.placeholder(tf.float32, shape=[], name="lr")

        inputs = {
            'images_a': self.input_a,
            'images_b': self.input_b,
            'fake_pool_a': self.fake_pool_A,
            'fake_pool_b': self.fake_pool_B,
        }

        outputs = get_outputs(
            inputs, network=self._network_version, skip=self._skip)

        self.prob_real_a_is_real = outputs['prob_real_a_is_real']
        self.prob_real_b_is_real = outputs['prob_real_b_is_real']
        self.fake_images_a = outputs['fake_images_a']
        self.fake_images_b = outputs['fake_images_b']
        self.prob_fake_a_is_real = outputs['prob_fake_a_is_real']
        self.prob_fake_b_is_real = outputs['prob_fake_b_is_real']

        self.cycle_images_a = outputs['cycle_images_a']
        self.cycle_images_b = outputs['cycle_images_b']

        self.prob_fake_pool_a_is_real = outputs['prob_fake_pool_a_is_real']
        self.prob_fake_pool_b_is_real = outputs['prob_fake_pool_b_is_real']

    def compute_losses(self):
        """
        In this function we are defining the variables for loss calculations
        and training model.

        d_loss_A/d_loss_B -> loss for discriminator A/B
        g_loss_A/g_loss_B -> loss for generator A/B
        *_trainer -> Various trainer for above loss functions
        *_summ -> Summary variables for above loss functions
        """
        cycle_consistency_loss_a = \
            self._lambda_a * cycle_consistency_loss(
                real_images=self.input_a, generated_images=self.cycle_images_a,
            )
        cycle_consistency_loss_b = \
            self._lambda_b * cycle_consistency_loss(
                real_images=self.input_b, generated_images=self.cycle_images_b,
            )

        lsgan_loss_a = lsgan_loss_generator(self.prob_fake_a_is_real)
        lsgan_loss_b = lsgan_loss_generator(self.prob_fake_b_is_real)

        g_loss_A = \
            cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b
        g_loss_B = \
            cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a

        d_loss_A = lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_a_is_real,
            prob_fake_is_real=self.prob_fake_pool_a_is_real,
        )
        d_loss_B = lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_b_is_real,
            prob_fake_is_real=self.prob_fake_b_is_real,
        )

        optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)

        self.model_vars = tf.trainable_variables()

        d_A_vars = [var for var in self.model_vars if 'd_A' in var.name]
        g_A_vars = [var for var in self.model_vars if 'g_A' in var.name]
        d_B_vars = [var for var in self.model_vars if 'd_B' in var.name]
        g_B_vars = [var for var in self.model_vars if 'g_B' in var.name]

        self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)
        self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)
        self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)
        self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars)

        for var in self.model_vars:
            print(var.name)

        # Summary variables for tensorboard training
        self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)
        self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)
        self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)
        self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)
        
        # Summary variables for tensorboard validating
        self.g_A_loss_summ_val = tf.summary.scalar("g_A_loss_val", g_loss_A)
        self.g_B_loss_summ_val = tf.summary.scalar("g_B_loss_val", g_loss_B)
        self.d_A_loss_summ_val = tf.summary.scalar("d_A_loss_val", d_loss_A)
        self.d_B_loss_summ_val = tf.summary.scalar("d_B_loss_val", d_loss_B)

    def save_images(self, sess, epoch, inpt_a, inpt_b, imgs_dir):
        
        """
        Saves input and output images.

        :param sess: The session.
        :param epoch: Currnt epoch.
        """
        
        if not os.path.exists(imgs_dir):
            os.makedirs(imgs_dir)

        
        with open(os.path.join(self._output_dir, 'epoch_' + str(epoch) + '.html'), 'w') as v_html:
            
            for i in range(0, self._num_imgs_to_save):
                
                print("Saving image {}/{}".format(i, self._num_imgs_to_save))
                

                ################################
                
                '''
                fake_A_temp,fake_B_temp,cyc_A_temp, cyc_B_temp = sess.run([
                    self.fake_images_a,
                    self.fake_images_b,
                    self.cycle_images_a,
                    self.cycle_images_b
                ], feed_dict={
                    self.input_a: inpt_a[i],
                    self.input_b: inpt_b[i]
                })
                
                input_names = ['inputA_', 'inputB_']
                names = ['fakeA_','fakeB_', 'cycA_', 'cycB_']
                
                tensors = [fake_B_temp, fake_A_temp, cyc_A_temp, cyc_B_temp]
                
                '''
                
                fake_B_temp, cyc_A_temp = sess.run([
                    self.fake_images_b,
                    self.cycle_images_a,
                ], feed_dict={
                    self.input_a: inpt_a[i:(i+1)],
                    self.input_b: inpt_b[i:(i+1)]
                })
                
                input_names = ['inputA_', 'inputB_']
                names = ['fakeA_','cycA_']
                input_tensors = [inpt_a[i:(i+1)],
                                 inpt_b[i:(i+1)]]
                tensors = [fake_B_temp, cyc_A_temp]
                

                for name, tensor in zip(names, tensors):
                    
                    image_name = name + str(epoch) + "_" + str(i) + ".jpg"
                    
                    img = output_to_img(sess, tensor[0], IMG_CHANNELS)
                    
                    imsave(os.path.join(imgs_dir, image_name),img)
                    
                    v_html.write("<img src=\"" +
                        os.path.join('imgs', image_name) + "\">")
                
                
                if epoch==0:
                    
                    for name, tensor in zip(input_names, input_tensors):
                        
                        image_name = name + str(epoch) + "_" + str(i) + ".jpg"
                        
                        img = output_to_img(sess, tensor[0], IMG_CHANNELS)
                        
                        imsave(os.path.join(imgs_dir, image_name),img)
                        
                        v_html.write("<img src=\"" +
                            os.path.join('imgs', image_name) + "\">")   
                        
                    
                v_html.write("<br>")
                
   

                    

    def fake_image_pool(self, num_fakes, fake, fake_pool):
        """
        This function saves the generated image to corresponding
        pool of images.

        It keeps on feeling the pool till it is full and then randomly
        selects an already stored image and replace it with new one.
        """
        if num_fakes < self._pool_size:
            fake_pool[num_fakes] = fake
            return fake
        else:
            p = random.random()
            if p > 0.5:
                random_id = random.randint(0, self._pool_size - 1)
                temp = fake_pool[random_id]
                fake_pool[random_id] = fake
                return temp
            else:
                return fake
    
    
    def input_read(self):
        
        loader = Data_loader()
        loader.build_data()
        
        self.a_train = loader.a_train
        self.a_val = loader.a_val
        self.a_test = loader.a_test

        self.b_train = loader.b_train
        self.b_val = loader.b_val
        self.b_test = loader.b_test
        
        self.limitant_n = min((self.a_train.shape[0], self.b_train.shape[0]))
        self.n_batches = int(self.limitant_n/self.batch_size)
        
        print('WARNING: training images are limited to size: ', self.n_batches*self.batch_size)
        
        print(self.a_train.shape)
        
        
    def save_val_loss(self, sess, writer, epoch, i):
                
            lim_val = min(self.a_val.shape[0], self.b_val.shape[0])
            index = i%lim_val
            val_a =  self.a_val[index:index+1]
            val_b = self.b_val[index:index+1]
            
            # Generate fake A
            fake_A_temp, summary_str = sess.run(
                            [self.fake_images_a,
                             self.g_B_loss_summ_val],
                            feed_dict={
                                self.input_a:val_a,
                                self.input_b:val_b
                            }
                        )
            
            writer.add_summary(summary_str, epoch * self.n_batches + i)
                
            # d_A_loss
            summary_str = sess.run(self.d_A_loss_summ_val,
                        feed_dict={
                            self.input_a:val_a,
                            self.input_b:val_b,
                            self.fake_pool_A: fake_A_temp
                        }
                    )
                    
            writer.add_summary(summary_str, epoch * self.n_batches + i)
            
            # Generate fake B
            fake_B_temp, summary_str = sess.run(
                            [self.fake_images_b,
                             self.g_A_loss_summ_val],
                            feed_dict={
                                self.input_a:val_a,
                                self.input_b:val_b
                            }
                        )
                        
                        
            writer.add_summary(summary_str, epoch * self.n_batches + i)
                 
            # d_B_loss
            summary_str = sess.run(self.d_B_loss_summ_val,
                        feed_dict={
                            self.input_a:val_a,
                            self.input_b:val_b,
                            self.fake_pool_B:fake_B_temp
                        }
                    )
                 
            writer.add_summary(summary_str, epoch * self.n_batches + i)
            
            return writer
        

    def train(self):
        """Training Function."""
        tf.set_random_seed(1)
        
        # Build the network
        self.model_setup()

        # Loss function calculations
        self.compute_losses()
        
        # Load data
        self.input_read()
        

        # Initializing the global variables
        init = (tf.global_variables_initializer(),
                tf.local_variables_initializer())
        saver = tf.train.Saver()

        
        with tf.Session(config=config) as sess:
            
            sess.run(init)
        
            # Creates output folder
            if not os.path.exists(self._output_dir):
                os.makedirs(self._output_dir)
                
            writer = tf.summary.FileWriter(self._output_dir)
            
            # Restore the model to run the model from last checkpoint
            if self._to_restore:
                chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)
                saver.restore(sess, chkpt_fname)
                

            # EPOCH LOOP FOR TRAINING
            for epoch in range(sess.run(self.global_step), self._max_step):
                
                print("In the epoch ", epoch)
                saver.save(sess, os.path.join(
                    self._output_dir, "cyclegan"), global_step=epoch)
                
                # save images
                if (epoch % 1 == 0) or (epoch == 0):
                    self.save_images(sess, epoch, self.a_train, self.b_train, self._images_train)
                    self.save_images(sess, epoch, self.a_val, self.b_val, self._images_val)
                
                
                # Dealing with the learning rate as per the epoch number
                if epoch < 100:
                    curr_lr = self._base_lr
                else:
                    curr_lr = self._base_lr - \
                        self._base_lr * (epoch - 100) / 100

                
                # RUNNING THE TRAINING ON IMAGES  
                
                for i in range(0, self.n_batches):
                    
                    print("Processing batch {}/{}".format(i, self.n_batches))
                
                    batch_a = self.a_train[i*self.batch_size:(i+1)*self.batch_size]
                    batch_b = self.b_train[i*self.batch_size:(i+1)*self.batch_size]
                    
                    if (epoch % self.generator_frequency==0) or (epoch==0): 
                        # Optimizing the G_A network
                        _, fake_B_temp, summary_str = sess.run(
                            [self.g_A_trainer,
                             self.fake_images_b,
                             self.g_A_loss_summ],
                            feed_dict={
                                self.input_a:batch_a,
                                self.input_b:batch_b,
                                self.learning_rate: curr_lr
                            }
                        )
                    else:
                        # Generate fake B
                        fake_B_temp, summary_str = sess.run(
                            [self.fake_images_b,
                             self.g_A_loss_summ],
                            feed_dict={
                                self.input_a:batch_a,
                                self.input_b:batch_b,
                                self.learning_rate: curr_lr
                            }
                        )
                        
                        
                    writer.add_summary(summary_str, epoch * self.n_batches + i)

                    fake_B_temp1 = self.fake_image_pool(
                        self.num_fake_inputs, fake_B_temp, self.fake_images_B)

                    # Optimizing the D_B network
                    _, summary_str = sess.run(
                        [self.d_B_trainer, self.d_B_loss_summ],
                        feed_dict={
                                self.input_a:batch_a,
                                self.input_b:batch_b,
                            self.learning_rate: curr_lr,
                            self.fake_pool_B: fake_B_temp1
                        }
                    )
                    writer.add_summary(summary_str, epoch * self.n_batches + i)

                    if (epoch % self.generator_frequency == 0) or (epoch==0): 
                        # Optimizing the G_B network
                        _, fake_A_temp, summary_str = sess.run(
                            [self.g_B_trainer,
                             self.fake_images_a,
                             self.g_B_loss_summ],
                            feed_dict={
                                self.input_a:batch_a,
                                self.input_b:batch_b,
                                self.learning_rate: curr_lr
                            }
                        )
                    
                    else:
                        # Generate fake A
                        fake_A_temp, summary_str = sess.run(
                            [self.fake_images_a,
                             self.g_B_loss_summ],
                            feed_dict={
                                self.input_a:batch_a,
                                self.input_b:batch_b,
                                self.learning_rate: curr_lr
                            }
                        )
                        
                    writer.add_summary(summary_str, epoch * self.n_batches + i)

                    fake_A_temp1 = self.fake_image_pool(
                        self.num_fake_inputs, fake_A_temp, self.fake_images_A)

                    # Optimizing the D_A network
                    _, summary_str = sess.run(
                        [self.d_A_trainer, self.d_A_loss_summ],
                        feed_dict={
                            self.input_a:batch_a,
                            self.input_b:batch_b,
                            self.learning_rate: curr_lr,
                            self.fake_pool_A: fake_A_temp1
                        }
                    )
                    writer.add_summary(summary_str, epoch * self.n_batches + i)
                    
                    if (epoch % 10) or (epoch==0):
                        writer = self.save_val_loss(sess, writer, epoch, i)

                    writer.flush()
                    self.num_fake_inputs += 1

                sess.run(tf.assign(self.global_step, epoch + 1))

            writer.add_graph(sess.graph)

    def test(self):
        """Test Function."""
        print("Testing the results")

        
        self.model_setup()
        saver = tf.train.Saver()
        init = tf.global_variables_initializer()

        with tf.Session() as sess:
            
            sess.run(init)
            
            #loading_data
            self.load_data(sess)
            
            chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)
            saver.restore(sess, chkpt_fname)
            
            # png for direct viewing
            self._num_imgs_to_save = min((self.a_test.shape[0], self.b_test.shape[0]))
            self.save_images(sess, 0, self.a_test, self.b_test, self._images_test)


            
def main(to_train, log_dir, dataset_name, checkpoint_dir, single_input_dir):
    """

    :param to_train: Specify whether it is training or testing. 1: training; 2:
     resuming from latest checkpoint; 0: testing.
    :param log_dir: The root dir to save checkpoints and imgs. The actual dir
    is the root dir appended by the folder with the name timestamp.
    :param config_filename: The configuration file.
    :param checkpoint_dir: The directory that saves the latest checkpoint. It
    only takes effect when to_train == 2.
    :param skip: A boolean indicating whether to add skip connection between
    input and output.
    """
    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)

    to_restore = (to_train == 2)
    cyclegan_model = CycleGAN(to_restore, log_dir, dataset_name, checkpoint_dir, single_input_dir)
    
    if to_train > 0:
        cyclegan_model.train()
    else:
        cyclegan_model.test()

In [4]:
to_train=1

log_dir="./output/cyclegan_from_fluorospot_to_digital_diagnost/debug_224_size/base_to_gen_freq="+str(GEN_FREQ)+"_retrain_seed_0/exp_01"
output_root_dir = "./output"
checkpoint_dir="./output/cyclegan_from_fluorospot_to_digital_diagnost/debug_224_size/base_to_gen_freq=1/exp_01/20180321-161135"
single_input_dir= ""
dataset_name = 'hugxgan'

main(to_train, log_dir, dataset_name ,checkpoint_dir, single_input_dir)

Model/d_A/c1/Conv/weights:0
Model/d_A/c1/Conv/biases:0
Model/d_A/c2/Conv/weights:0
Model/d_A/c2/Conv/biases:0
Model/d_A/c2/instance_norm/scale:0
Model/d_A/c2/instance_norm/offset:0
Model/d_A/c3/Conv/weights:0
Model/d_A/c3/Conv/biases:0
Model/d_A/c3/instance_norm/scale:0
Model/d_A/c3/instance_norm/offset:0
Model/d_A/c4/Conv/weights:0
Model/d_A/c4/Conv/biases:0
Model/d_A/c4/instance_norm/scale:0
Model/d_A/c4/instance_norm/offset:0
Model/d_A/c5/Conv/weights:0
Model/d_A/c5/Conv/biases:0
Model/d_B/c1/Conv/weights:0
Model/d_B/c1/Conv/biases:0
Model/d_B/c2/Conv/weights:0
Model/d_B/c2/Conv/biases:0
Model/d_B/c2/instance_norm/scale:0
Model/d_B/c2/instance_norm/offset:0
Model/d_B/c3/Conv/weights:0
Model/d_B/c3/Conv/biases:0
Model/d_B/c3/instance_norm/scale:0
Model/d_B/c3/instance_norm/offset:0
Model/d_B/c4/Conv/weights:0
Model/d_B/c4/Conv/biases:0
Model/d_B/c4/instance_norm/scale:0
Model/d_B/c4/instance_norm/offset:0
Model/d_B/c5/Conv/weights:0
Model/d_B/c5/Conv/biases:0
Model/g_A/c1/Conv/weight

KeyboardInterrupt: 