In [1]:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0" #MAKE THIS 1 in your script

from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

[name: "/cpu:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 12997310710113250134
, name: "/gpu:0"
device_type: "GPU"
memory_limit: 3326738432
locality {
  bus_id: 1
}
incarnation: 6011853531831098513
physical_device_desc: "device: 0, name: GeForce GTX 970, pci bus id: 0000:02:00.0"
]


In [3]:
from tensorflow.python.client import device_lib

def get_available_gpus():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos if x.device_type == 'GPU']
get_available_gpus()

[u'/gpu:0']

In [4]:
import numpy as np
from keras.layers import Input
from keras.models import Model
import tensorflow as tf
import scipy
import matplotlib.pyplot as plt
from os import listdir
from os.path import isfile, join
import scipy.misc

from models import *
import random
from keras import backend as K
sess = tf.Session()
K.set_session(sess)

Using TensorFlow backend.


In [5]:
class Batch():
    def __init__(self, dirname):
        all_files = self.get_files(dirname)

        self.num_all = len(all_files)
        self.num_train = int(self.num_all*0.80)
        self.num_test = self.num_all - self.num_train

        random.shuffle(all_files)

        self.train_files = all_files[:self.num_train]
        self.test_files = all_files[self.num_train:]
        self.train_start = 0
        self.test_start = 0
        
    def get_files(self, dirname, mode='i'):
        img_dir = dirname + '/images/'
        contour_dir = dirname + '/contours_' + mode + '/'
        img_names = [join(img_dir, f) for f in listdir(img_dir) if isfile(join(img_dir, f))]
        contour_names = [join(contour_dir, f) for f in listdir(contour_dir) 
                        if isfile(join(contour_dir, f))]
        img_names.sort()
        contour_names.sort()
        return list(zip(img_names, contour_names))

    def image_augment(self, img1, img2):
        rand_num1 = np.random.rand()
        if rand_num1 < 0.50:
            img1 = np.fliplr(img1)
            img2 = np.fliplr(img2)
        rand_num2 = np.random.rand()
        if rand_num2 < 0.50:
            img1 = np.flipud(img1)
            img2 = np.flipud(img2)
        return img1, img2

    def read_images(self, files):
        batch_size = len(files)
        imgs = np.ones((batch_size, 256, 256, 1))
        contours = np.ones((batch_size, 256, 256, 1))
        img_contours = np.ones((batch_size, 256, 256, 2))
        for i in range(len(files)):
            img = scipy.misc.imread(files[i][0])
            contour = scipy.misc.imread(files[i][1])
            img, contour = self.image_augment(img, contour)
            img = scipy.misc.imresize(img, (256, 256))/255.0
            contour = scipy.misc.imresize(contour, (256, 256))/255.0
            imgs[i,:,:,:] = np.expand_dims(img, -1)#.transpose(2,0,1)
            contours[i,:,:,:] = np.expand_dims(contour, -1)
            img_contours[i,:,:,:] = np.stack([contour, img], axis=-1)
        return imgs, contours, img_contours


    def get_next_batch(self, batch_size):
        is_running = True
        end = min(self.train_start + batch_size, self.num_train)
        files = self.train_files[self.train_start:end]
        self.train_start += batch_size
        if end == self.num_train:
            self.train_start = 0
            random.shuffle(self.train_files)
            is_running = False
        imgs, contours, img_contours = self.read_images(files)
        return imgs, contours, img_contours, is_running
    
    def get_next_testbatch(self, batch_size):
        is_running = True
        end = min(self.test_start + batch_size, self.num_test)
        files = self.test_files[self.test_start:end]
        self.test_start += batch_size
        if end == self.num_test:
            self.test_start = 0
            is_running = False
        imgs, contours, img_contours = self.read_images(files)
        return imgs, contours, img_contours, is_running
    

In [6]:
dirname = './Train_Set'

batch = Batch(dirname)

In [7]:
img_size = 256
trainable = True
lr = 1e-4
c = 1e-2
Lambda = 10.0
batch_size = 10

enet_inputs_pl = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 1])
enet_inputs_test = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 1])

d_inputs_pl = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 2])
target_mask_test = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 2])

enet_input = Input(shape=(img_size, img_size, 1))
enet_output = enet_skip(enet_input, 1, activation='sigmoid',trainable=trainable)

d_input = Input(shape=(img_size, img_size, 2))
d_out = discriminator_noBN(d_input, trainable=trainable)

# Compiling the models for the E-Net and the Discriminator
enet_model = Model(inputs=enet_input, outputs=enet_output)
d_model = Model(inputs=d_input, outputs=d_out)

enet_weights = [w for w in tf.global_variables() if 'ENet' in w.name]
d_weights = [w for w in tf.global_variables() if 'discriminator' in w.name]

pred_mask = enet_model(enet_inputs_pl)
fake_inputs = tf.concat([pred_mask, enet_inputs_pl], axis=-1)

pred_mask_test = enet_model(enet_inputs_test)

# Calculating the gradient penalty
alpha = tf.random_uniform(shape=[tf.shape(enet_inputs_pl)[0],1,1,1], minval=0, maxval=1)
differences = fake_inputs - d_inputs_pl
interpolates = d_inputs_pl + alpha*differences
grads = tf.gradients(d_model(interpolates), interpolates)
slopes = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
grad_penalty = tf.reduce_mean((slopes - 1.0)**2)

# Expectation of the fake and real probability distribution
d_real = d_model(d_inputs_pl)
d_fake = d_model(fake_inputs)

loss_enet = -1*tf.reduce_mean(d_fake)
# The Wasserstien Loss
w_dist = tf.reduce_mean(d_real) - tf.reduce_mean(d_fake)
loss_d = tf.reduce_mean(d_fake) - tf.reduce_mean(d_real) + Lambda*grad_penalty

opt_enet = tf.train.AdamOptimizer(lr, beta1=0.5)
opt_d = tf.train.AdamOptimizer(lr, beta1=0.5)

grad_enet = opt_enet.compute_gradients(loss_enet, enet_weights)
grad_d = opt_d.compute_gradients(loss_d, d_weights)

train_op_enet = opt_enet.apply_gradients(grad_enet)
train_op_d = opt_d.apply_gradients(grad_d)

init = tf.global_variables_initializer()

# Creating summaries of the Wasserstien Distance, Generated and Real Mask
w_summary = tf.summary.scalar("Wasserstien_Distance", w_dist)

g_mask_sumry = tf.summary.image("Generated_Mask", 
                                tf.expand_dims(pred_mask[:,:,:,0], -1), 
                                max_outputs=1)

r_mask_sumry = tf.summary.image("Real_Mask", 
                                tf.expand_dims(d_inputs_pl[:,:,:,0], -1), 
                                max_outputs=1)

g_mask_sumry_t = tf.summary.image("Generated_Mask_Test", 
                                 tf.expand_dims(pred_mask_test[:,:,:,0], -1), 
                                 max_outputs=1)

r_mask_sumry_t = tf.summary.image("Real_Mask_Test", 
                                 tf.expand_dims(target_mask_test[:,:,:,0], -1), 
                                 max_outputs=1)

merged_summary_op = tf.summary.merge([w_summary, g_mask_sumry, r_mask_sumry])
merged_summary_op_t = tf.summary.merge([g_mask_sumry_t, r_mask_sumry_t])

In [None]:
num_epochs = 3000
logs_path = './enet_logs_wgangp'

saver = tf.train.Saver()

batch_size_test = 1
ep_critic = 10
with sess.as_default():
    sess.run(init)
    summary_writer = tf.summary.FileWriter(logs_path, graph=tf.get_default_graph())
    count=0
    for epoch in range(num_epochs):
        is_running = True
        while is_running:
            #print(count, end=' ')
            if count%100 == 0:
                ep_critic = 100
            else:
                ep_critic = 10 # Was 5
            for i in range(ep_critic):
                imgs, _, img_contours, is_running = batch.get_next_batch(batch_size)
                sess.run([train_op_d], 
                         feed_dict={enet_inputs_pl: imgs, d_inputs_pl:img_contours,
                                    K.learning_phase(): 1})
                if not is_running:
                    break
            imgs_test, _, img_contours_test,_ = batch.get_next_testbatch(batch_size_test)
            _, summary = sess.run([train_op_enet, merged_summary_op], 
                                  feed_dict={enet_inputs_pl: imgs, 
                                             d_inputs_pl:img_contours,
                                             K.learning_phase(): 1})
            summary_test = sess.run(merged_summary_op_t,
                                     feed_dict={enet_inputs_test:imgs_test,
                                               target_mask_test:img_contours_test,
                                               K.learning_phase(): 0})
            summary_writer.add_summary(summary, count)
            summary_writer.add_summary(summary_test, count)
            count += 1
        saver.save(sess, 'RV_SegmentationBN_Skip_NoCrop_WGANGP.ckpt')