In [1]:
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

[name: "/cpu:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 16329700523200128798
, name: "/gpu:0"
device_type: "GPU"
memory_limit: 3775004672
locality {
  bus_id: 1
}
incarnation: 11438657818061364883
physical_device_desc: "device: 0, name: Quadro K2200, pci bus id: 0000:03:00.0"
]


In [3]:
import numpy as np
from keras.layers import Input
from keras.models import Model
import tensorflow as tf
import scipy
import matplotlib.pyplot as plt
from os import listdir
from os.path import isfile, join
import scipy.misc

from models import *
from losses import *
import random
from keras import backend as K
sess = tf.Session()
K.set_session(sess)

Using TensorFlow backend.


In [4]:
class Batch():
    def __init__(self, dirname):
        all_files = self.get_files(dirname)

        self.num_all = len(all_files)
        self.num_train = int(self.num_all*0.80)
        self.num_test = self.num_all - self.num_train

        random.shuffle(all_files)

        self.train_files = all_files[:self.num_train]
        self.test_files = all_files[self.num_train:]
        self.train_start = 0
        self.test_start = 0
        
    def get_files(self, dirname, mode='i'):
        img_dir = dirname + '/images/'
        contour_dir = dirname + '/contours_' + mode + '/'
        img_names = [join(img_dir, f) for f in listdir(img_dir) if isfile(join(img_dir, f))]
        contour_names = [join(contour_dir, f) for f in listdir(contour_dir) 
                        if isfile(join(contour_dir, f))]
        img_names.sort()
        contour_names.sort()
        return list(zip(img_names, contour_names))

    def image_augment(self, img1, img2):
        rand_num = np.random.rand()
        if rand_num < 0.50:
            img1 = np.fliplr(img1)
            img2 = np.fliplr(img2)
        rand_num = np.random.rand()
        if rand_num < 0.50:
            img1 = np.flipud(img1)
            img2 = np.flipud(img2)
        return img1, img2

    def read_images(self, files):
        batch_size = len(files)
        imgs = np.ones((batch_size, 256, 256, 1))
        contours = np.ones((batch_size, 256, 256, 1))
        img_contours = np.ones((batch_size, 256, 256, 2))
        for i in range(len(files)):
            img = scipy.misc.imread(files[i][0])[10:200, 10:200]
            contour = scipy.misc.imread(files[i][1])[10:200, 10:200]
            img, contour = self.image_augment(img, contour)
            img = scipy.misc.imresize(img, (256, 256))/255.0
            contour = scipy.misc.imresize(contour, (256, 256))/255.0
            imgs[i,:,:,:] = np.expand_dims(img, -1)#.transpose(2,0,1)
            contours[i,:,:,:] = np.expand_dims(contour, -1)
            img_contours[i,:,:,:] = np.stack([contour, img], axis=-1)
        return imgs, contours, img_contours


    def get_next_batch(self, batch_size):
        is_running = True
        end = min(self.train_start + batch_size, self.num_train)
        files = self.train_files[self.train_start:end]
        self.train_start += batch_size
        if end == self.num_train:
            self.train_start = 0
            random.shuffle(self.train_files)
            is_running = False
        imgs, contours, img_contours = self.read_images(files)
        return imgs, contours, img_contours, is_running
    
    def get_next_testbatch(self, batch_size):
        is_running = True
        end = min(self.test_start + batch_size, self.num_test)
        files = self.test_files[self.test_start:end]
        self.test_start += batch_size
        if end == self.num_test:
            self.test_start = 0
            is_running = False
        imgs, contours, img_contours = self.read_images(files)
        return imgs, contours, img_contours, is_running
    

In [5]:
dirname = './Train_Set'
batch = Batch(dirname)

In [6]:
img_size = 256
trainable = True
#c = 1e-2

enet_inputs_pl = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 1])
target_y = tf.placeholder(tf.float32, shape=[None, img_size, img_size, 1]) #2])

global_step = tf.Variable(0, trainable=False)
starter_learning_rate = 5e-3
lr = tf.train.exponential_decay(starter_learning_rate, global_step,
                                           100, 0.96, staircase=True)

enet_input = Input(shape=(img_size, img_size, 1))
enet_output = enet(enet_input, 1, activation='sigmoid', trainable=trainable)

enet_model = Model(inputs=enet_input, outputs=enet_output)

pred_mask = enet_model(enet_inputs_pl)

dice_loss_enet = dice_coef_loss(target_y, pred_mask)

train_op_enet = tf.train.AdamOptimizer(lr).minimize(dice_loss_enet, global_step=global_step)


init = tf.global_variables_initializer()

tf.summary.scalar("Dice_loss", dice_loss_enet)
tf.summary.scalar("Learning_rate", lr)

tf.summary.image("Generated_Mask", pred_mask, max_outputs=1)

tf.summary.image("Real_Mask", tf.expand_dims(target_y[:,:,:,0], -1), max_outputs=1)

merged_summary_op = tf.summary.merge_all()

In [None]:
num_epochs = 100
logs_path = './logs'
saver = tf.train.Saver()

batch_size = 10
with sess.as_default():
    sess.run(init)
    print('0')
    summary_writer_train = tf.summary.FileWriter(logs_path+'/train', graph=tf.get_default_graph())
    summary_writer_test = tf.summary.FileWriter(logs_path+'/test')
    count = 0
    for epoch in range(num_epochs):
        is_running = True
        while is_running:
            #print('Epoch',count)
            imgs, contours, _, is_running = batch.get_next_batch(batch_size)
            _, summary = sess.run([train_op_enet, merged_summary_op], 
                                 feed_dict={enet_inputs_pl: imgs, target_y:contours,
                                 K.learning_phase(): 1})
            summary_writer_train.add_summary(summary, count)
            count += 1
        saver.save(sess, 'RV_Segmentation_Dice.ckpt')
        is_running = True
        while is_running:
            imgs, contours, _, is_running = batch.get_next_testbatch(batch_size)
            _, summary = sess.run([dice_loss_enet, merged_summary_op], 
                                  feed_dict={enet_inputs_pl: imgs, target_y:contours,
                                  K.learning_phase(): 0})
            summary_writer_test.add_summary(summary, count)