# Training/Testing Code

Modify the dataset name to be 'DRIVE' or 'STARE' to perform training or testing on the respective dataset. 

If 'is_test' is False, then this code will perform training. If 'is_test' is True, it will produce predictions on the testing set.

The directory already has a model pretrained by us. If you want to train yourself:
<br>
<br>
  1) please set your desired dataset <br>
  2) then set is_test=False. <br>
  3) After execution is completed, there might be an error that says "An exception has occurred, use %tb to see the full traceback. SystemExit". Ignore this; it is just a consequence of transferring python code to a jupyter notebook. After you see this message, the model will be saved in a respective folder. <br>
  4) To test, set is_test=True. IMPORTANT: be sure to to restart the kernel and run all. This will produce predictions in the DRIVE/seg_result_image_10000_1 or STARE/seg_result_image_10000_1 folders. <br>

 <br>

In [None]:
#dataset = 'DRIVE'
dataset = 'STARE'

is_test = True
#is_test = False

In [1]:
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.python.training import moving_averages

import os
import sys

import pickle
import matplotlib
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image, ImageEnhance
from skimage import filters
from scipy.ndimage import rotate
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, auc, confusion_matrix

import random
from datetime import datetime

import time
import collections

import matplotlib.gridspec as gridspec


def linear(input_, output_size, stddev=0.02, bias_start=0.0, with_w=False, name='fc'):
    shape = input_.get_shape().as_list()
    # print('shape: ', shape)

    with tf.variable_scope(name) as scope:
        matrix = tf.get_variable(name="matrix", shape=[shape[1], output_size],
                                 dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
        bias = tf.get_variable(name="bias", shape=[output_size],
                               initializer=tf.constant_initializer(bias_start))
        if with_w:
            return tf.matmul(input_, matrix) + bias, matrix, bias
        else:
            return tf.matmul(input_, matrix) + bias


def batch_norm(x, name, _ops, is_train=True):
    """Batch normalization."""
    with tf.variable_scope(name):
        params_shape = [x.get_shape()[-1]]

        beta = tf.get_variable('beta', params_shape, tf.float32,
                               initializer=tf.constant_initializer(0.0, tf.float32))
        gamma = tf.get_variable('gamma', params_shape, tf.float32,
                                initializer=tf.constant_initializer(1.0, tf.float32))

        if is_train is True:
            mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments')

            moving_mean = tf.get_variable('moving_mean', params_shape, tf.float32,
                                          initializer=tf.constant_initializer(0.0, tf.float32),
                                          trainable=False)
            moving_variance = tf.get_variable('moving_variance', params_shape, tf.float32,
                                              initializer=tf.constant_initializer(1.0, tf.float32),
                                              trainable=False)

            _ops.append(moving_averages.assign_moving_average(moving_mean, mean, 0.9))
            _ops.append(moving_averages.assign_moving_average(moving_variance, variance, 0.9))
        else:
            mean = tf.get_variable('moving_mean', params_shape, tf.float32,
                                   initializer=tf.constant_initializer(0.0, tf.float32), trainable=False)
            variance = tf.get_variable('moving_variance', params_shape, tf.float32, trainable=False)

        # epsilon used to be 1e-5. Maybe 0.001 solves NaN problem in deeper net.
        y = tf.nn.batch_normalization(x, mean, variance, beta, gamma, 1e-5)
        y.set_shape(x.get_shape())

        return y


def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name='conv2d'):
    with tf.variable_scope(name):
        w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
                            initializer=tf.truncated_normal_initializer(stddev=stddev))
        conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')

        biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
        # conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
        conv = tf.nn.bias_add(conv, biases)

        return conv


def deconv2d(input_, output_shape, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name='deconv2d', with_w=False):
    with tf.variable_scope(name):
        # filter : [height, width, output_channels, in_channels]
        w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
                            initializer=tf.random_normal_initializer(stddev=stddev))
        deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1])

        biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))

        # deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
        deconv = tf.nn.bias_add(deconv, biases)

        if with_w:
            return deconv, w, biases
        else:
            return deconv


def upsampling2d(input_, size=(2, 2), name='upsampling2d'):
    with tf.name_scope(name):
        shape = input_.get_shape().as_list()
        return tf.image.resize_nearest_neighbor(input_, size=(size[0] * shape[1], size[1] * shape[2]))


def max_pool_2x2(x, name='max_pool'):
    with tf.name_scope(name):
        return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')


def lrelu(x, leak=0.2, name='lrelu'):
    return tf.maximum(x, leak*x, name=name)


def xavier_init(in_dim):
    print('in_dim: ', in_dim)
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return xavier_stddev


def print_activations(t):
    print(t.op.name, ' ', t.get_shape().as_list())


def show_all_variables():
    model_vars = tf.trainable_variables()
    slim.model_analyzer.analyze_vars(model_vars, print_info=True)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
def get_img_path(target_dir, dataset):
    img_files, vessel_files, mask_files = None, None, None
    if dataset == 'DRIVE':
        img_files, vessel_files, mask_files = DRIVE_files(target_dir)
    elif dataset == 'STARE':
        img_files, vessel_files, mask_files = STARE_files(target_dir)

    return img_files, vessel_files, mask_files


# noinspection PyPep8Naming
def STARE_files(data_path):
    img_dir = os.path.join(data_path, "images")
    vessel_dir = os.path.join(data_path, "1st_manual")
    mask_dir = os.path.join(data_path, "mask")

    img_files = all_files_under(img_dir, extension=".ppm")
    vessel_files = all_files_under(vessel_dir, extension=".ppm")
    mask_files = all_files_under(mask_dir, extension=".ppm")

    return img_files, vessel_files, mask_files


# noinspection PyPep8Naming
def DRIVE_files(data_path):
    img_dir = os.path.join(data_path, "images")
    vessel_dir = os.path.join(data_path, "1st_manual")
    mask_dir = os.path.join(data_path, "mask")

    img_files = all_files_under(img_dir, extension=".tif")
    vessel_files = all_files_under(vessel_dir, extension=".gif")
    mask_files = all_files_under(mask_dir, extension=".gif")

    return img_files, vessel_files, mask_files


def load_images_under_dir(path_dir):
    files = all_files_under(path_dir)
    return imagefiles2arrs(files)


def all_files_under(path, extension=None, append_path=True, sort=True):
    if append_path:
        if extension is None:
            filenames = [os.path.join(path, fname) for fname in os.listdir(path)]
        else:
            filenames = [os.path.join(path, fname)
                         for fname in os.listdir(path) if fname.endswith(extension)]
    else:
        if extension is None:
            filenames = [os.path.basename(fname) for fname in os.listdir(path)]
        else:
            filenames = [os.path.basename(fname)
                         for fname in os.listdir(path) if fname.endswith(extension)]

    if sort:
        filenames = sorted(filenames)

    return filenames


def imagefiles2arrs(filenames):
    img_shape = image_shape(filenames[0])
    images_arr = None

    if len(img_shape) == 3:
        images_arr = np.zeros((len(filenames), img_shape[0], img_shape[1], img_shape[2]), dtype=np.float32)
    elif len(img_shape) == 2:
        images_arr = np.zeros((len(filenames), img_shape[0], img_shape[1]), dtype=np.float32)

    for file_index in range(len(filenames)):
        img = Image.open(filenames[file_index])
        images_arr[file_index] = np.asarray(img).astype(np.float32)

    return images_arr


def get_train_batch(train_img_files, train_vessel_files, train_indices, img_size):
    batch_size = len(train_indices)
    batch_img_files, batch_vessel_files = [], []
    for _, idx in enumerate(train_indices):
        batch_img_files.append(train_img_files[idx])
        batch_vessel_files.append(train_vessel_files[idx])

    # load images
    fundus_imgs = imagefiles2arrs(batch_img_files)
    vessel_imgs = imagefiles2arrs(batch_vessel_files) / 255
    fundus_imgs = pad_imgs(fundus_imgs, img_size)
    vessel_imgs = pad_imgs(vessel_imgs, img_size)
    assert (np.min(vessel_imgs) == 0 and np.max(vessel_imgs) == 1)

    # random mirror flipping
    for idx in range(batch_size):
        if np.random.random() > 0.5:
            fundus_imgs[idx] = fundus_imgs[idx, :, ::-1, :]  # flipped imgs
            vessel_imgs[idx] = vessel_imgs[idx, :, ::-1]  # flipped vessel

    # flip_index = np.random.choice(batch_size, int(np.ceil(0.5 * batch_size)), replace=False)
    # fundus_imgs[flip_index] = fundus_imgs[flip_index, :, ::-1, :]  # flipped imgs
    # vessel_imgs[flip_index] = vessel_imgs[flip_index, :, ::-1]  # flipped vessel

    # random rotation
    for idx in range(batch_size):
        angle = np.random.randint(360)
        fundus_imgs[idx] = random_perturbation(rotate(input=fundus_imgs[idx], angle=angle, axes=(0, 1),
                                                      reshape=False, order=1))
        vessel_imgs[idx] = rotate(input=vessel_imgs[idx], angle=angle, axes=(0, 1), reshape=False, order=1)

    # z score with mean, std of each image
    for idx in range(batch_size):
        mean = np.mean(fundus_imgs[idx, ...][fundus_imgs[idx, ..., 0] > 40.0], axis=0)
        std = np.std(fundus_imgs[idx, ...][fundus_imgs[idx, ..., 0] > 40.0], axis=0)

        assert len(mean) == 3 and len(std) == 3
        fundus_imgs[idx, ...] = (fundus_imgs[idx, ...] - mean) / std

    return fundus_imgs, np.round(vessel_imgs)


def get_val_imgs(img_files, vessel_files, mask_files, img_size):
    # load images
    fundus_imgs = imagefiles2arrs(img_files)
    vessel_imgs = imagefiles2arrs(vessel_files) / 255
    mask_imgs = imagefiles2arrs(mask_files) / 255

    # padding
    fundus_imgs = pad_imgs(fundus_imgs, img_size)
    vessel_imgs = pad_imgs(vessel_imgs, img_size)
    mask_imgs = pad_imgs(mask_imgs, img_size)

    assert (np.min(vessel_imgs) == 0 and np.max(vessel_imgs) == 1)
    assert (np.min(mask_imgs) == 0 and np.max(mask_imgs) == 1)

    # augmentation
    # augment the original image (flip, rotate)
    all_fundus_imgs = [fundus_imgs]
    all_vessel_imgs = [vessel_imgs]
    all_mask_imgs = [mask_imgs]

    flipped_imgs = fundus_imgs[:, :, ::-1, :]  # flipped imgs
    flipped_vessels = vessel_imgs[:, :, ::-1]
    flipped_masks = mask_imgs[:, :, ::-1]

    all_fundus_imgs.append(flipped_imgs)
    all_vessel_imgs.append(flipped_vessels)
    all_mask_imgs.append(flipped_masks)

    for angle in range(3, 360, 3):  # rotated imgs (3, 360, 3)
        print("Val data augmentation {} degree...".format(angle))
        all_fundus_imgs.append(random_perturbation(rotate(fundus_imgs, angle, axes=(1, 2), reshape=False,
                                                          order=1)))
        all_fundus_imgs.append(random_perturbation(rotate(flipped_imgs, angle, axes=(1, 2), reshape=False,
                                                          order=1)))

        all_vessel_imgs.append(rotate(vessel_imgs, angle, axes=(1, 2), reshape=False, order=1))
        all_vessel_imgs.append(rotate(flipped_vessels, angle, axes=(1, 2), reshape=False, order=1))

        all_mask_imgs.append(rotate(mask_imgs, angle, axes=(1, 2), reshape=False, order=1))
        all_mask_imgs.append(rotate(flipped_masks, angle, axes=(1, 2), reshape=False, order=1))

    fundus_imgs = np.concatenate(all_fundus_imgs, axis=0)
    vessel_imgs = np.concatenate(all_vessel_imgs, axis=0)
    mask_imgs = np.concatenate(all_mask_imgs, axis=0)

    # z score with mean, std of each image
    mean_std = []
    n_all_imgs = fundus_imgs.shape[0]
    for index in range(n_all_imgs):
        mean = np.mean(fundus_imgs[index, ...][fundus_imgs[index, ..., 0] > 40.0], axis=0)
        std = np.std(fundus_imgs[index, ...][fundus_imgs[index, ..., 0] > 40.0], axis=0)

        assert len(mean) == 3 and len(std) == 3
        fundus_imgs[index, ...] = (fundus_imgs[index, ...] - mean) / std

        mean_std.append({'mean': mean, 'std': std})

    return fundus_imgs, np.round(vessel_imgs), np.round(mask_imgs), mean_std


def get_test_imgs(target_dir, img_size, dataset):
    img_files, vessel_files, mask_files, mask_imgs = None, None, None, None
    if dataset == 'DRIVE':
        img_files, vessel_files, mask_files = DRIVE_files(target_dir)
    elif dataset == 'STARE':
        img_files, vessel_files, mask_files = STARE_files(target_dir)

    # load images
    fundus_imgs = imagefiles2arrs(img_files)
    vessel_imgs = imagefiles2arrs(vessel_files) / 255
    fundus_imgs = pad_imgs(fundus_imgs, img_size)
    vessel_imgs = pad_imgs(vessel_imgs, img_size)
    assert (np.min(vessel_imgs) == 0 and np.max(vessel_imgs) == 1)

    mask_imgs = imagefiles2arrs(mask_files) / 255
    mask_imgs = pad_imgs(mask_imgs, img_size)
    assert (np.min(mask_imgs) == 0 and np.max(mask_imgs) == 1)

    # z score with mean, std of each image
    mean_std = []
    n_all_imgs = fundus_imgs.shape[0]
    for index in range(n_all_imgs):
        mean = np.mean(fundus_imgs[index, ...][fundus_imgs[index, ..., 0] > 40.0], axis=0)
        std = np.std(fundus_imgs[index, ...][fundus_imgs[index, ..., 0] > 40.0], axis=0)

        assert len(mean) == 3 and len(std) == 3
        fundus_imgs[index, ...] = (fundus_imgs[index, ...] - mean) / std

        mean_std.append({'mean': mean, 'std': std})

    return fundus_imgs, np.round(vessel_imgs), mask_imgs, mean_std


def image_shape(filename):
    img = Image.open(filename)
    img_arr = np.asarray(img)
    img_shape = img_arr.shape
    return img_shape


def pad_imgs(imgs, img_size):
    padded = None
    img_h, img_w = imgs.shape[1], imgs.shape[2]
    target_h, target_w = img_size[0], img_size[1]
    if len(imgs.shape) == 4:
        d = imgs.shape[3]
        padded = np.zeros((imgs.shape[0], target_h, target_w, d))
    elif len(imgs.shape) == 3:
        padded = np.zeros((imgs.shape[0], img_size[0], img_size[1]))

    start_h, start_w = (target_h - img_h) // 2, (target_w - img_w) // 2
    end_h, end_w = start_h + img_h, start_w + img_w
    padded[:, start_h:end_h, start_w:end_w, ...] = imgs

    return padded


def crop_to_original(imgs, ori_shape):
    # imgs: (N, 640, 640, 3 or None)
    # ori_shape: (584, 565)
    pred_shape = imgs.shape
    assert len(pred_shape) > 2

    if ori_shape == pred_shape:
        return imgs
    else:
        if len(imgs.shape) > 3:  # images (N, 640, 640, 3)
            ori_h, ori_w = ori_shape[0], ori_shape[1]
            pred_h, pred_w = pred_shape[1], pred_shape[2]

            start_h, start_w = (pred_h - ori_h) // 2, (pred_w - ori_w) // 2
            end_h, end_w = start_h + ori_h, start_w + ori_w

            return imgs[:, start_h:end_h, start_w:end_w, :]
        else:  # vesels
            ori_h, ori_w = ori_shape[0], ori_shape[1]
            pred_h, pred_w = pred_shape[1], pred_shape[2]

            start_h, start_w = (pred_h - ori_h) // 2, (pred_w - ori_w) // 2
            end_h, end_w = start_h + ori_h, start_w + ori_w

            return imgs[:, start_h:end_h, start_w:end_w]


def random_perturbation(imgs):
    for i in range(imgs.shape[0]):
        im = Image.fromarray(imgs[i, ...].astype(np.uint8))
        en = ImageEnhance.Color(im)
        im = en.enhance(np.random.uniform(0.8, 1.2))
        imgs[i, ...] = np.asarray(im).astype(np.float32)

    return imgs


def pixel_values_in_mask(true_vessels, pred_vessels, masks, split_by_img=False):
    assert np.max(pred_vessels) <= 1.0 and np.min(pred_vessels) >= 0.0
    assert np.max(true_vessels) == 1.0 and np.min(true_vessels) == 0.0
    assert np.max(masks) == 1.0 and np.min(masks) == 0.0
    assert pred_vessels.shape[0] == true_vessels.shape[0] and masks.shape[0] == true_vessels.shape[0]
    assert pred_vessels.shape[1] == true_vessels.shape[1] and masks.shape[1] == true_vessels.shape[1]
    assert pred_vessels.shape[2] == true_vessels.shape[2] and masks.shape[2] == true_vessels.shape[2]

    if split_by_img:
        n = pred_vessels.shape[0]
        return (np.array([true_vessels[i, ...][masks[i, ...] == 1].flatten() for i in range(n)]),
                np.array([pred_vessels[i, ...][masks[i, ...] == 1].flatten() for i in range(n)]))
    else:
        return true_vessels[masks == 1].flatten(), pred_vessels[masks == 1].flatten()


def remain_in_mask(imgs, masks):
    imgs[masks == 0] = 0
    return imgs


# noinspection PyPep8Naming
def AUC_ROC(true_vessel_arr, pred_vessel_arr):
    """
    Area under the ROC curve with x axis flipped
    ROC: Receiver operating characteristic
    """
    # roc_auc_score: sklearn function
    AUC_ROC_ = roc_auc_score(true_vessel_arr.flatten(), pred_vessel_arr.flatten())
    return AUC_ROC_


# noinspection PyPep8Naming
def AUC_PR(true_vessel_arr, pred_vessel_arr):
    """
    Precision-recall curve: sklearn function
    auc: Area Under Curve, sklearn function
    """
    precision, recall, _ = precision_recall_curve(true_vessel_arr.flatten(),
                                                  pred_vessel_arr.flatten(), pos_label=1)
    AUC_prec_rec = auc(recall, precision)
    return AUC_prec_rec


def threshold_by_f1(true_vessels, generated, masks, flatten=True, f1_score=False):
    vessels_in_mask, generated_in_mask = pixel_values_in_mask(true_vessels, generated, masks)
    precision, recall, thresholds = precision_recall_curve(
        vessels_in_mask.flatten(), generated_in_mask.flatten(), pos_label=1)
    best_f1, best_threshold = best_f1_threshold(precision, recall, thresholds)

    pred_vessels_bin = np.zeros(generated.shape)
    pred_vessels_bin[generated >= best_threshold] = 1

    if flatten:
        if f1_score:
            return pred_vessels_bin[masks == 1].flatten(), best_f1
        else:
            return pred_vessels_bin[masks == 1].flatten()
    else:
        if f1_score:
            return pred_vessels_bin, best_f1
        else:
            return pred_vessels_bin


def best_f1_threshold(precision, recall, thresholds):
    best_f1, best_threshold = -1., None
    for index in range(len(precision)):
        curr_f1 = 2. * precision[index] * recall[index] / (precision[index] + recall[index])
        if best_f1 < curr_f1:
            best_f1 = curr_f1
            best_threshold = thresholds[index]

    return best_f1, best_threshold


def threshold_by_otsu(pred_vessels, masks, flatten=True):
    # cut by otsu threshold
    threshold = filters.threshold_otsu(pred_vessels[masks == 1])
    pred_vessels_bin = np.zeros(pred_vessels.shape)
    pred_vessels_bin[pred_vessels >= threshold] = 1

    if flatten:
        return pred_vessels_bin[masks == 1].flatten()
    else:
        return pred_vessels_bin


def dice_coefficient_in_train(true_vessel_arr, pred_vessel_arr):
    true_vessel_arr = true_vessel_arr.astype(np.bool)
    pred_vessel_arr = pred_vessel_arr.astype(np.bool)

    intersection = np.count_nonzero(true_vessel_arr & pred_vessel_arr)

    size1 = np.count_nonzero(true_vessel_arr)
    size2 = np.count_nonzero(pred_vessel_arr)

    try:
        dc = 2. * intersection / float(size1 + size2)
    except ZeroDivisionError:
        dc = 0.0

    return dc


def misc_measures(true_vessel_arr, pred_vessel_arr):
    cm = confusion_matrix(true_vessel_arr, pred_vessel_arr)
    acc = 1. * (cm[0, 0] + cm[1, 1]) / np.sum(cm)
    sensitivity = 1. * cm[1, 1] / (cm[1, 0] + cm[1, 1])
    specificity = 1. * cm[0, 0] / (cm[0, 1] + cm[0, 0])
    return acc, sensitivity, specificity


def difference_map(ori_vessel, pred_vessel, mask):
    # ori_vessel : an RGB image
    thresholded_vessel = threshold_by_f1(np.expand_dims(ori_vessel, axis=0),
                                         np.expand_dims(pred_vessel, axis=0),
                                         np.expand_dims(mask, axis=0), flatten=False)

    thresholded_vessel = np.squeeze(thresholded_vessel, axis=0)
    diff_map = np.zeros((ori_vessel.shape[0], ori_vessel.shape[1], 3))

    # Green (overlapping)
    diff_map[(ori_vessel == 1) & (thresholded_vessel == 1)] = (0, 255, 0)
    # Red (false negative, missing in pred)
    diff_map[(ori_vessel == 1) & (thresholded_vessel != 1)] = (255, 0, 0)
    # Blue (false positive)
    diff_map[(ori_vessel != 1) & (thresholded_vessel == 1)] = (0, 0, 255)

    # compute dice coefficient for a given image
    overlap = len(diff_map[(ori_vessel == 1) & (thresholded_vessel == 1)])
    fn = len(diff_map[(ori_vessel == 1) & (thresholded_vessel != 1)])
    fp = len(diff_map[(ori_vessel != 1) & (thresholded_vessel == 1)])

    return diff_map, 2. * overlap / (2 * overlap + fn + fp)


def operating_pts_human_experts(gt_vessels, pred_vessels, masks):
    gt_vessels_in_mask, pred_vessels_in_mask = pixel_values_in_mask(
        gt_vessels, pred_vessels, masks, split_by_img=True)

    n = gt_vessels_in_mask.shape[0]
    op_pts_roc, op_pts_pr = [], []
    for i in range(n):
        cm = confusion_matrix(gt_vessels_in_mask[i], pred_vessels_in_mask[i])
        fpr = 1 - 1. * cm[0, 0] / (cm[0, 1] + cm[0, 0])
        tpr = 1. * cm[1, 1] / (cm[1, 0] + cm[1, 1])
        prec = 1. * cm[1, 1] / (cm[0, 1] + cm[1, 1])
        recall = tpr
        op_pts_roc.append((fpr, tpr))
        op_pts_pr.append((recall, prec))

    return op_pts_roc, op_pts_pr


def misc_measures_evaluation(true_vessels, pred_vessels, masks):
    thresholded_vessel_arr, f1_score = threshold_by_f1(true_vessels, pred_vessels, masks, f1_score=True)
    true_vessel_arr = true_vessels[masks == 1].flatten()

    cm = confusion_matrix(true_vessel_arr, thresholded_vessel_arr)
    acc = 1. * (cm[0, 0] + cm[1, 1]) / np.sum(cm)
    sensitivity = 1. * cm[1, 1] / (cm[1, 0] + cm[1, 1])
    specificity = 1. * cm[0, 0] / (cm[0, 1] + cm[0, 0])
    return f1_score, acc, sensitivity, specificity


def dice_coefficient(true_vessels, pred_vessels, masks):
    thresholded_vessels = threshold_by_f1(true_vessels, pred_vessels, masks, flatten=False)

    true_vessels = true_vessels.astype(np.bool)
    thresholded_vessels = thresholded_vessels.astype(np.bool)

    intersection = np.count_nonzero(true_vessels & thresholded_vessels)

    size1 = np.count_nonzero(true_vessels)
    size2 = np.count_nonzero(thresholded_vessels)

    try:
        dc = 2. * intersection / float(size1 + size2)
    except ZeroDivisionError:
        dc = 0.0

    return dc


def save_obj(true_vessel_arr, pred_vessel_arr, auc_roc_file_name, auc_pr_file_name):
    fpr, tpr, _ = roc_curve(true_vessel_arr, pred_vessel_arr)  # roc_curve: sklearn function

    precision, recall, _ = precision_recall_curve(true_vessel_arr.flatten(),
                                                  pred_vessel_arr.flatten(), pos_label=1)

    with open(auc_roc_file_name, 'wb') as f:
        pickle.dump({"fpr": fpr, "tpr": tpr}, f, pickle.HIGHEST_PROTOCOL)

    with open(auc_pr_file_name, 'wb') as f:
        pickle.dump({"precision": precision, "recall": recall}, f, pickle.HIGHEST_PROTOCOL)


def print_metrics(itr, kargs):
    print("*** Iteration {}  ====> ".format(itr))
    for name, value in kargs.items():
        print("{} : {:.6}, ".format(name, value))
    print("")
    sys.stdout.flush()


# noinspection PyPep8Naming
def plot_AUC_ROC(fprs, tprs, method_names, fig_dir, op_pts):
    # set font style
    font = {'family': 'serif'}
    matplotlib.rc('font', **font)

    # sort the order of plots manually for eye-pleasing plots
    colors = ['r', 'b', 'y', 'g', '#7e7e7e', 'm', 'c', 'k'] if len(fprs) == 8 \
        else ['r', 'y', 'm', 'g', 'k']
    indices = [7, 2, 5, 3, 4, 6, 1, 0] if len(fprs) == 8 else [4, 1, 2, 3, 0]

    # print auc
    print("****** ROC AUC ******")
    print("CAVEAT : AUC of V-GAN with 8bit images might be lower than the floating point array "
          "(check <home>/pretrained/auc_roc*.npy)")

    for index in indices:
        if method_names[index] != 'CRFs' and method_names[index] != '2nd_manual':
            print("{} : {:.4}".format(method_names[index], auc(fprs[index], tprs[index])))

    # plot results
    for index in indices:
        if method_names[index] == 'CRFs':
            plt.plot(fprs[index], tprs[index], colors[index] + '*', label=method_names[index].replace("_", " "))
        elif method_names[index] == '2nd_manual':
            plt.plot(fprs[index], tprs[index], colors[index] + '*', label='Human')
        else:
            plt.step(fprs[index], tprs[index], colors[index], where='post',
                     label=method_names[index].replace("_", " "), linewidth=1.5)

    # plot individual operation points
    for op_pt in op_pts:
        plt.plot(op_pt[0], op_pt[1], 'r.')

    plt.title('ROC Curve')
    plt.xlabel("1-Specificity")
    plt.ylabel("Sensitivity")
    plt.xlim(0, 0.3)
    plt.ylim(0.7, 1.0)
    plt.legend(loc="lower right")
    plt.savefig(os.path.join(fig_dir, "ROC.png"))
    plt.close()


# noinspection PyPep8Naming
def plot_AUC_PR(precisions, recalls, method_names, fig_dir, op_pts):
    # set font style
    font = {'family': 'serif'}
    matplotlib.rc('font', **font)

    # sort the order of plots manually for eye-pleasing plots
    colors = ['r', 'b', 'y', 'g', '#7e7e7e', 'm', 'c', 'k'] if len(precisions) == 8 \
        else ['r', 'y', 'm', 'g', 'k']
    indices = [7, 2, 5, 3, 4, 6, 1, 0] if len(precisions) == 8 else [4, 1, 2, 3, 0]

    # print auc
    print("****** Precision Recall AUC ******")
    print("CAVEAT : AUC of V-GAN with 8bit images might be lower than the floating point array "
          "(check <home>/pretrained/auc_pr*.npy)")

    for index in indices:
        if method_names[index] != 'CRFs' and method_names[index] != '2nd_manual':
            print("{} : {:.4}".format(method_names[index], auc(recalls[index], precisions[index])))

    # plot results
    for index in indices:
        if method_names[index] == 'CRFs':
            plt.plot(recalls[index], precisions[index], colors[index] + '*',
                     label=method_names[index].replace("_", " "))
        elif method_names[index] == '2nd_manual':
            plt.plot(recalls[index], precisions[index], colors[index] + '*', label='Human')
        else:
            plt.step(recalls[index], precisions[index], colors[index], where='post',
                     label=method_names[index].replace("_", " "), linewidth=1.5)

    # plot individual operation points
    for op_pt in op_pts:
        plt.plot(op_pt[0], op_pt[1], 'r.')

    plt.title('Precision Recall Curve')
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.xlim(0.5, 1.0)
    plt.ylim(0.5, 1.0)
    plt.legend(loc="lower left")
    plt.savefig(os.path.join(fig_dir, "Precision_recall.png"))
    plt.close()

In [3]:
class Dataset(object):
    def __init__(self, dataset, flags):
        self.dataset = dataset
        self.flags = flags

        self.image_size = (640, 640) if self.dataset == 'DRIVE' else (720, 720)
        self.ori_shape = (584, 565) if self.dataset == 'DRIVE' else (605, 700)
        self.val_ratio = 0.1  # 10% of the training data are used as validation data
        self.train_dir = "data/{}/training/".format(self.dataset)
        self.test_dir = "data/{}/test/".format(self.dataset)

        self.num_train, self.num_val, self.num_test = 0, 0, 0

        self._read_data()  # read training, validation, and test data)
        print('num of training images: {}'.format(self.num_train))
        print('num of validation images: {}'.format(self.num_val))
        print('num of test images: {}'.format(self.num_test))

    def _read_data(self):
        if self.flags.is_test:
            # real test images and vessels in the memory
            self.test_imgs, self.test_vessels, self.test_masks, self.test_mean_std = get_test_imgs(
                target_dir=self.test_dir, img_size=self.image_size, dataset=self.dataset)
            self.test_img_files = all_files_under(os.path.join(self.test_dir, 'images'))

            self.num_test = self.test_imgs.shape[0]

        elif not self.flags.is_test:
            random.seed(datetime.now())  # set random seed
            self.train_img_files, self.train_vessel_files, mask_files = get_img_path(
                self.train_dir, self.dataset)

            self.num_train = int(len(self.train_img_files))
            self.num_val = int(np.floor(self.val_ratio * int(len(self.train_img_files))))
            self.num_train -= self.num_val

            self.val_img_files = self.train_img_files[-self.num_val:]
            self.val_vessel_files = self.train_vessel_files[-self.num_val:]
            val_mask_files = mask_files[-self.num_val:]
            self.train_img_files = self.train_img_files[:-self.num_val]
            self.train_vessel_files = self.train_vessel_files[:-self.num_val]

            # read val images and vessels in the memory
            self.val_imgs, self.val_vessels, self.val_masks, self.val_mean_std = get_val_imgs(
                self.val_img_files, self.val_vessel_files, val_mask_files, img_size=self.image_size)

            self.num_val = self.val_imgs.shape[0]

    def train_next_batch(self, batch_size):
        train_indices = np.random.choice(self.num_train, batch_size, replace=True)
        train_imgs, train_vessels = get_train_batch(
            self.train_img_files, self.train_vessel_files, train_indices.astype(np.int32),
            img_size=self.image_size)
        train_vessels = np.expand_dims(train_vessels, axis=3)

        return train_imgs, train_vessels

In [4]:
class Solver(object):
    def __init__(self, flags):
        run_config = tf.ConfigProto()
        run_config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=run_config)

        self.flags = flags
        self.dataset = Dataset(self.flags.dataset, self.flags)
        self.model = CGAN(self.sess, self.flags, self.dataset.image_size)

        self.best_auc_sum = 0.
        self._make_folders()

        self.saver = tf.train.Saver()
        self.sess.run(tf.global_variables_initializer())

        show_all_variables()

    def _make_folders(self):
        self.model_out_dir = "{}/model_{}_{}_{}".format(self.flags.dataset, self.flags.discriminator,
                                                        self.flags.train_interval, self.flags.batch_size)
        if not os.path.isdir(self.model_out_dir):
            os.makedirs(self.model_out_dir)

        if self.flags.is_test:
            self.img_out_dir = "{}/seg_result_{}_{}_{}".format(self.flags.dataset,
                                                               self.flags.discriminator,
                                                               self.flags.train_interval,
                                                               self.flags.batch_size)
            self.auc_out_dir = "{}/auc_{}_{}_{}".format(self.flags.dataset, self.flags.discriminator,
                                                        self.flags.train_interval, self.flags.batch_size)

            if not os.path.isdir(self.img_out_dir):
                os.makedirs(self.img_out_dir)
            if not os.path.isdir(self.auc_out_dir):
                os.makedirs(self.auc_out_dir)

        elif not self.flags.is_test:
            self.sample_out_dir = "{}/sample_{}_{}_{}".format(self.flags.dataset, self.flags.discriminator,
                                                              self.flags.train_interval, self.flags.batch_size)
            if not os.path.isdir(self.sample_out_dir):
                os.makedirs(self.sample_out_dir)

    def train(self):
        for iter_time in range(0, self.flags.iters+1, self.flags.train_interval):
            self.sample(iter_time)  # sampling images and save them

            # train discrminator
            for iter_ in range(1, self.flags.train_interval+1):
                x_imgs, y_imgs = self.dataset.train_next_batch(batch_size=self.flags.batch_size)
                d_loss = self.model.train_dis(x_imgs, y_imgs)
                self.print_info(iter_time + iter_, 'd_loss', d_loss)

            # train generator
            for iter_ in range(1, self.flags.train_interval+1):
                x_imgs, y_imgs = self.dataset.train_next_batch(batch_size=self.flags.batch_size)
                g_loss = self.model.train_gen(x_imgs, y_imgs)
                self.print_info(iter_time + iter_, 'g_loss', g_loss)

            auc_sum = self.eval(iter_time, phase='train')

            if self.best_auc_sum < auc_sum:
                self.best_auc_sum = auc_sum
                self.save_model(iter_time)

    def test(self):
        if self.load_model():
            print(' [*] Load Success!\n')
            self.eval(phase='test')
        else:
            print(' [!] Load Failed!\n')

    def sample(self, iter_time):
        if np.mod(iter_time, self.flags.sample_freq) == 0:
            idx = np.random.choice(self.dataset.num_val, 2, replace=False)
            x_imgs, y_imgs = self.dataset.val_imgs[idx], self.dataset.val_vessels[idx]
            samples = self.model.sample_imgs(x_imgs)

            # masking
            seg_samples = remain_in_mask(samples, self.dataset.val_masks[idx])

            # crop to original image shape
            x_imgs_ = crop_to_original(x_imgs, self.dataset.ori_shape)
            seg_samples_ = crop_to_original(seg_samples, self.dataset.ori_shape)
            y_imgs_ = crop_to_original(y_imgs, self.dataset.ori_shape)

            # sampling
            self.plot(x_imgs_, seg_samples_, y_imgs_, iter_time, idx=idx, save_file=self.sample_out_dir,
                      phase='train')

    def plot(self, x_imgs, samples, y_imgs, iter_time, idx=None, save_file=None, phase='train'):
        # initialize grid size
        cell_size_h, cell_size_w = self.dataset.ori_shape[0] / 100, self.dataset.ori_shape[1] / 100
        num_columns, margin = 3, 0.05
        width = cell_size_w * num_columns
        height = cell_size_h * x_imgs.shape[0]
        fig = plt.figure(figsize=(width, height))  # (column, row)
        gs = gridspec.GridSpec(x_imgs.shape[0], num_columns)  # (row, column)
        gs.update(wspace=margin, hspace=margin)

        # convert from normalized to original image
        x_imgs_norm = np.zeros_like(x_imgs)
        std, mean = 0., 0.
        for _ in range(x_imgs.shape[0]):
            if phase == 'train':
                std = self.dataset.val_mean_std[idx[_]]['std']
                mean = self.dataset.val_mean_std[idx[_]]['mean']
            elif phase == 'test':
                std = self.dataset.test_mean_std[idx[_]]['std']
                mean = self.dataset.test_mean_std[idx[_]]['mean']
            x_imgs_norm[_] = np.expand_dims(x_imgs[_], axis=0) * std + mean
        x_imgs_norm = x_imgs_norm.astype(np.uint8)

        # 1 channel to 3 channels
        samples_3 = np.stack((samples, samples, samples), axis=3)
        y_imgs_3 = np.stack((y_imgs, y_imgs, y_imgs), axis=3)

        imgs = [x_imgs_norm, samples_3, y_imgs_3]
        for col_index in range(len(imgs)):
            for row_index in range(x_imgs.shape[0]):
                ax = plt.subplot(gs[row_index * num_columns + col_index])
                plt.axis('off')
                ax.set_xticklabels([])
                ax.set_yticklabels([])
                ax.set_aspect('equal')
                plt.imshow(imgs[col_index][row_index].reshape(
                    self.dataset.ori_shape[0], self.dataset.ori_shape[1], 3), cmap='Greys_r')

        if phase == 'train':
            plt.savefig(save_file + '/{}_{}.png'.format(str(iter_time), idx[0]), bbox_inches='tight')
            plt.close(fig)
        else:
            # save compared image
            plt.savefig(os.path.join(save_file, 'compared_{}.png'.format(os.path.basename(
                self.dataset.test_img_files[idx[0]])[:-4])), bbox_inches='tight')
            plt.close(fig)

            # save vessel alone, vessel should be uint8 type
            Image.fromarray(np.squeeze(samples*255).astype(np.uint8)).save(os.path.join(
                save_file, '{}.png'.format(os.path.basename(self.dataset.test_img_files[idx[0]][:-4]))))

    def print_info(self, iter_time, name, loss):
        if np.mod(iter_time, self.flags.print_freq) == 0:
            ord_output = collections.OrderedDict([(name, loss), ('dataset', self.flags.dataset),
                                                  ('discriminator', self.flags.discriminator),
                                                  ('train_interval', np.float32(self.flags.train_interval)),
                                                  ('gpu_index', self.flags.gpu_index)])
            print_metrics(iter_time, ord_output)

    def eval(self, iter_time=0, phase='train'):
        total_time, auc_sum = 0., 0.
        if np.mod(iter_time, self.flags.eval_freq) == 0:
            num_data, imgs, vessels, masks = None, None, None, None
            if phase == 'train':
                num_data = self.dataset.num_val
                imgs = self.dataset.val_imgs
                vessels = self.dataset.val_vessels
                masks = self.dataset.val_masks
            elif phase == 'test':
                num_data = self.dataset.num_test
                imgs = self.dataset.test_imgs
                vessels = self.dataset.test_vessels
                masks = self.dataset.test_masks

            generated = []
            for iter_ in range(num_data):
                x_img = imgs[iter_]
                x_img = np.expand_dims(x_img, axis=0)  # (H, W, C) to (1, H, W, C)

                # measure inference time
                start_time = time.time()
                generated_vessel = self.model.sample_imgs(x_img)
                total_time += (time.time() - start_time)

                generated.append(np.squeeze(generated_vessel, axis=(0, 3)))  # (1, H, W, 1) to (H, W)

            generated = np.asarray(generated)
            # calculate measurements
            auc_sum = self.measure(generated, vessels, masks, num_data, iter_time, phase, total_time)

            if phase == 'test':
                # save test images
                segmented_vessel = remain_in_mask(generated, masks)

                # crop to original image shape
                imgs_ = crop_to_original(imgs, self.dataset.ori_shape)
                cropped_vessel = crop_to_original(segmented_vessel, self.dataset.ori_shape)
                vessels_ = crop_to_original(vessels, self.dataset.ori_shape)

                for idx in range(num_data):
                    self.plot(np.expand_dims(imgs_[idx], axis=0),
                              np.expand_dims(cropped_vessel[idx], axis=0),
                              np.expand_dims(vessels_[idx], axis=0),
                              'test', idx=[idx], save_file=self.img_out_dir, phase='test')

        return auc_sum

    def measure(self, generated, vessels, masks, num_data, iter_time, phase, total_time):
        # masking
        vessels_in_mask, generated_in_mask = pixel_values_in_mask(
            vessels, generated, masks)

        # averaging processing time
        avg_pt = (total_time / num_data) * 1000  # average processing tiem

        # evaluate Area Under the Curve of ROC and Precision-Recall
        auc_roc = AUC_ROC(vessels_in_mask, generated_in_mask)
        auc_pr = AUC_PR(vessels_in_mask, generated_in_mask)

        # binarize to calculate Dice Coeffient
        binarys_in_mask = threshold_by_otsu(generated, masks)
        dice_coeff = dice_coefficient_in_train(vessels_in_mask, binarys_in_mask)
        acc, sensitivity, specificity = misc_measures(vessels_in_mask, binarys_in_mask)
        score = auc_pr + auc_roc + dice_coeff + acc + sensitivity + specificity

        # auc_sum for saving best model in training
        auc_sum = auc_roc + auc_pr

        # print information
        ord_output = collections.OrderedDict([('auc_pr', auc_pr), ('auc_roc', auc_roc),
                                              ('dice_coeff', dice_coeff), ('acc', acc),
                                              ('sensitivity', sensitivity), ('specificity', specificity),
                                              ('score', score), ('auc_sum', auc_sum),
                                              ('best_auc_sum', self.best_auc_sum), ('avg_pt', avg_pt)])
        if phase != 'test':
            print_metrics(iter_time, ord_output)

        # write in tensorboard when in train mode only
        if phase == 'train':
            self.model.measure_assign(
                auc_pr, auc_roc, dice_coeff, acc, sensitivity, specificity, score, iter_time)
        elif phase == 'test':
            # write in npy format for evaluation
            save_obj(vessels_in_mask, generated_in_mask,
                           os.path.join(self.auc_out_dir, "auc_roc.npy"),
                           os.path.join(self.auc_out_dir, "auc_pr.npy"))

        return auc_sum

    def save_model(self, iter_time):
        self.model.best_auc_sum_assign(self.best_auc_sum)

        model_name = "iter_{}_auc_sum_{:.3}".format(iter_time, self.best_auc_sum)
        self.saver.save(self.sess, os.path.join(self.model_out_dir, model_name))

        print('===================================================')
        print('                     Model saved!                  ')
        print(' Best auc_sum: {:.3}'.format(self.best_auc_sum))
        print('===================================================\n')

    def load_model(self):
        print(' [*] Reading checkpoint...')

        ckpt = tf.train.get_checkpoint_state(self.model_out_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, os.path.join(self.model_out_dir, ckpt_name))

            self.best_auc_sum = self.sess.run(self.model.best_auc_sum)
            print('====================================================')
            print('                     Model saved!                   ')
            #print(' Best auc_sum: {:.3}'.format(self.best_auc_sum))
            print('====================================================')

            return True
        else:
            return False

In [5]:
class CGAN(object):
    def __init__(self, sess, flags, image_size):
        self.sess = sess
        self.flags = flags
        self.image_size = image_size

        self.alpha_recip = 1. / self.flags.ratio_gan2seg if self.flags.ratio_gan2seg > 0 else 0
        self._gen_train_ops, self._dis_train_ops = [], []
        self.gen_c, self.dis_c = 32, 32

        self._build_net()
        self._init_assign_op()  # initialize assign operations

        print('Initialized CGAN SUCCESS!\n')

    def _build_net(self):
        self.X = tf.placeholder(tf.float32, shape=[None, *self.image_size, 3], name='image')
        self.Y = tf.placeholder(tf.float32, shape=[None, *self.image_size, 1], name='vessel')

        self.g_samples = self.generator(self.X)
        self.real_pair = tf.concat([self.X, self.Y], axis=3)
        self.fake_pair = tf.concat([self.X, self.g_samples], axis=3)

        d_real, d_logit_real = self.discriminator(self.real_pair)
        d_fake, d_logit_fake = self.discriminator(self.fake_pair, is_reuse=True)

        # discrminator loss
        self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_logit_real, labels=tf.ones_like(d_real)))
        self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_logit_fake, labels=tf.zeros_like(d_logit_fake)))
        self.d_loss = self.d_loss_real + self.d_loss_fake

        # generator loss
        gan_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            logits=d_logit_fake, labels=tf.ones_like(d_logit_fake)))
        seg_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.g_samples, labels=self.Y))
        self.g_loss = self.alpha_recip * gan_loss + seg_loss

        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'd_' in var.name]
        g_vars = [var for var in t_vars if 'g_' in var.name]

        dis_op = tf.train.AdamOptimizer(learning_rate=self.flags.learning_rate, beta1=self.flags.beta1)\
            .minimize(self.d_loss, var_list=d_vars)
        dis_ops = [dis_op] + self._dis_train_ops
        self.dis_optim = tf.group(*dis_ops)

        gen_op = tf.train.AdamOptimizer(learning_rate=self.flags.learning_rate, beta1=self.flags.beta1)\
            .minimize(self.g_loss, var_list=g_vars)
        gen_ops = [gen_op] + self._gen_train_ops
        self.gen_optim = tf.group(*gen_ops)

    def _init_assign_op(self):
        self.best_auc_sum_placeholder = tf.placeholder(tf.float32, name='best_auc_sum_placeholder')
        self.auc_pr_placeholder = tf.placeholder(tf.float32, name='auc_pr_placeholder')
        self.auc_roc_placeholder = tf.placeholder(tf.float32, name='auc_roc_placeholder')
        self.dice_coeff_placeholder = tf.placeholder(tf.float32, name='dice_coeff_placeholder')
        self.acc_placeholder = tf.placeholder(tf.float32, name='acc_placeholder')
        self.sensitivity_placeholder = tf.placeholder(tf.float32, name='sensitivity_placeholder')
        self.specificity_placeholder = tf.placeholder(tf.float32, name='specificity_placeholder')
        self.score_placeholder = tf.placeholder(tf.float32, name='score_placeholder')

        self.best_auc_sum = tf.Variable(0., trainable=False, dtype=tf.float32, name='best_auc_sum')
        auc_pr = tf.Variable(0., trainable=False, dtype=tf.float32, name='auc_pr')
        auc_roc = tf.Variable(0., trainable=False, dtype=tf.float32, name='auc_roc')
        dice_coeff = tf.Variable(0., trainable=False, dtype=tf.float32, name='dice_coeff')
        acc = tf.Variable(0., trainable=False, dtype=tf.float32, name='acc')
        sensitivity = tf.Variable(0., trainable=False, dtype=tf.float32, name='sensitivity')
        specificity = tf.Variable(0., trainable=False, dtype=tf.float32, name='specificity')
        score = tf.Variable(0., trainable=False, dtype=tf.float32, name='score')

        self.best_auc_sum_assign_op = self.best_auc_sum.assign(self.best_auc_sum_placeholder)
        auc_pr_assign_op = auc_pr.assign(self.auc_pr_placeholder)
        auc_roc_assign_op = auc_roc.assign(self.auc_roc_placeholder)
        dice_coeff_assign_op = dice_coeff.assign(self.dice_coeff_placeholder)
        acc_assign_op = acc.assign(self.acc_placeholder)
        sensitivity_assign_op = sensitivity.assign(self.sensitivity_placeholder)
        specificity_assign_op = specificity.assign(self.specificity_placeholder)
        score_assign_op = score.assign(self.score_placeholder)

        self.measure_assign_op = tf.group(auc_pr_assign_op, auc_roc_assign_op, dice_coeff_assign_op,
                                          acc_assign_op, sensitivity_assign_op, specificity_assign_op,
                                          score_assign_op)

        # for tensorboard
        if not self.flags.is_test:
            self.writer = tf.summary.FileWriter("{}/logs/{}_{}_{}".format(
                self.flags.dataset, self.flags.discriminator, self.flags.train_interval, self.flags.batch_size))

        auc_pr_summ = tf.summary.scalar("auc_pr_summary", auc_pr)
        auc_roc_summ = tf.summary.scalar("auc_roc_summary", auc_roc)
        dice_coeff_summ = tf.summary.scalar("dice_coeff_summary", dice_coeff)
        acc_summ = tf.summary.scalar("acc_summary", acc)
        sensitivity_summ = tf.summary.scalar("sensitivity_summary", sensitivity)
        specificity_summ = tf.summary.scalar("specificity_summary", specificity)
        score_summ = tf.summary.scalar("score_summary", score)

        self.measure_summary = tf.summary.merge([auc_pr_summ, auc_roc_summ, dice_coeff_summ, acc_summ,
                                                 sensitivity_summ, specificity_summ, score_summ])

    def generator(self, data, name='g_'):
        
        with tf.variable_scope(name):
            # conv1: (N, 640, 640, 1) -> (N, 320, 320, 32)
            conv1 = conv2d(data, self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv1_conv1')
            conv1 = batch_norm(conv1, name='conv1_batch1', _ops=self._gen_train_ops)
            conv1 = tf.nn.relu(conv1, name='conv1_relu1')
            conv1 = conv2d(conv1, self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv1_conv2')
            conv1 = batch_norm(conv1, name='conv1_batch2', _ops=self._gen_train_ops)
            conv1 = tf.nn.relu(conv1, name='conv1_relu2')
            pool1 = max_pool_2x2(conv1, name='maxpool1')

            # conv2: (N, 320, 320, 32) -> (N, 160, 160, 64)
            conv2 = conv2d(pool1, 2*self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv2_conv1')
            conv2 = batch_norm(conv2, name='conv2_batch1', _ops=self._gen_train_ops)
            conv2 = tf.nn.relu(conv2, name='conv2_relu1')
            conv2 = conv2d(conv2, 2*self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv2_conv2')
            conv2 = batch_norm(conv2, name='conv2-batch2', _ops=self._gen_train_ops)
            conv2 = tf.nn.relu(conv2, name='conv2_relu2')
            pool2 = max_pool_2x2(conv2, name='maxpool2')

            # conv3: (N, 160, 160, 64) -> (N, 80, 80, 128)
            conv3 = conv2d(pool2, 4*self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv3_conv1')
            conv3 = batch_norm(conv3, name='conv3_batch1', _ops=self._gen_train_ops)
            conv3 = tf.nn.relu(conv3, name='conv3_relu1')
            conv3 = conv2d(conv3, 4*self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv3_conv2')
            conv3 = batch_norm(conv3, name='conv3_batch2', _ops=self._gen_train_ops)
            conv3 = tf.nn.relu(conv3, name='conv3_relu2')
            pool3 = max_pool_2x2(conv3, name='maxpool3')

            # conv4: (N, 80, 80, 128) -> (N, 40, 40, 256)
            conv4 = conv2d(pool3, 8*self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv4_conv1')
            conv4 = batch_norm(conv4, name='conv4_batch1', _ops=self._gen_train_ops)
            conv4 = tf.nn.relu(conv4, name='conv4_relu1')
            conv4 = conv2d(conv4, 8*self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv4_conv2')
            conv4 = batch_norm(conv4, name='conv4_batch2', _ops=self._gen_train_ops)
            conv4 = tf.nn.relu(conv4, name='conv4_relu2')
            pool4 = max_pool_2x2(conv4, name='maxpool4')

            # conv5: (N, 40, 40, 256) -> (N, 40, 40, 512)
            conv5 = conv2d(pool4, 16*self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv5_conv1')
            conv5 = batch_norm(conv5, name='conv5_batch1', _ops=self._gen_train_ops)
            conv5 = tf.nn.relu(conv5, name='conv5_relu1')
            conv5 = conv2d(conv5, 16*self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv5_conv2')
            conv5 = batch_norm(conv5, name='conv5_batch2', _ops=self._gen_train_ops)
            conv5 = tf.nn.relu(conv5, name='conv5_relu2')

            # conv6: (N, 40, 40, 512) -> (N, 80, 80, 256)
            up1 = upsampling2d(conv5, size=(2, 2), name='conv6_up')
            conv6 = tf.concat([up1, conv4], axis=3, name='conv6_concat')
            conv6 = conv2d(conv6, 8*self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv6_conv1')
            conv6 = batch_norm(conv6, name='conv6_batch1', _ops=self._gen_train_ops)
            conv6 = tf.nn.relu(conv6, name='conv6_relu1')
            conv6 = conv2d(conv6, 8*self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv6_conv2')
            conv6 = batch_norm(conv6, name='conv6_batch2', _ops=self._gen_train_ops)
            conv6 = tf.nn.relu(conv6, name='conv6_relu2')

            # conv7: (N, 80, 80, 256) -> (N, 160, 160, 128)
            up2 = upsampling2d(conv6, size=(2, 2), name='conv7_up')
            conv7 = tf.concat([up2, conv3], axis=3, name='conv7_concat')
            conv7 = conv2d(conv7, 4*self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv7_conv1')
            conv7 = batch_norm(conv7, name='conv7_batch1', _ops=self._gen_train_ops)
            conv7 = tf.nn.relu(conv7, name='conv7_relu1')
            conv7 = conv2d(conv7, 4*self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv7_conv2')
            conv7 = batch_norm(conv7, name='conv7_batch2', _ops=self._gen_train_ops)
            conv7 = tf.nn.relu(conv7, name='conv7_relu2')

            # conv8: (N, 160, 160, 128) -> (N, 320, 320, 64)
            up3 = upsampling2d(conv7, size=(2, 2), name='conv8_up')
            conv8 = tf.concat([up3, conv2], axis=3, name='conv8_concat')
            conv8 = conv2d(conv8, 2*self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv8_conv1')
            conv8 = batch_norm(conv8, name='conv8_batch1', _ops=self._gen_train_ops)
            conv8 = tf.nn.relu(conv8, name='conv8_relu1')
            conv8 = conv2d(conv8, 2*self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv8_conv2')
            conv8 = batch_norm(conv8, name='conv8_batch2', _ops=self._gen_train_ops)
            conv8 = tf.nn.relu(conv8, name='conv8_relu2')

            # conv9: (N, 320, 320, 64) -> (N, 640, 640, 32)
            up4 = upsampling2d(conv8, size=(2, 2), name='conv9_up')
            conv9 = tf.concat([up4, conv1], axis=3, name='conv9_concat')
            conv9 = conv2d(conv9, self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv9_conv1')
            conv9 = batch_norm(conv9, name='conv9_batch1', _ops=self._gen_train_ops)
            conv9 = tf.nn.relu(conv9, name='conv9_relu1')
            conv9 = conv2d(conv9, self.gen_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv9_conv2')
            conv9 = batch_norm(conv9, name='conv9_batch2', _ops=self._gen_train_ops)
            conv9 = tf.nn.relu(conv9, name='conv9_relu2')

            # output layer: (N, 640, 640, 32) -> (N, 640, 640, 1)
            output = conv2d(conv9, 1, k_h=1, k_w=1, d_h=1, d_w=1, name='conv_output')

            return tf.nn.sigmoid(output)

    def discriminator(self, data, is_reuse=False):
        if self.flags.discriminator == 'image':
            return self.discriminator_image(data, is_reuse=is_reuse)
        else:
            raise NotImplementedError






    def discriminator_image(self, data, name='d_', is_reuse=False):
        with tf.variable_scope(name) as scope:
            if is_reuse is True:
                scope.reuse_variables()

            # conv1: (N, 640, 640, 4) -> (N,, 160, 160, 32)
            conv1 = conv2d(data, self.dis_c, k_h=3, k_w=3, d_h=2, d_w=2, name='conv1_conv1')
            conv1 = batch_norm(conv1, name='conv1_batch1', _ops=self._dis_train_ops)
            conv1 = tf.nn.relu(conv1, name='conv1_relu1')
            conv1 = conv2d(conv1, self.dis_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv1_conv2')
            conv1 = batch_norm(conv1, name='conv1_batch2', _ops=self._dis_train_ops)
            conv1 = tf.nn.relu(conv1, name='conv1_relu2')
            pool1 = max_pool_2x2(conv1, name='maxpool1')

            # conv2: (N, 160, 160, 32) -> (N, 40, 40, 64)
            conv2 = conv2d(pool1, 2*self.dis_c, k_h=3, k_w=3, d_h=2, d_w=2, name='conv2_conv1')
            conv2 = batch_norm(conv2, name='conv2_batch1', _ops=self._dis_train_ops)
            conv2 = tf.nn.relu(conv2, name='conv2_relu1')
            conv2 = conv2d(conv2, 2*self.dis_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv2_conv2')
            conv2 = batch_norm(conv2, name='conv2_batch2', _ops=self._dis_train_ops)
            conv2 = tf.nn.relu(conv2, name='conv2_relu2')
            pool2 = max_pool_2x2(conv2, name='maxpool2')

            # conv3: (N, 40, 40, 64) -> (N, 20, 20, 128)
            conv3 = conv2d(pool2, 4*self.dis_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv3_conv1')
            conv3 = batch_norm(conv3, name='conv3_batch1', _ops=self._dis_train_ops)
            conv3 = tf.nn.relu(conv3, name='conv3_relu1')
            conv3 = conv2d(conv3, 4*self.dis_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv3_conv2')
            conv3 = batch_norm(conv3, name='conv3_batch2', _ops=self._dis_train_ops)
            conv3 = tf.nn.relu(conv3, name='conv3_relu2')
            pool3 = max_pool_2x2(conv3, name='maxpool3')

            # conv4: (N, 20, 20, 128) -> (N, 10, 10, 256)
            conv4 = conv2d(pool3, 8*self.dis_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv4_conv1')
            conv4 = batch_norm(conv4, name='conv4_batch1', _ops=self._dis_train_ops)
            conv4 = tf.nn.relu(conv4, name='conv4_relu1')
            conv4 = conv2d(conv4, 8*self.dis_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv4_conv2')
            conv4 = batch_norm(conv4, name='conv4_batch2', _ops=self._dis_train_ops)
            conv4 = tf.nn.relu(conv4, name='conv4_relu2')
            pool4 = max_pool_2x2(conv4, name='maxpool4')

            # conv5: (N, 10, 10, 256) -> (N, 10, 10, 512)
            conv5 = conv2d(pool4, 16*self.dis_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv5_conv1')
            conv5 = batch_norm(conv5, name='conv5_batch1', _ops=self._dis_train_ops)
            conv5 = tf.nn.relu(conv5, name='conv5_relu1')
            conv5 = conv2d(conv5, 16*self.dis_c, k_h=3, k_w=3, d_h=1, d_w=1, name='conv5_conv2')
            conv5 = batch_norm(conv5, name='conv5_batch2', _ops=self._dis_train_ops)
            conv5 = tf.nn.relu(conv5, name='conv5_relu2')

            # output layer: (N, 10, 10, 512) -> (N, 1, 1, 512) -> (N, 1)
            shape = conv5.get_shape().as_list()
            gap = tf.layers.average_pooling2d(inputs=conv5, pool_size=shape[1], strides=1, padding='VALID',
                                              name='global_vaerage_pool')
            gap_flatten = tf.reshape(gap, [-1, 16*self.dis_c])
            output = linear(gap_flatten, 1, name='linear_output')

            return tf.nn.sigmoid(output), output

    def train_dis(self, x_data, y_data):
        feed_dict = {self.X: x_data, self.Y: y_data}
        # run discriminator
        _, d_loss = self.sess.run([self.dis_optim, self.d_loss], feed_dict=feed_dict)

        return d_loss

    def train_gen(self, x_data, y_data):
        feed_dict = {self.X: x_data, self.Y: y_data}
        # run generator
        _, g_loss = self.sess.run([self.gen_optim, self.g_loss], feed_dict=feed_dict)

        return g_loss

    def measure_assign(self, auc_pr, auc_roc, dice_coeff, acc, sensitivity, specificity, score, iter_time):
        feed_dict = {self.auc_pr_placeholder: auc_pr,
                     self.auc_roc_placeholder: auc_roc,
                     self.dice_coeff_placeholder: dice_coeff,
                     self.acc_placeholder: acc,
                     self.sensitivity_placeholder: sensitivity,
                     self.specificity_placeholder: specificity,
                     self.score_placeholder: score}

        self.sess.run(self.measure_assign_op, feed_dict=feed_dict)

        summary = self.sess.run(self.measure_summary)
        self.writer.add_summary(summary, iter_time)

    def best_auc_sum_assign(self, auc_sum):
        self.sess.run(self.best_auc_sum_assign_op, feed_dict={self.best_auc_sum_placeholder: auc_sum})

    def sample_imgs(self, x_data):
        return self.sess.run(self.g_samples, feed_dict={self.X: x_data})

In [7]:
FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_integer('train_interval', 10000, 'training interval between discriminator and generator, default: 1')
tf.flags.DEFINE_integer('ratio_gan2seg', 10, 'ratio of gan loss to seg loss, default: 10')
tf.flags.DEFINE_string('gpu_index', '0', 'gpu index, default: 0')
tf.flags.DEFINE_string('discriminator', 'image', 'type of discriminator [pixel|patch1|patch2|image], '
                                                 'default: image')

tf.flags.DEFINE_integer('batch_size', 1, 'batch size, default: 1')
tf.flags.DEFINE_string('dataset', dataset, 'dataset name [DRIVE|STARE], default: STARE')
tf.flags.DEFINE_bool('is_test', is_test, 'default: False (train)')

tf.flags.DEFINE_float('learning_rate', 2e-4, 'initial learning rate for Adam, default: 2e-4')
tf.flags.DEFINE_float('beta1', 0.5, 'momentum term of adam, default: 0.5')
tf.flags.DEFINE_integer('iters', 50000, 'number of iteratons, default: 50000')
tf.flags.DEFINE_integer('print_freq', 100, 'print frequency, default: 100')
tf.flags.DEFINE_integer('eval_freq', 500, 'evaluation frequency, default: 500')
tf.flags.DEFINE_integer('sample_freq', 200, 'sample frequency, default: 200')

tf.flags.DEFINE_string('checkpoint_dir', './checkpoints', 'models are saved here')
tf.flags.DEFINE_string('sample_dir', './sample', 'sample are saved here')
tf.flags.DEFINE_string('test_dir', './test', 'test images are saved here')


def main(_):
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_index

    solver = Solver(FLAGS)
    if FLAGS.is_test:
        solver.test()
    if not FLAGS.is_test:
        solver.train()


if __name__ == '__main__':
    tf.app.run()


Val data augmentation 3 degree...
Val data augmentation 6 degree...
Val data augmentation 9 degree...
Val data augmentation 12 degree...
Val data augmentation 15 degree...
Val data augmentation 18 degree...
Val data augmentation 21 degree...
Val data augmentation 24 degree...
Val data augmentation 27 degree...
Val data augmentation 30 degree...
Val data augmentation 33 degree...
Val data augmentation 36 degree...
Val data augmentation 39 degree...
Val data augmentation 42 degree...
Val data augmentation 45 degree...
Val data augmentation 48 degree...
Val data augmentation 51 degree...
Val data augmentation 54 degree...
Val data augmentation 57 degree...
Val data augmentation 60 degree...
Val data augmentation 63 degree...
Val data augmentation 66 degree...
Val data augmentation 69 degree...
Val data augmentation 72 degree...
Val data augmentation 75 degree...
Val data augmentation 78 degree...
Val data augmentation 81 degree...
Val data augmentation 84 degree...
Val data augmentation 8

W0603 10:37:13.986487 140267753264960 deprecation.py:506] From <ipython-input-1-4b36038312c1>:51: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Instructions for updating:
Use keras.layers.AveragePooling2D instead.


W0603 10:37:15.625997 140267753264960 deprecation.py:323] From <ipython-input-5-54453e6564db>:257: average_pooling2d (from tensorflow.python.layers.pooling) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.AveragePooling2D instead.




W0603 10:37:15.665317 140267753264960 ag_logging.py:145] Entity <bound method Pooling2D.call of <tensorflow.python.layers.pooling.AveragePooling2D object at 0x7f90f4677e10>> could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting <bound method Pooling2D.call of <tensorflow.python.layers.pooling.AveragePooling2D object at 0x7f90f4677e10>>: AssertionError: Bad argument number for Name: 3, expecting 4




W0603 10:37:16.060556 140267753264960 ag_logging.py:145] Entity <bound method Pooling2D.call of <tensorflow.python.layers.pooling.AveragePooling2D object at 0x7f90f40796d8>> could not be transformed and will be executed as-is. Please report this to the AutgoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: converting <bound method Pooling2D.call of <tensorflow.python.layers.pooling.AveragePooling2D object at 0x7f90f40796d8>>: AssertionError: Bad argument number for Name: 3, expecting 4


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


W0603 10:37:16.072754 140267753264960 deprecation.py:323] From /opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/nn_impl.py:180: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Initialized CGAN SUCCESS!

---------
Variables: name (type shape) [size]
---------
g_/conv1_conv1/w:0 (float32_ref 3x3x3x32) [864, bytes: 3456]
g_/conv1_conv1/biases:0 (float32_ref 32) [32, bytes: 128]
g_/conv1_batch1/beta:0 (float32_ref 32) [32, bytes: 128]
g_/conv1_batch1/gamma:0 (float32_ref 32) [32, bytes: 128]
g_/conv1_conv2/w:0 (float32_ref 3x3x32x32) [9216, bytes: 36864]
g_/conv1_conv2/biases:0 (float32_ref 32) [32, bytes: 128]
g_/conv1_batch2/beta:0 (float32_ref 32) [32, bytes: 128]
g_/conv1_batch2/gamma:0 (float32_ref 32) [32, bytes: 128]
g_/conv2_conv1/w:0 (float32_ref 3x3x32x64) [18432, bytes: 73728]
g_/conv2_conv1/biases:0 (float32_ref 64) [64, bytes: 256]
g_/conv2_batch1/beta:0 (float32_ref 64) [64, bytes: 256]
g_/conv2_batch1/gamma:0 (float32_ref 64) [64, bytes: 256]
g_/conv2_conv2/w:0 (float32_ref 3x3x64x64) [36864, bytes: 147456]
g_/conv2_conv2/biases:0 (float32_ref 64) [64, bytes: 256]
g_/conv2-batch2/beta:0 (float32_ref 64) [64, bytes: 256]
g_/conv2-batch2/gamma:0 (fl

*** Iteration 900  ====> 
d_loss : 0.00568102, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 1000  ====> 
d_loss : 0.00491966, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 1100  ====> 
d_loss : 0.00397274, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 1200  ====> 
d_loss : 0.00221968, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 1300  ====> 
d_loss : 0.00197458, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 1400  ====> 
d_loss : 0.0026956, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 1500  ====> 
d_loss : 0.00162501, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 1600  ====> 
d_loss : 0.00139975, 
dataset 

*** Iteration 7000  ====> 
d_loss : 2.85976e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 7100  ====> 
d_loss : 2.57716e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 7200  ====> 
d_loss : 2.98513e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 7300  ====> 
d_loss : 2.807e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 7400  ====> 
d_loss : 4.53784e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 7500  ====> 
d_loss : 2.63127e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 7600  ====> 
d_loss : 2.25886e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 7700  ====> 
d_loss : 2.11208e-05, 


*** Iteration 3200  ====> 
g_loss : 0.683178, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 3300  ====> 
g_loss : 0.684151, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 3400  ====> 
g_loss : 0.683363, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 3500  ====> 
g_loss : 0.687087, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 3600  ====> 
g_loss : 0.687259, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 3700  ====> 
g_loss : 0.682311, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 3800  ====> 
g_loss : 0.683271, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 3900  ====> 
g_loss : 0.682605, 
dataset : STARE, 
disc

*** Iteration 9400  ====> 
g_loss : 0.671236, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 9500  ====> 
g_loss : 0.674561, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 9600  ====> 
g_loss : 0.675314, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 9700  ====> 
g_loss : 0.674317, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 9800  ====> 
g_loss : 0.674025, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 9900  ====> 
g_loss : 0.677705, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 10000  ====> 
g_loss : 0.677874, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 0  ====> 
auc_pr : 0.835318, 
auc_roc : 0.916335, 
dic

*** Iteration 15300  ====> 
d_loss : 1.29218, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 15400  ====> 
d_loss : 0.485908, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 15500  ====> 
d_loss : 1.33312, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 15600  ====> 
d_loss : 1.20328, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 15700  ====> 
d_loss : 1.26376, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 15800  ====> 
d_loss : 0.372778, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 15900  ====> 
d_loss : 1.38576, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 16000  ====> 
d_loss : 0.700514, 
dataset : STARE, 
d

*** Iteration 11500  ====> 
g_loss : 0.958285, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 11600  ====> 
g_loss : 0.756222, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 11700  ====> 
g_loss : 0.741216, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 11800  ====> 
g_loss : 0.734734, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 11900  ====> 
g_loss : 0.735041, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 12000  ====> 
g_loss : 0.734903, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 12100  ====> 
g_loss : 0.731986, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 12200  ====> 
g_loss : 0.731907, 
dataset : STAR

*** Iteration 17700  ====> 
g_loss : 0.727762, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 17800  ====> 
g_loss : 0.73304, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 17900  ====> 
g_loss : 0.736391, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 18000  ====> 
g_loss : 0.74523, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 18100  ====> 
g_loss : 0.732914, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 18200  ====> 
g_loss : 0.738308, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 18300  ====> 
g_loss : 0.720488, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 18400  ====> 
g_loss : 0.733294, 
dataset : STARE,

*** Iteration 23600  ====> 
d_loss : 0.00151901, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 23700  ====> 
d_loss : 0.000991777, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 23800  ====> 
d_loss : 0.0012143, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 23900  ====> 
d_loss : 0.00192199, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 24000  ====> 
d_loss : 0.00109075, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 24100  ====> 
d_loss : 0.00185757, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 24200  ====> 
d_loss : 0.000529213, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 24300  ====> 
d_loss : 0.00106041

*** Iteration 29600  ====> 
d_loss : 8.15312e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 29700  ====> 
d_loss : 5.55476e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 29800  ====> 
d_loss : 4.65137e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 29900  ====> 
d_loss : 4.9744e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 30000  ====> 
d_loss : 4.219e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 20100  ====> 
g_loss : 0.726075, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 20200  ====> 
g_loss : 0.712968, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 20300  ====> 
g_loss : 0.707517, 
da

*** Iteration 25800  ====> 
g_loss : 0.694025, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 25900  ====> 
g_loss : 0.702159, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 26000  ====> 
g_loss : 0.713539, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 26100  ====> 
g_loss : 0.707968, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 26200  ====> 
g_loss : 0.708112, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 26300  ====> 
g_loss : 0.701051, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 26400  ====> 
g_loss : 0.701077, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 26500  ====> 
g_loss : 0.694497, 
dataset : STAR

*** Iteration 31800  ====> 
d_loss : 0.00218849, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 31900  ====> 
d_loss : 0.00256903, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 32000  ====> 
d_loss : 0.00325117, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 32100  ====> 
d_loss : 0.00159226, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 32200  ====> 
d_loss : 0.00152302, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 32300  ====> 
d_loss : 0.00539388, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 32400  ====> 
d_loss : 0.00171799, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 32500  ====> 
d_loss : 0.00122729,

*** Iteration 37800  ====> 
d_loss : 4.6756e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 37900  ====> 
d_loss : 1.7731e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 38000  ====> 
d_loss : 4.22761e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 38100  ====> 
d_loss : 2.27272e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 38200  ====> 
d_loss : 9.28904e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 38300  ====> 
d_loss : 6.06826e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 38400  ====> 
d_loss : 0.00015412, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 38500  ====> 
d_loss : 1.96876

*** Iteration 33900  ====> 
g_loss : 0.704257, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 34000  ====> 
g_loss : 0.700578, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 34100  ====> 
g_loss : 0.693829, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 34200  ====> 
g_loss : 0.696902, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 34300  ====> 
g_loss : 0.700348, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 34400  ====> 
g_loss : 0.69448, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 34500  ====> 
g_loss : 0.696056, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 34600  ====> 
g_loss : 0.691415, 
dataset : STARE

*** Iteration 30000  ====> 
auc_pr : 0.584026, 
auc_roc : 0.832512, 
dice_coeff : 0.578868, 
acc : 0.919456, 
sensitivity : 0.507643, 
specificity : 0.969858, 
score : 4.39236, 
auc_sum : 1.41654, 
best_auc_sum : 1.75165, 
avg_pt : 43.5879, 

*** Iteration 40100  ====> 
d_loss : 0.66683, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 40200  ====> 
d_loss : 0.622874, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 40300  ====> 
d_loss : 0.0732093, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 40400  ====> 
d_loss : 0.00393473, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 40500  ====> 
d_loss : 0.00313979, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 40600  ====> 
d_loss : 0.00385574, 
dataset : STARE, 
discriminator :

*** Iteration 46000  ====> 
d_loss : 8.40614e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 46100  ====> 
d_loss : 5.5981e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 46200  ====> 
d_loss : 3.9321e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 46300  ====> 
d_loss : 0.000155133, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 46400  ====> 
d_loss : 4.80622e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 46500  ====> 
d_loss : 6.37388e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 46600  ====> 
d_loss : 3.23031e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 46700  ====> 
d_loss : 3.6245

*** Iteration 42100  ====> 
g_loss : 1.2834, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 42200  ====> 
g_loss : 0.711402, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 42300  ====> 
g_loss : 0.706684, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 42400  ====> 
g_loss : 0.704427, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 42500  ====> 
g_loss : 0.709223, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 42600  ====> 
g_loss : 0.724349, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 42700  ====> 
g_loss : 0.70056, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 42800  ====> 
g_loss : 0.721865, 
dataset : STARE, 

*** Iteration 48300  ====> 
g_loss : 0.715494, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 48400  ====> 
g_loss : 0.720145, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 48500  ====> 
g_loss : 0.709734, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 48600  ====> 
g_loss : 0.695943, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 48700  ====> 
g_loss : 0.695986, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 48800  ====> 
g_loss : 0.694277, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 48900  ====> 
g_loss : 0.70364, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 49000  ====> 
g_loss : 0.702102, 
dataset : STARE

*** Iteration 54200  ====> 
d_loss : 0.000201574, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 54300  ====> 
d_loss : 0.000378497, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 54400  ====> 
d_loss : 0.00072596, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 54500  ====> 
d_loss : 9.61395e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 54600  ====> 
d_loss : 0.000117264, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 54700  ====> 
d_loss : 0.000146216, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 54800  ====> 
d_loss : 7.80247e-05, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 54900  ====> 
d_loss : 5.609

*** Iteration 50200  ====> 
g_loss : 0.701057, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 50300  ====> 
g_loss : 0.699848, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 50400  ====> 
g_loss : 0.693487, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 50500  ====> 
g_loss : 0.690765, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 50600  ====> 
g_loss : 0.69271, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 50700  ====> 
g_loss : 0.689859, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 50800  ====> 
g_loss : 0.69323, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 50900  ====> 
g_loss : 0.695207, 
dataset : STARE,

*** Iteration 56400  ====> 
g_loss : 0.689877, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 56500  ====> 
g_loss : 0.693627, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 56600  ====> 
g_loss : 0.692974, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 56700  ====> 
g_loss : 0.689122, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 56800  ====> 
g_loss : 0.689042, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 56900  ====> 
g_loss : 0.689458, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 57000  ====> 
g_loss : 0.693197, 
dataset : STARE, 
discriminator : image, 
train_interval : 10000.0, 
gpu_index : 0, 

*** Iteration 57100  ====> 
g_loss : 0.69085, 
dataset : STARE

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
