In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import pandas as pd
import os.path
import os
from data_loader import Dataset
from utils import local_clock, preprocess, postprocess, show_images, save_image
from evaluate import evaluate_model
%matplotlib inline

  from ._conv import register_converters as _register_converters


In [2]:
# Generator
def generator(x, is_training, output_channels=4, filters = [128,128,128,256,512,512,512,512], kernel_size = 4, stride = 2):
    with tf.variable_scope('generator'):
        layers = []
        # Encoder:
        x = tf.layers.conv2d(inputs = x,
                             filters = filters[0],
                             kernel_size = 1,
                             strides = 1,
                             padding = 'same',
                             kernel_initializer = tf.contrib.layers.xavier_initializer()) 
        x = tf.layers.batch_normalization(x, training=is_training)
        x =  tf.nn.leaky_relu(x)
        layers.append(x)
        for i in range(1, len(filters)):
            x = tf.layers.conv2d(inputs = x,
                                 filters = filters[i],
                                 kernel_size = kernel_size,
                                 strides = stride,
                                 padding = 'same',
                                 kernel_initializer = tf.contrib.layers.xavier_initializer())  
            x = tf.layers.batch_normalization(x, training=is_training)
            x =  tf.nn.leaky_relu(x)
        # save contracting path layers to be used for skip connections
            layers.append(x)
            
        
        # Decoder:
        for i in reversed(range(len(filters)-1)):
            x = tf.layers.conv2d_transpose(inputs = x,
                                           filters = filters[i],
                                           kernel_size = kernel_size,
                                           strides = stride,
                                           padding = 'same',
                                           kernel_initializer = tf.contrib.layers.xavier_initializer())
            x = tf.layers.batch_normalization(x, training=is_training)
            x =  tf.nn.relu(x)
        # concat the layer from the contracting path with the output of the current layer
        # concat only the channels (axis=3)
            x = tf.concat([layers[i], x], axis=3)
            # layers.append(x)
        x = tf.layers.conv2d(inputs = x,
                             filters = output_channels,
                             kernel_size = 1,
                             strides = 1,
                             padding = 'same',
                             activation = tf.nn.tanh,
                             kernel_initializer = tf.contrib.layers.xavier_initializer())   
        # layers.append(x)
        # return layers
        return x

In [3]:
# Discriminator
def discriminator(x, is_training, filters = [64,128,256,512] , kernel_size = 4, stride = 2): # conditional GAN
    """
    filters: Integer, the dimensionality of the output space (i.e. the number of filters in the convolution).
    kernel_size: An integer or tuple/list of 2 integers, specifying the height and width of the 2D convolution window. 
                 Can be a single integer to specify the same value for all spatial dimensions.
    strides: An integer or tuple/list of 2 integers, specifying the strides of the convolution along the height and width. 
             Can be a single integer to specify the same value for all spatial dimensions. 
             Specifying any stride value != 1 is incompatible with specifying any dilation_rate value != 1.
    
    filters: a series of 4x4 convolutional layers with stride 2 with the number of channels being doubled after each downsampling.
    All convolution layers are followed by batch normalization, leaky ReLU activation. 
    After the last layer, a convolution is applied to map to a 1 dimensional output, 
        followed by a sigmoid function to return a probability value of the input being real or fake
    """
    
    with tf.variable_scope("discriminator"): 
        # layers = []
        for i in range(len(filters)):
            x = tf.layers.conv2d(inputs = x,
                                 filters = filters[i],
                                 kernel_size = kernel_size,
                                 strides = stride,
                                 padding = 'same',
                                 kernel_initializer = tf.contrib.layers.xavier_initializer())           
            if i != 0: # Do not use batch-norm in the first layer
                x = tf.layers.batch_normalization(x, training=is_training)
            x =  tf.nn.leaky_relu(x)
            # layers.append(x)
        x = tf.contrib.layers.flatten(x)
        logit = tf.layers.dense(inputs = x, units=1, kernel_initializer = tf.contrib.layers.xavier_initializer())
        # layers.append(logit)
        # return layers
        return logit

In [4]:
def gan_loss(logits_real, logits_fake):
    """Compute the GAN loss.
    
    Inputs:
    - logits_real: Tensor, shape [batch_size, 1], output of discriminator
        Unnormalized score that the image is real for each real image
    - logits_fake: Tensor, shape[batch_size, 1], output of discriminator
        Unnormalized score that the image is real for each fake image
    
    Returns:
    - D_loss: discriminator loss scalar
    - G_loss: generator loss scalar
    
    Note: For the discriminator loss, do the averaging separately for
    its two components, and then add them together (instead of averaging once at the very end).
    """
    G_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits_fake), logits=logits_fake)
    G_loss = tf.reduce_mean(G_loss)
    
    D_real_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits_real), logits=logits_real)
    D_real_loss = tf.reduce_mean(D_real_loss)
    D_fake_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits_fake), logits=logits_fake)
    D_fake_loss = tf.reduce_mean(D_fake_loss)
    D_loss = D_real_loss + D_fake_loss
    return D_loss, G_loss

In [5]:
def l1_loss(fake_imgs, real_imgs, reg=127.5):
    """
    Compute the L1 loss between fake images and real images.
    
    Inputs:
    - fake_imgs: Tensor with shape [batch_size, H, W, C], output of generator
    - real_imgs: Tensor with shape [batch_size, H, W, C], fed into the graph
    - reg: Float for the regularization constant. Default to 127.5 for RGBA scheme (0-255).
    
    Outputs:
    - loss: L1 loss scalar
    """
    fake_flat = tf.contrib.layers.flatten(fake_imgs)
    real_flat = tf.contrib.layers.flatten(real_imgs)
    loss = tf.reduce_mean(tf.abs(fake_flat - real_flat))
    return reg * loss

In [6]:
def get_solvers(D_lr=2e-4, G_lr=2e-4, beta1=0.5):
    """Create solvers for GAN training.
    
    Inputs:
    - D_lr: learning rate for the discriminator
    - G_lr: learning rate for the generator
    - beta1: beta1 parameter for both solvers (first moment decay)
    
    Returns:
    - D_solver: instance of tf.train.AdamOptimizer with correct learning_rate and beta1
    - G_solver: instance of tf.train.AdamOptimizer with correct learning_rate and beta1
    """
    D_solver = tf.train.AdamOptimizer(D_lr, beta1)
    G_solver = tf.train.AdamOptimizer(G_lr, beta1)
    return D_solver, G_solver

In [7]:
# Construct computational graph
device = '/gpu:0'
tf.reset_default_graph() # reset the graph
with tf.device(device):
    is_training = tf.placeholder(tf.bool, name='is_training')
    gray_img = tf.placeholder(tf.float32, [None, 256, 256, 1])
    color_img = tf.placeholder(tf.float32, [None, 256, 256, 4])
    
    pair_real = tf.concat([gray_img, color_img], axis=3)
    G_sample = generator(gray_img, is_training)
    pair_fake = tf.concat([gray_img, G_sample], axis=3)

    with tf.variable_scope('') as scope:
        logits_real = discriminator(pair_real, is_training)
        scope.reuse_variables()
        logits_fake = discriminator(pair_fake, is_training)
    
    # Get the list of trainable variables for the discriminator and generator
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')
    
    # Get solvers
    D_solver, G_solver = get_solvers()
    
    # Compute the losses
    D_loss, G_loss = gan_loss(logits_real, logits_fake)
    img_loss = l1_loss(G_sample, color_img)
    
    # Set up the training operations
    D_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'discriminator')
    with tf.control_dependencies(D_update_ops):
        D_train_op = D_solver.minimize(D_loss, var_list=D_vars)
    
    G_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'generator')
    with tf.control_dependencies(G_update_ops):
        G_train_op = G_solver.minimize(G_loss + img_loss, var_list=G_vars)

Instructions for updating:
Use the retry module or similar alternatives.


In [8]:
# Get the data set
train_data = Dataset('data/gray_examples_256/', 'data/color_examples_256/', 16, 256, shuffle=True)
train_example_data = Dataset('data/gray_examples_256/', 'data/color_examples_256/', 16, 256, shuffle=False)
val_data = Dataset('data/gray_examples_256/', 'data/color_examples_256/', 16, 256, shuffle=False)

In [9]:
# Hyperparameters
lr = 2e-4
beta1 = 0.5
reg = 127.5

# Master output directories
output_dir = '/home/shared/chris/Histopathology-Imaging/gan_l1_0528/'
val_dir = output_dir + 'val_results/'
val_img_dir = val_dir + 'imgs/'
train_dir = output_dir + 'train_results/'
trained_sess_dir = output_dir + 'trained_sess/'
if not os.path.exists(val_dir):
    os.makedirs(val_dir)
if not os.path.exists(val_img_dir):
    os.makedirs(val_img_dir)
if not os.path.exists(train_dir):
    os.makedirs(train_dir)
if not os.path.exists(trained_sess_dir):
    os.makedirs(trained_sess_dir)

# Output file paths
train_log_file = train_dir + 'train_log_lr={}_beta1={}_reg={}.txt'.format(lr, beta1, reg)
train_img_file = train_dir + 'train_gen_examples_epoch_'
val_log_file = val_dir + 'val_log_lr={}_beta1={}_reg={}.txt'.format(lr, beta1, reg)
val_csv_file = val_dir + 'val_metrics_lr={}_beta1={}_reg={}'.format(lr, beta1, reg)

# Initialize the log files
start_msg = local_clock() + '  Started training model with learning rate={}, beta1={}, reg={}\n'.format(lr, beta1, reg)
with open(train_log_file, 'w') as handle:
    handle.write(start_msg)
with open(val_log_file, 'w') as handle:
    handle.write(start_msg)

In [10]:
# Training loop
num_epochs = 10
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(num_epochs):
        print(local_clock() + '  Started epoch %d' % (epoch))
        for t, (gray_img_np, color_img_np) in enumerate(train_data):
            gray_processed_np = preprocess(gray_img_np)
            color_processed_np = preprocess(color_img_np)
            feed_dict = {gray_img: gray_processed_np, color_img: color_processed_np, is_training: True}
            _, D_loss_np = sess.run([D_train_op, D_loss], feed_dict=feed_dict)
            _, G_loss_np, img_loss_np = sess.run([G_train_op, G_loss, img_loss], feed_dict=feed_dict)

        # Save the results to the train log file
        epoch_train_time = local_clock() + '\n'
        epoch_train_msg = 'Epoch %d  D loss: %0.4f  G loss: %0.4f  img loss: %0.4f' % (epoch, D_loss_np, G_loss_np, img_loss_np)
        print(local_clock() + '  ' + epoch_train_msg)
        epoch_train_msg += '\n'
        with open(train_log_file, 'a') as handle:
            handle.write('\n')
            handle.write(epoch_train_time)
            handle.write(epoch_train_msg)
        
        # Save examples of generated images
        for j, (gray_example_np, color_example_np) in enumerate(train_example_data):
            gray_example_processed_np = preprocess(gray_example_np)
            color_example_processed_np = preprocess(color_example_np)
            break # only load the first batch as examples
        example_feed_dict = {gray_img: gray_example_processed_np, 
                             color_img: color_example_processed_np, 
                             is_training: False}
        gen_example_np = sess.run(G_sample, feed_dict=example_feed_dict)
        gen_example_np = postprocess(gen_example_np)
        show_images(gen_example_np, post_process=False, save=True, filepath=train_img_file + str(epoch) + '.png')
        
        # Evaluate on the validation data set
        val_log_note = 'Epoch ' + str(epoch)
        epoch_val_img_dir = val_img_dir + 'epoch' + str(epoch) + '/'
        if not os.path.exists(epoch_val_img_dir):
            os.makedirs(epoch_val_img_dir)
        epoch_val_csv = val_csv_file + '_epoch' + str(epoch) + '.csv'
        evaluate_model(sess=sess,
                       graph_gray=gray_img, 
                       graph_color=color_img, 
                       graph_training=is_training,
                       graph_D_loss=D_loss, 
                       graph_G_loss=G_loss, 
                       graph_img_loss=img_loss, 
                       graph_G_sample=G_sample, 
                       dataset=val_data, 
                       log_filename=val_log_file, 
                       log_note=val_log_note, 
                       csv_filename=epoch_val_csv, 
                       output_imgs=True, 
                       img_dir=epoch_val_img_dir)
        
        # Save the session when the epoch is done
        saver = tf.train.Saver()
        sess_name = trained_sess_dir + 'gan_epoch' + str(epoch)
        saver.save(sess, sess_name)

        print(local_clock() + '  Finished epoch %d' % (epoch))
        print('')

Mon May 28 22:07:44 2018  Started epoch 0
Mon May 28 22:07:51 2018  Epoch 0  D loss: 0.1589  G loss: 4.6346  img loss: 74.4301
Mon May 28 22:08:07 2018  Finished epoch 0

Mon May 28 22:08:07 2018  Started epoch 1
Mon May 28 22:08:10 2018  Epoch 1  D loss: 0.0974  G loss: 9.2042  img loss: 62.1175
Mon May 28 22:08:25 2018  Finished epoch 1

Mon May 28 22:08:25 2018  Started epoch 2
Mon May 28 22:08:28 2018  Epoch 2  D loss: 0.5330  G loss: 17.8175  img loss: 51.6276
Mon May 28 22:08:44 2018  Finished epoch 2

Mon May 28 22:08:44 2018  Started epoch 3
Mon May 28 22:08:46 2018  Epoch 3  D loss: 0.0134  G loss: 7.3397  img loss: 48.1348
Mon May 28 22:09:02 2018  Finished epoch 3

Mon May 28 22:09:02 2018  Started epoch 4
Mon May 28 22:09:05 2018  Epoch 4  D loss: 0.0735  G loss: 7.2856  img loss: 46.2591
Mon May 28 22:09:21 2018  Finished epoch 4

Mon May 28 22:09:21 2018  Started epoch 5
Mon May 28 22:09:24 2018  Epoch 5  D loss: 1.5592  G loss: 20.0963  img loss: 39.0119
Mon May 28 22:09

In [7]:
def train_gan(train_data_dir, val_data_dir, output_dir, D_lr, G_lr, beta1, reg, num_epochs, batch_size=16, eval_val=True, save_eval_img=True, device='/gpu:0', img_dim=256):
    # Set up output directories    
    val_dir = output_dir + 'val_results/'
    val_img_dir = val_dir + 'imgs/'
    train_dir = output_dir + 'train_results/'
    trained_sess_dir = output_dir + 'trained_sess/'
    if not os.path.exists(val_dir):
        os.makedirs(val_dir)
    if not os.path.exists(val_img_dir):
        os.makedirs(val_img_dir)
    if not os.path.exists(train_dir):
        os.makedirs(train_dir)
    if not os.path.exists(trained_sess_dir):
        os.makedirs(trained_sess_dir)

    # Output file paths
    train_log_file = train_dir + 'train_log_Dlr={}_Glr={}_beta1={}_reg={}.txt'.format(D_lr, G_lr, beta1, reg)
    train_img_file = train_dir + 'train_gen_examples_epoch_'
    val_log_file = val_dir + 'val_log_Dlr={}_Glr={}_beta1={}_reg={}.txt'.format(D_lr, G_lr, beta1, reg)
    val_csv_file = val_dir + 'val_metrics_Dlr={}_Glr={}_beta1={}_reg={}'.format(D_lr, G_lr, beta1, reg)

    # Initialize the log files
    start_msg = local_clock() + '  Started training model with D_lr={}, G_lr={}, beta1={}, reg={}\n'.format(D_lr, G_lr, beta1, reg)
    with open(train_log_file, 'w') as handle:
        handle.write(start_msg)
    with open(val_log_file, 'w') as handle:
        handle.write(start_msg)

    # Get the data set
    train_gray_dir = train_data_dir + 'gray/'
    train_color_dir = train_data_dir + 'color/'
    val_gray_dir = val_data_dir + 'gray/'
    val_color_dir = val_data_dir + 'color/'
    train_data = Dataset(train_gray_dir, train_color_dir, batch_size, img_dim, shuffle=True)
    train_example_data = Dataset(train_gray_dir, train_color_dir, batch_size, img_dim, shuffle=False)
    val_data = Dataset(val_gray_dir, val_color_dir, batch_size, img_dim, shuffle=False)
    
    # Construct computational graph
    tf.reset_default_graph() # reset the graph
    with tf.device(device):
        is_training = tf.placeholder(tf.bool, name='is_training')
        gray_img = tf.placeholder(tf.float32, [None, img_dim, img_dim, 1])
        color_img = tf.placeholder(tf.float32, [None, img_dim, img_dim, 4])

        pair_real = tf.concat([gray_img, color_img], axis=3)
        G_sample = generator(gray_img, is_training)
        pair_fake = tf.concat([gray_img, G_sample], axis=3)

        with tf.variable_scope('') as scope:
            logits_real = discriminator(pair_real, is_training)
            scope.reuse_variables()
            logits_fake = discriminator(pair_fake, is_training)

        # Get the list of trainable variables for the discriminator and generator
        D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
        G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')

        # Get solvers
        D_solver, G_solver = get_solvers(D_lr=D_lr, G_lr=G_lr, beta1=beta1)

        # Compute the losses
        D_loss, G_loss = gan_loss(logits_real, logits_fake)
        img_loss = l1_loss(G_sample, color_img, reg=reg)

        # Set up the training operations
        D_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'discriminator')
        with tf.control_dependencies(D_update_ops):
            D_train_op = D_solver.minimize(D_loss, var_list=D_vars)

        G_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'generator')
        with tf.control_dependencies(G_update_ops):
            G_train_op = G_solver.minimize(G_loss + img_loss, var_list=G_vars)

    # Training loop
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch in range(num_epochs):
            print(local_clock() + '  Started epoch %d' % (epoch))
            for t, (gray_img_np, color_img_np) in enumerate(train_data):
                gray_processed_np = preprocess(gray_img_np)
                color_processed_np = preprocess(color_img_np)
                feed_dict = {gray_img: gray_processed_np, color_img: color_processed_np, is_training: True}
                _, D_loss_np = sess.run([D_train_op, D_loss], feed_dict=feed_dict)
                _, G_loss_np, img_loss_np = sess.run([G_train_op, G_loss, img_loss], feed_dict=feed_dict)

            # Save the results to the train log file
            epoch_train_time = local_clock() + '\n'
            epoch_train_msg = 'Epoch %d  D loss: %0.4f  G loss: %0.4f  img loss: %0.4f' % (epoch, D_loss_np, G_loss_np, img_loss_np)
            print(local_clock() + '  ' + epoch_train_msg)
            epoch_train_msg += '\n'
            with open(train_log_file, 'a') as handle:
                handle.write('\n')
                handle.write(epoch_train_time)
                handle.write(epoch_train_msg)

            # Save examples of generated images
            for j, (gray_example_np, color_example_np) in enumerate(train_example_data):
                gray_example_processed_np = preprocess(gray_example_np)
                color_example_processed_np = preprocess(color_example_np)
                break # only load the first batch as examples
            example_feed_dict = {gray_img: gray_example_processed_np, 
                                 color_img: color_example_processed_np, 
                                 is_training: False}
            gen_example_np = sess.run(G_sample, feed_dict=example_feed_dict)
            gen_example_np = postprocess(gen_example_np)
            show_images(gen_example_np, post_process=False, save=True, filepath=train_img_file + str(epoch) + '.png')

            # If true, evaluate on the validation data set
            if eval_val:
                val_log_note = 'Epoch ' + str(epoch)
                epoch_val_img_dir = val_img_dir + 'epoch' + str(epoch) + '/'
                if not os.path.exists(epoch_val_img_dir):
                    os.makedirs(epoch_val_img_dir)
                epoch_val_csv = val_csv_file + '_epoch' + str(epoch) + '.csv'
                evaluate_model(sess=sess,
                               graph_gray=gray_img, 
                               graph_color=color_img, 
                               graph_training=is_training,
                               graph_D_loss=D_loss, 
                               graph_G_loss=G_loss, 
                               graph_img_loss=img_loss, 
                               graph_G_sample=G_sample, 
                               dataset=val_data, 
                               log_filename=val_log_file, 
                               log_note=val_log_note, 
                               csv_filename=epoch_val_csv, 
                               output_imgs=save_eval_img, 
                               img_dir=epoch_val_img_dir)

            # Save the session when the epoch is done
            saver = tf.train.Saver()
            sess_name = trained_sess_dir + 'gan_epoch' + str(epoch)
            saver.save(sess, sess_name)

            print(local_clock() + '  Finished epoch %d' % (epoch))
            print('')
    return

In [8]:
# Testing
train_gan('/home/shared/chris/Histopathology-Imaging/testing/', 
          '/home/shared/chris/Histopathology-Imaging/testing/', 
          '/home/shared/chris/Histopathology-Imaging/gan_l1_0529/', 
          D_lr=2e-4, 
          G_lr=2e-4, 
          beta1=0.5, 
          reg=1000, 
          num_epochs=15)

Instructions for updating:
Use the retry module or similar alternatives.
Tue May 29 17:07:49 2018  Started epoch 0
Tue May 29 17:07:58 2018  Epoch 0  D loss: 0.2878  G loss: 12.4007  img loss: 281.0776
Tue May 29 17:08:15 2018  Finished epoch 0

Tue May 29 17:08:15 2018  Started epoch 1
Tue May 29 17:08:19 2018  Epoch 1  D loss: 0.1007  G loss: 5.4090  img loss: 221.9584
Tue May 29 17:08:36 2018  Finished epoch 1

Tue May 29 17:08:36 2018  Started epoch 2
Tue May 29 17:08:40 2018  Epoch 2  D loss: 0.1460  G loss: 3.7714  img loss: 198.5429
Tue May 29 17:08:57 2018  Finished epoch 2

Tue May 29 17:08:57 2018  Started epoch 3
Tue May 29 17:09:01 2018  Epoch 3  D loss: 0.0007  G loss: 11.1351  img loss: 146.2451
Tue May 29 17:09:18 2018  Finished epoch 3

Tue May 29 17:09:18 2018  Started epoch 4
Tue May 29 17:09:22 2018  Epoch 4  D loss: 0.0018  G loss: 7.1131  img loss: 132.8643
Tue May 29 17:09:40 2018  Finished epoch 4

Tue May 29 17:09:40 2018  Started epoch 5
Tue May 29 17:09:44 201