In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline
from data_loader import Dataset
from utils import local_clock, preprocess, postprocess

  from ._conv import register_converters as _register_converters


In [2]:
# Generator
def generator(x, is_training, output_channels=4, filters = [64,64,128,256,512,512,512,512], kernel_size = 4, stride = 2):
    with tf.variable_scope('generator'):
        layers = []
        # Encoder:
        x = tf.layers.conv2d(inputs = x,
                             filters = filters[0],
                             kernel_size = 1,
                             strides = 1,
                             padding = 'same',
                             kernel_initializer = tf.contrib.layers.xavier_initializer()) 
        x = tf.layers.batch_normalization(x, training=is_training)
        x =  tf.nn.leaky_relu(x)
        layers.append(x)
        for i in range(1, len(filters)):
            x = tf.layers.conv2d(inputs = x,
                                 filters = filters[i],
                                 kernel_size = kernel_size,
                                 strides = stride,
                                 padding = 'same',
                                 kernel_initializer = tf.contrib.layers.xavier_initializer())  
            x = tf.layers.batch_normalization(x, training=is_training)
            x =  tf.nn.leaky_relu(x)
        # save contracting path layers to be used for skip connections
            layers.append(x)
            
        
        # Decoder:
        for i in reversed(range(len(filters)-1)):
            x = tf.layers.conv2d_transpose(inputs = x,
                                           filters = filters[i],
                                           kernel_size = kernel_size,
                                           strides = stride,
                                           padding = 'same',
                                           kernel_initializer = tf.contrib.layers.xavier_initializer())
            x = tf.layers.batch_normalization(x, training=is_training)
            x =  tf.nn.relu(x)
        # concat the layer from the contracting path with the output of the current layer
        # concat only the channels (axis=3)
            x = tf.concat([layers[i], x], axis=3)
            # layers.append(x)
        x = tf.layers.conv2d(inputs = x,
                             filters = output_channels,
                             kernel_size = 1,
                             strides = 1,
                             padding = 'same',
                             activation = tf.nn.tanh,
                             kernel_initializer = tf.contrib.layers.xavier_initializer())   
        # layers.append(x)
        # return layers
        return x

In [3]:
# Discriminator
def discriminator(x, is_training, filters = [64,128,256,512] , kernel_size = 4, stride = 2): # conditional GAN
    """
    filters: Integer, the dimensionality of the output space (i.e. the number of filters in the convolution).
    kernel_size: An integer or tuple/list of 2 integers, specifying the height and width of the 2D convolution window. 
                 Can be a single integer to specify the same value for all spatial dimensions.
    strides: An integer or tuple/list of 2 integers, specifying the strides of the convolution along the height and width. 
             Can be a single integer to specify the same value for all spatial dimensions. 
             Specifying any stride value != 1 is incompatible with specifying any dilation_rate value != 1.
    
    filters: a series of 4x4 convolutional layers with stride 2 with the number of channels being doubled after each downsampling.
    All convolution layers are followed by batch normalization, leaky ReLU activation. 
    After the last layer, a convolution is applied to map to a 1 dimensional output, 
        followed by a sigmoid function to return a probability value of the input being real or fake
    """
    
    with tf.variable_scope("discriminator"): 
        # layers = []
        for i in range(len(filters)):
            x = tf.layers.conv2d(inputs = x,
                                 filters = filters[i],
                                 kernel_size = kernel_size,
                                 strides = stride,
                                 padding = 'same',
                                 kernel_initializer = tf.contrib.layers.xavier_initializer())           
            if i != 0: # Do not use batch-norm in the first layer
                x = tf.layers.batch_normalization(x, training=is_training)
            x =  tf.nn.leaky_relu(x)
            # layers.append(x)
        x = tf.contrib.layers.flatten(x)
        logit = tf.layers.dense(inputs = x, units=1, kernel_initializer = tf.contrib.layers.xavier_initializer())
        # layers.append(logit)
        # return layers
        return logit

In [4]:
def gan_loss(logits_real, logits_fake):
    """Compute the GAN loss.
    
    Inputs:
    - logits_real: Tensor, shape [batch_size, 1], output of discriminator
        Unnormalized score that the image is real for each real image
    - logits_fake: Tensor, shape[batch_size, 1], output of discriminator
        Unnormalized score that the image is real for each fake image
    
    Returns:
    - D_loss: discriminator loss scalar
    - G_loss: generator loss scalar
    
    HINT: for the discriminator loss, you'll want to do the averaging separately for
    its two components, and then add them together (instead of averaging once at the very end).
    """
    # TODO: compute D_loss and G_loss
    G_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits_fake), logits=logits_fake)
    G_loss = tf.reduce_mean(G_loss)
    
    D_real_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits_real), logits=logits_real)
    D_real_loss = tf.reduce_mean(D_real_loss)
    D_fake_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits_fake), logits=logits_fake)
    D_fake_loss = tf.reduce_mean(D_fake_loss)
    D_loss = D_real_loss + D_fake_loss
    return D_loss, G_loss

In [5]:
def get_solvers(learning_rate=1e-3, beta1=0.5):
    """Create solvers for GAN training.
    
    Inputs:
    - learning_rate: learning rate to use for both solvers
    - beta1: beta1 parameter for both solvers (first moment decay)
    
    Returns:
    - D_solver: instance of tf.train.AdamOptimizer with correct learning_rate and beta1
    - G_solver: instance of tf.train.AdamOptimizer with correct learning_rate and beta1
    """
    D_solver = tf.train.AdamOptimizer(learning_rate, beta1)
    G_solver = tf.train.AdamOptimizer(learning_rate, beta1)
    return D_solver, G_solver

In [6]:
# Construct computational graph
device = '/gpu:0'
tf.reset_default_graph() # reset the graph
with tf.device(device):
    is_training = tf.placeholder(tf.bool, name='is_training')
    gray_img = tf.placeholder(tf.float32, [None, 256, 256, 1])
    color_img = tf.placeholder(tf.float32, [None, 256, 256, 4])
    
    pair_real = tf.concat([gray_img, color_img], axis=3)
    G_sample = generator(gray_img, is_training)
    pair_fake = tf.concat([gray_img, G_sample], axis=3)

    with tf.variable_scope('') as scope:
        logits_real = discriminator(pair_real, is_training)
        scope.reuse_variables()
        logits_fake = discriminator(pair_fake, is_training)
    
    # Get the list of trainable variables for the discriminator and generator
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator')
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')
    
    # Get solvers
    D_solver, G_solver = get_solvers()
    
    # Compute the losses
    D_loss, G_loss = gan_loss(logits_real, logits_fake)
    
    # Set up the training operations
    D_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'discriminator')
    with tf.control_dependencies(D_update_ops):
        D_train_op = D_solver.minimize(D_loss, var_list=D_vars)
    
    G_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, 'generator')
    with tf.control_dependencies(G_update_ops):
        G_train_op = G_solver.minimize(G_loss, var_list=G_vars)

Instructions for updating:
Use the retry module or similar alternatives.


In [7]:
# Get the data set
train_data = Dataset('data/gray_examples_256/', 'data/color_examples_256/', 16, 256)

In [8]:
# Training loop
num_epochs = 10
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(num_epochs):
        for t, (gray_img_np, color_img_np) in enumerate(train_data):
            gray_processed_np = preprocess(gray_img_np)
            color_processed_np = preprocess(color_img_np)
            _, D_loss_np = sess.run([D_train_op, D_loss], feed_dict={gray_img: gray_processed_np, 
                                                                     color_img: color_processed_np, 
                                                                     is_training: True})
            _, G_loss_np = sess.run([G_train_op, G_loss], feed_dict={gray_img: gray_processed_np, 
                                                                     color_img: color_processed_np, 
                                                                     is_training: True})
            print('Batch %d  D: %0.4f  G: %0.4f' % (t, D_loss_np, G_loss_np))
        print('Epoch %d  D: %0.4f  G: %0.4f' % (epoch, D_loss_np, G_loss_np))

Batch 0  D: 1.9758  G: 14.0985
Batch 1  D: 7.6710  G: 49.1197
Batch 2  D: 20.8722  G: 20.7676
Batch 3  D: 0.0000  G: 8.9930
Batch 4  D: 1.3719  G: 22.0266
Batch 5  D: 0.0100  G: 22.8914
Batch 6  D: 1.7281  G: 13.4394
Epoch 0  D: 1.7281  G: 13.4394
Batch 0  D: 0.1459  G: 15.8613
Batch 1  D: 1.5284  G: 30.8282
Batch 2  D: 0.0000  G: 26.0007
Batch 3  D: 0.0001  G: 28.1834
Batch 4  D: 0.0000  G: 21.7020
Batch 5  D: 0.6013  G: 51.2231
Batch 6  D: 15.5023  G: 5.1886
Epoch 1  D: 15.5023  G: 5.1886
Batch 0  D: 5.8843  G: 33.8772
Batch 1  D: 0.0000  G: 35.8727
Batch 2  D: 0.2627  G: 26.1515
Batch 3  D: 0.0000  G: 20.4975
Batch 4  D: 0.0000  G: 18.0373
Batch 5  D: 7.7809  G: 54.3373
Batch 6  D: 0.0000  G: 49.9602
Epoch 2  D: 0.0000  G: 49.9602
Batch 0  D: 0.1684  G: 39.0267
Batch 1  D: 0.0000  G: 44.3203
Batch 2  D: 0.0000  G: 33.0927
Batch 3  D: 0.0003  G: 17.4479
Batch 4  D: 0.1210  G: 36.5859
Batch 5  D: 0.0000  G: 39.8272
Batch 6  D: 0.0000  G: 31.4331
Epoch 3  D: 0.0000  G: 31.4331
Batch 0 

In [18]:
test_function = discriminator
device = '/gpu:0'
tf.reset_default_graph()
B, W, H, C = 16, 256, 256, 5
output_channels = 4
with tf.device(device):
    x = tf.zeros((B, H, W, C))
    layers = test_function(x)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    layers_list = sess.run([layers])
    layers_np_list = layers_list[0]
    for i in range(len(layers_np_list)):
        print(i)
        print(layers_np_list[i].shape)

0
(16, 128, 128, 64)
1
(16, 64, 64, 128)
2
(16, 32, 32, 256)
3
(16, 16, 16, 512)
4
(16, 1)


In [22]:
test_function = generator
device = '/gpu:0'
tf.reset_default_graph()
B, W, H, C = 16, 256, 256, 5
output_channels = 4
with tf.device(device):
    x = tf.zeros((B, H, W, C))
    output = test_function(x)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    output_np = sess.run(output)
    print(output_np.shape)

(16, 256, 256, 4)


In [60]:
from os import listdir
from shutil import copyfile
def copy_file(org_path_color, org_path_gray, dest_path_color, dest_path_gray, num_files = 100):
    files = [f for f in listdir(org_path_color)]
    for i in range(num_files):
        color_file = files[i]
        gray_file = 'gray_' + color_file
        print('Copying ' + color_file)
        copyfile(org_path_color + color_file, dest_path_color + color_file)
        print('Copying ' + gray_file)
        copyfile(org_path_gray + gray_file, dest_path_gray + gray_file)

In [61]:
copy_file('/home/shared/data/center_4/patches_256/color/', 
          '/home/shared/data/center_4/patches_256/gray/', 
          '/home/shared/chris/Histopathology-Imaging/data/color_examples_256/', 
          '/home/shared/chris/Histopathology-Imaging/data/gray_examples_256/')

Copying 080_4_1006_46464_27136.png
Copying gray_080_4_1006_46464_27136.png
Copying 088_3_3738_70656_29056.png
Copying gray_088_3_3738_70656_29056.png
Copying 090_1_1760_62592_47872.png
Copying gray_090_1_1760_62592_47872.png
Copying 093_1_3994_72064_30080.png
Copying gray_093_1_3994_72064_30080.png
Copying 086_3_3508_62336_48512.png
Copying gray_086_3_3508_62336_48512.png
Copying 097_0_3338_44288_26624.png
Copying gray_097_0_3338_44288_26624.png
Copying 099_0_1317_22144_44544.png
Copying gray_099_0_1317_22144_44544.png
Copying 086_3_4045_64640_49536.png
Copying gray_086_3_4045_64640_49536.png
Copying 095_1_4780_47232_43392.png
Copying gray_095_1_4780_47232_43392.png
Copying 088_3_1213_39168_49024.png
Copying gray_088_3_1213_39168_49024.png
Copying 091_4_929_33280_39424.png
Copying gray_091_4_929_33280_39424.png
Copying 086_3_1964_56448_44160.png
Copying gray_086_3_1964_56448_44160.png
Copying 086_3_1768_55680_53632.png
Copying gray_086_3_1768_55680_53632.png
Copying 080_4_198_37248_529