# Cycle GAN in TensorFlow


Version: https://github.com/leehomyc/cyclegan-1

Data: cxr8

###  Import load data class

In [1]:
from helpers import *
from data_loader import Data_loader

### cycle GAN main code and training

In [2]:
from tensorflow.core.protobuf import config_pb2

"""Code for training CycleGAN."""
from datetime import datetime
import json
import numpy as np
import os
import random
from scipy.misc import imsave
import tensorflow as tf

from losses import *
from model import *
from parameters import *

slim = tf.contrib.slim

In [None]:
class CycleGAN:
    """The CycleGAN module."""

    def __init__(self, to_restore, output_root_dir,
                 dataset_name, checkpoint_dir, single_input_dir):
        
        current_time = datetime.now().strftime("%Y%m%d-%H%M%S")

        
        self._pool_size = POOL_SIZE
        self._lambda_a = LAMBDA_A
        self._lambda_b = LAMBDA_B
        self._output_dir = os.path.join(output_root_dir, current_time)
        
        self._images_train = os.path.join(self._output_dir, 'imgs_train')
        self._images_val = os.path.join(self._output_dir, 'imgs_val')
        self._images_test = os.path.join(self._output_dir, 'imgs_test')
        
        self._num_imgs_to_save = SAVE_IMAGES 
        self._to_restore = to_restore
        self._base_lr = BASE_LR
        self._max_step = MAX_STEP
        self._network_version = NET_VERSION
        self._dataset_name = dataset_name
        self._checkpoint_dir = checkpoint_dir
        self._do_flipping = DO_FLIPPING
        self._skip = SKIP
        
        self.batch_size = BATCH_SIZE

        self.fake_images_A = np.zeros(
            (self._pool_size, 1, IMG_HEIGHT, IMG_WIDTH,
             IMG_CHANNELS)
        )
        self.fake_images_B = np.zeros(
            (self._pool_size, 1, IMG_HEIGHT, IMG_WIDTH,
             IMG_CHANNELS)
        )

    def model_setup(self):
        """
        This function sets up the model to train.

        self.input_A/self.input_B -> Set of training images.
        self.fake_A/self.fake_B -> Generated images by corresponding generator
        of input_A and input_B
        self.lr -> Learning rate variable
        self.cyc_A/ self.cyc_B -> Images generated after feeding
        self.fake_A/self.fake_B to corresponding generator.
        This is use to calculate cyclic loss
        """
        self.input_a = tf.placeholder(
            tf.float32, [
                BATCH_SIZE,
                IMG_WIDTH,
                IMG_HEIGHT,
                IMG_CHANNELS
            ], name="input_A")
        self.input_b = tf.placeholder(
            tf.float32, [
                BATCH_SIZE,
                IMG_WIDTH,
                IMG_HEIGHT,
                IMG_CHANNELS
            ], name="input_B")
        
        self.fake_pool_A = tf.placeholder(
            tf.float32, [
                None,
                IMG_WIDTH,
                IMG_HEIGHT,
                IMG_CHANNELS
            ], name="fake_pool_A")
        self.fake_pool_B = tf.placeholder(
            tf.float32, [
                None,
                IMG_WIDTH,
                IMG_HEIGHT,
                IMG_CHANNELS
            ], name="fake_pool_B")
        
         self.logit_a = tf.placeholder(
            tf.float32, [
                BATCH_SIZE,
                6
            ], name="logit_A")

        self.global_step = slim.get_or_create_global_step()

        self.num_fake_inputs = 0

        self.learning_rate = tf.placeholder(tf.float32, shape=[], name="lr")

        inputs = {
            'images_a': self.input_a,
            'images_b': self.input_b,
            'fake_pool_a': self.fake_pool_A,
            'fake_pool_b': self.fake_pool_B,
        }

        outputs = get_outputs(
            inputs, network=self._network_version, skip=self._skip)

        self.prob_real_a_is_real = outputs['prob_real_a_is_real']
        self.prob_real_b_is_real = outputs['prob_real_b_is_real']
        self.fake_images_a = outputs['fake_images_a']
        self.fake_images_b = outputs['fake_images_b']
        self.prob_fake_a_is_real = outputs['prob_fake_a_is_real']
        self.prob_fake_b_is_real = outputs['prob_fake_b_is_real']

        self.cycle_images_a = outputs['cycle_images_a']
        self.cycle_images_b = outputs['cycle_images_b']

        self.prob_fake_pool_a_is_real = outputs['prob_fake_pool_a_is_real']
        self.prob_fake_pool_b_is_real = outputs['prob_fake_pool_b_is_real']
        
        self.prob_b_from_class_i = outputs['prob_b_from_class_i']

    def compute_losses(self):
        """
        In this function we are defining the variables for loss calculations
        and training model.

        d_loss_A/d_loss_B -> loss for discriminator A/B
        g_loss_A/g_loss_B -> loss for generator A/B
        *_trainer -> Various trainer for above loss functions
        *_summ -> Summary variables for above loss functions
        """
        cycle_consistency_loss_a = \
            self._lambda_a * cycle_consistency_loss(
                real_images=self.input_a, generated_images=self.cycle_images_a,
            )
        cycle_consistency_loss_b = \
            self._lambda_b * cycle_consistency_loss(
                real_images=self.input_b, generated_images=self.cycle_images_b,
            )

        lsgan_loss_a = lsgan_loss_generator(self.prob_fake_a_is_real)
        lsgan_loss_b = lsgan_loss_generator(self.prob_fake_b_is_real)

        g_loss_A = \
            cycle_consistency_loss_a + cycle_consistency_loss_b + lsgan_loss_b
        g_loss_B = \
            cycle_consistency_loss_b + cycle_consistency_loss_a + lsgan_loss_a

        d_loss_A = lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_a_is_real,
            prob_fake_is_real=self.prob_fake_pool_a_is_real,
        )
        d_loss_B = lsgan_loss_discriminator(
            prob_real_is_real=self.prob_real_b_is_real,
            prob_fake_is_real=self.prob_fake_pool_b_is_real,
        )

        c_loss_B = origin_discrmn_loss(self.prob_b_from_class_i, self.logit_a)
        

        optimizer = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5)

        self.model_vars = tf.trainable_variables()

        d_A_vars = [var for var in self.model_vars if 'd_A' in var.name]
        g_A_vars = [var for var in self.model_vars if 'g_A' in var.name]
        d_B_vars = [var for var in self.model_vars if 'd_B' in var.name]
        g_B_vars = [var for var in self.model_vars if 'g_B' in var.name]
        c_B_vars = [var for var in self.model_vars if 'c_B' in var.name]

        self.d_A_trainer = optimizer.minimize(d_loss_A, var_list=d_A_vars)
        self.d_B_trainer = optimizer.minimize(d_loss_B, var_list=d_B_vars)
        self.g_A_trainer = optimizer.minimize(g_loss_A, var_list=g_A_vars)
        self.g_B_trainer = optimizer.minimize(g_loss_B, var_list=g_B_vars)
        self.c_B_trainer = optimizer.minimize(c_loss_B, var_list=c_B_vars)

        for var in self.model_vars:
            print(var.name)

        # Summary variables for tensorboard
        self.g_A_loss_summ = tf.summary.scalar("g_A_loss", g_loss_A)
        self.g_B_loss_summ = tf.summary.scalar("g_B_loss", g_loss_B)
        self.d_A_loss_summ = tf.summary.scalar("d_A_loss", d_loss_A)
        self.d_B_loss_summ = tf.summary.scalar("d_B_loss", d_loss_B)
        self.c_B_loss_summ = tf.summary.scalar("c_B_loss", c_loss_B)

    def save_images(self, sess, epoch, inpt_a, inpt_b, imgs_dir):
        
        """
        Saves input and output images.

        :param sess: The session.
        :param epoch: Currnt epoch.
        """
        
        if not os.path.exists(imgs_dir):
            os.makedirs(imgs_dir)

        
        with open(os.path.join(self._output_dir, 'epoch_' + str(epoch) + '.html'), 'w') as v_html:
            
            for i in range(0, self._num_imgs_to_save):
                
                print("Saving image {}/{}".format(i, self._num_imgs_to_save))
                

                ################################
                
                '''
                fake_A_temp,fake_B_temp,cyc_A_temp, cyc_B_temp = sess.run([
                    self.fake_images_a,
                    self.fake_images_b,
                    self.cycle_images_a,
                    self.cycle_images_b
                ], feed_dict={
                    self.input_a: inpt_a[i],
                    self.input_b: inpt_b[i]
                })
                
                input_names = ['inputA_', 'inputB_']
                names = ['fakeA_','fakeB_', 'cycA_', 'cycB_']
                
                tensors = [fake_B_temp, fake_A_temp, cyc_A_temp, cyc_B_temp]
                
                '''
                
                fake_B_temp, cyc_A_temp = sess.run([
                    self.fake_images_b,
                    self.cycle_images_a,
                ], feed_dict={
                    self.input_a: inpt_a[i*self.batch_size:(i+1)*self.batch_size],
                    self.input_b: inpt_b[i*self.batch_size:(i+1)*self.batch_size]
                })
                
                input_names = ['inputA_', 'inputB_']
                names = ['fakeA_','cycA_']
                input_tensors = [inpt_a[i*self.batch_size:(i+1)*self.batch_size],
                                 inpt_b[i*self.batch_size:(i+1)*self.batch_size]]
                tensors = [fake_B_temp, cyc_A_temp]
                

                for name, tensor in zip(names, tensors):
                    
                    image_name = name + str(epoch) + "_" + str(i) + ".jpg"
                    
                    img = output_to_img(sess, tensor[0], IMG_CHANNELS)
                    
                    imsave(os.path.join(imgs_dir, image_name),img)
                    
                    v_html.write("<img src=\"" +
                        os.path.join('imgs', image_name) + "\">")
                
                
                if epoch==0:
                    
                    for name, tensor in zip(input_names, input_tensors):
                        
                        image_name = name + str(epoch) + "_" + str(i) + ".jpg"
                        
                        img = output_to_img(sess, tensor[0], IMG_CHANNELS)
                        
                        imsave(os.path.join(imgs_dir, image_name),img)
                        
                        v_html.write("<img src=\"" +
                            os.path.join('imgs', image_name) + "\">")   
                        
                    
                v_html.write("<br>")
                
                

    def fake_image_pool(self, num_fakes, fake, fake_pool):
        """
        This function saves the generated image to corresponding
        pool of images.

        It keeps on feeling the pool till it is full and then randomly
        selects an already stored image and replace it with new one.
        """
        if num_fakes < self._pool_size:
            fake_pool[num_fakes] = fake
            return fake
        else:
            p = random.random()
            if p > 0.5:
                random_id = random.randint(0, self._pool_size - 1)
                temp = fake_pool[random_id]
                fake_pool[random_id] = fake
                return temp
            else:
                return fake
    
    
    def input_read(self):
        
        loader = Data_loader()
        loader.build_data()
        
        self.a_train = loader.a_train
        self.logit_a_train = loader.logit_a_train
        self.a_val = loader.a_val
        self.a_test = loader.a_test

        self.b_train = loader.b_train
        self.b_val = loader.b_val
        self.b_test = loader.b_test
        
        self.limitant_n = min((self.a_train.shape[0], self.b_train.shape[0]))
        self.n_batches = int(self.limitant_n/self.batch_size)
        
        print('WARNING: training images are limited to size: ', self.n_batches*self.batch_size)
        
        print(self.a_train.shape)
        

    def train(self):
        """Training Function."""
        
        # Build the network
        self.model_setup()

        # Loss function calculations
        self.compute_losses()
        
        # Load data
        self.input_read()
        

        # Initializing the global variables
        init = (tf.global_variables_initializer(),
                tf.local_variables_initializer())
        saver = tf.train.Saver()

        
        with tf.Session() as sess:
            
            sess.run(init)
        
            # Creates output folder
            if not os.path.exists(self._output_dir):
                os.makedirs(self._output_dir)
                
            writer = tf.summary.FileWriter(self._output_dir)
            
            # Restore the model to run the model from last checkpoint
            if self._to_restore:
                chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)
                saver.restore(sess, chkpt_fname)
                

            # EPOCH LOOP FOR TRAINING
            for epoch in range(sess.run(self.global_step), self._max_step):
                
                print("In the epoch ", epoch)
                saver.save(sess, os.path.join(
                    self._output_dir, "cyclegan"), global_step=epoch)
                
                # save images
                if (epoch % 5 == 0) or (epoch == 0):
                    self.save_images(sess, epoch, self.a_train, self.b_train, self._images_train)
                    self.save_images(sess, epoch, self.a_val, self.b_val, self._images_val)
                
                
                # Dealing with the learning rate as per the epoch number
                if epoch < 100:
                    curr_lr = self._base_lr
                else:
                    curr_lr = self._base_lr - \
                        self._base_lr * (epoch - 100) / 100

                
                # RUNNING THE TRAINING ON IMAGES  
                
                for i in range(0, self.n_batches):
                    
                    print("Processing batch {}/{}".format(i, self.n_batches))
                
                    batch_a = self.a_train[i*self.batch_size:(i+1)*self.batch_size]
                    batch_b = self.b_train[i*self.batch_size:(i+1)*self.batch_size]
                    batch_log_a = self.logit_a_train[i*self.batch_size:(i+1)*self.batch_size]

                    # Optimizing the G_A network
                    _, fake_B_temp, summary_str = sess.run(
                        [self.g_A_trainer,
                         self.fake_images_b,
                         self.g_A_loss_summ],
                        feed_dict={
                            self.input_a:batch_a,
                            self.input_b:batch_b,
                            self.learning_rate: curr_lr
                        }
                    )
                    writer.add_summary(summary_str, epoch * self.n_batches + i)

                    fake_B_temp1 = self.fake_image_pool(
                        self.num_fake_inputs, fake_B_temp, self.fake_images_B)

                    # Optimizing the D_B network
                    _, summary_str = sess.run(
                        [self.d_B_trainer, self.d_B_loss_summ],
                        feed_dict={
                                self.input_a:batch_a,
                                self.input_b:batch_b,
                            self.learning_rate: curr_lr,
                            self.fake_pool_B: fake_B_temp1
                        }
                    )
                    writer.add_summary(summary_str, epoch * self.n_batches + i)

                    # Optimizing the G_B network
                    _, fake_A_temp, summary_str = sess.run(
                        [self.g_B_trainer,
                         self.fake_images_a,
                         self.g_B_loss_summ],
                        feed_dict={
                            self.input_a:batch_a,
                            self.input_b:batch_b,
                            self.learning_rate: curr_lr
                        }
                    )
                    writer.add_summary(summary_str, epoch * self.n_batches + i)

                    fake_A_temp1 = self.fake_image_pool(
                        self.num_fake_inputs, fake_A_temp, self.fake_images_A)

                    # Optimizing the D_A network
                    _, summary_str = sess.run(
                        [self.d_A_trainer, self.d_A_loss_summ],
                        feed_dict={
                            self.input_a:batch_a,
                            self.input_b:batch_b,
                            self.learning_rate: curr_lr,
                            self.fake_pool_A: fake_A_temp1
                        }
                    )
                    writer.add_summary(summary_str, epoch * self.n_batches + i)
                    
                    # Optimizing the C_B network
                    _, summary_str = sess.run(
                        [self.c_B_trainer, self.c_B_loss_summ],
                        feed_dict={
                            self.input_a:batch_a,
                            self.input_b:batch_b,
                            self.logit_a:batch_log_a,
                            self.learning_rate: curr_lr,
                        }
                    )
                    writer.add_summary(summary_str, epoch * self.n_batches + i)
                    

                    writer.flush()
                    self.num_fake_inputs += 1

                sess.run(tf.assign(self.global_step, epoch + 1))

            writer.add_graph(sess.graph)

    def test(self):
        """Test Function."""
        print("Testing the results")

        
        self.model_setup()
        saver = tf.train.Saver()
        init = tf.global_variables_initializer()

        with tf.Session() as sess:
            
            sess.run(init)
            
            #loading_data
            self.load_data(sess)
            
            chkpt_fname = tf.train.latest_checkpoint(self._checkpoint_dir)
            saver.restore(sess, chkpt_fname)
            
            self.save_images(sess, 0, self.a_test, self.b_test, self._images_test)

            
def main(to_train, log_dir, dataset_name, checkpoint_dir, single_input_dir):
    """

    :param to_train: Specify whether it is training or testing. 1: training; 2:
     resuming from latest checkpoint; 0: testing.
    :param log_dir: The root dir to save checkpoints and imgs. The actual dir
    is the root dir appended by the folder with the name timestamp.
    :param config_filename: The configuration file.
    :param checkpoint_dir: The directory that saves the latest checkpoint. It
    only takes effect when to_train == 2.
    :param skip: A boolean indicating whether to add skip connection between
    input and output.
    """
    if not os.path.isdir(log_dir):
        os.makedirs(log_dir)

    to_restore = (to_train == 2)
    cyclegan_model = CycleGAN(to_restore, log_dir, dataset_name, checkpoint_dir, single_input_dir)
    
    if to_train > 0:
        cyclegan_model.train()
    else:
        cyclegan_model.test()

In [None]:
to_train=1

log_dir="./output/cyclegan_from_fluorospot_to_digital_diagnost/base_parameters_from_git/exp_01"
output_root_dir = "./output"
checkpoint_dir=""
single_input_dir= ""
dataset_name = 'hugxgan'

main(to_train, log_dir, dataset_name ,checkpoint_dir, single_input_dir)

Model/d_A/c1/Conv/weights:0
Model/d_A/c1/Conv/biases:0
Model/d_A/c2/Conv/weights:0
Model/d_A/c2/Conv/biases:0
Model/d_A/c2/instance_norm/scale:0
Model/d_A/c2/instance_norm/offset:0
Model/d_A/c3/Conv/weights:0
Model/d_A/c3/Conv/biases:0
Model/d_A/c3/instance_norm/scale:0
Model/d_A/c3/instance_norm/offset:0
Model/d_A/c4/Conv/weights:0
Model/d_A/c4/Conv/biases:0
Model/d_A/c4/instance_norm/scale:0
Model/d_A/c4/instance_norm/offset:0
Model/d_A/c5/Conv/weights:0
Model/d_A/c5/Conv/biases:0
Model/d_B/c1/Conv/weights:0
Model/d_B/c1/Conv/biases:0
Model/d_B/c2/Conv/weights:0
Model/d_B/c2/Conv/biases:0
Model/d_B/c2/instance_norm/scale:0
Model/d_B/c2/instance_norm/offset:0
Model/d_B/c3/Conv/weights:0
Model/d_B/c3/Conv/biases:0
Model/d_B/c3/instance_norm/scale:0
Model/d_B/c3/instance_norm/offset:0
Model/d_B/c4/Conv/weights:0
Model/d_B/c4/Conv/biases:0
Model/d_B/c4/instance_norm/scale:0
Model/d_B/c4/instance_norm/offset:0
Model/d_B/c5/Conv/weights:0
Model/d_B/c5/Conv/biases:0
Model/g_A/c1/Conv/weight

Processing batch 9/2938
Processing batch 10/2938
Processing batch 11/2938
Processing batch 12/2938
Processing batch 13/2938
Processing batch 14/2938
Processing batch 15/2938
Processing batch 16/2938
Processing batch 17/2938
Processing batch 18/2938
Processing batch 19/2938
Processing batch 20/2938
Processing batch 21/2938
Processing batch 22/2938
Processing batch 23/2938
Processing batch 24/2938
Processing batch 25/2938
Processing batch 26/2938
Processing batch 27/2938
Processing batch 28/2938
Processing batch 29/2938
Processing batch 30/2938
Processing batch 31/2938
Processing batch 32/2938
Processing batch 33/2938
Processing batch 34/2938
Processing batch 35/2938
Processing batch 36/2938
Processing batch 37/2938
Processing batch 38/2938
Processing batch 39/2938
Processing batch 40/2938
Processing batch 41/2938
Processing batch 42/2938
Processing batch 43/2938
Processing batch 44/2938
Processing batch 45/2938
Processing batch 46/2938
Processing batch 47/2938
Processing batch 48/2938
P

Processing batch 328/2938
Processing batch 329/2938
Processing batch 330/2938
Processing batch 331/2938
Processing batch 332/2938
Processing batch 333/2938
Processing batch 334/2938
Processing batch 335/2938
Processing batch 336/2938
Processing batch 337/2938
Processing batch 338/2938
Processing batch 339/2938
Processing batch 340/2938
Processing batch 341/2938
Processing batch 342/2938
Processing batch 343/2938
Processing batch 344/2938
Processing batch 345/2938
Processing batch 346/2938
Processing batch 347/2938
Processing batch 348/2938
Processing batch 349/2938
Processing batch 350/2938
Processing batch 351/2938
Processing batch 352/2938
Processing batch 353/2938
Processing batch 354/2938
Processing batch 355/2938
Processing batch 356/2938
Processing batch 357/2938
Processing batch 358/2938
Processing batch 359/2938
Processing batch 360/2938
Processing batch 361/2938
Processing batch 362/2938
Processing batch 363/2938
Processing batch 364/2938
Processing batch 365/2938
Processing b

Processing batch 644/2938
Processing batch 645/2938
Processing batch 646/2938
Processing batch 647/2938
Processing batch 648/2938
Processing batch 649/2938
Processing batch 650/2938
Processing batch 651/2938
Processing batch 652/2938
Processing batch 653/2938
Processing batch 654/2938
Processing batch 655/2938
Processing batch 656/2938
Processing batch 657/2938
Processing batch 658/2938
Processing batch 659/2938
Processing batch 660/2938
Processing batch 661/2938
Processing batch 662/2938
Processing batch 663/2938
Processing batch 664/2938
Processing batch 665/2938
Processing batch 666/2938
Processing batch 667/2938
Processing batch 668/2938
Processing batch 669/2938
Processing batch 670/2938
Processing batch 671/2938
Processing batch 672/2938
Processing batch 673/2938
Processing batch 674/2938
Processing batch 675/2938
Processing batch 676/2938
Processing batch 677/2938
Processing batch 678/2938
Processing batch 679/2938
Processing batch 680/2938
Processing batch 681/2938
Processing b

Processing batch 960/2938
Processing batch 961/2938
Processing batch 962/2938
Processing batch 963/2938
Processing batch 964/2938
Processing batch 965/2938
Processing batch 966/2938
Processing batch 967/2938
Processing batch 968/2938
Processing batch 969/2938
Processing batch 970/2938
Processing batch 971/2938
Processing batch 972/2938
Processing batch 973/2938
Processing batch 974/2938
Processing batch 975/2938
Processing batch 976/2938
Processing batch 977/2938
Processing batch 978/2938
Processing batch 979/2938
Processing batch 980/2938
Processing batch 981/2938
Processing batch 982/2938
Processing batch 983/2938
Processing batch 984/2938
Processing batch 985/2938
Processing batch 986/2938
Processing batch 987/2938
Processing batch 988/2938
Processing batch 989/2938
Processing batch 990/2938
Processing batch 991/2938
Processing batch 992/2938
Processing batch 993/2938
Processing batch 994/2938
Processing batch 995/2938
Processing batch 996/2938
Processing batch 997/2938
Processing b

Processing batch 1265/2938
Processing batch 1266/2938
Processing batch 1267/2938
Processing batch 1268/2938
Processing batch 1269/2938
Processing batch 1270/2938
Processing batch 1271/2938
Processing batch 1272/2938
Processing batch 1273/2938
Processing batch 1274/2938
Processing batch 1275/2938
Processing batch 1276/2938
Processing batch 1277/2938
Processing batch 1278/2938
Processing batch 1279/2938
Processing batch 1280/2938
Processing batch 1281/2938
Processing batch 1282/2938
Processing batch 1283/2938
Processing batch 1284/2938
Processing batch 1285/2938
Processing batch 1286/2938
Processing batch 1287/2938
Processing batch 1288/2938
Processing batch 1289/2938
Processing batch 1290/2938
Processing batch 1291/2938
Processing batch 1292/2938
Processing batch 1293/2938
Processing batch 1294/2938
Processing batch 1295/2938
Processing batch 1296/2938
Processing batch 1297/2938
Processing batch 1298/2938
Processing batch 1299/2938
Processing batch 1300/2938
Processing batch 1301/2938
P