In [None]:
!pip install pydicom
!pip install dicom
 
from google.colab import drive
drive.mount('/content/drive')
 
%cd "/content/drive/My Drive/TrainingSet"
 
#! git clone https://github.com/chuckyee/cardiac-segmentation.git  
 
%cd cardiac-segmentation
#!pip install .

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/.shortcut-targets-by-id/15_bncqiTE7-hHuSb3jG-DtFgbMx_k45P/TrainingSet
/content/drive/.shortcut-targets-by-id/15_bncqiTE7-hHuSb3jG-DtFgbMx_k45P/TrainingSet/cardiac-segmentation


In [None]:
#!/usr/bin/env python

from __future__ import division, print_function

from math import ceil
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter

from keras import utils
from keras.preprocessing import image as keras_image
from keras.preprocessing.image import ImageDataGenerator

import os, glob, re
import dicom
import numpy as np
from PIL import Image, ImageDraw

from keras.layers import Input, Conv2D, Conv2DTranspose
from keras.layers import MaxPooling2D, Cropping2D, Concatenate
from keras.layers import Lambda, Activation, BatchNormalization, Dropout
from keras.models import Model
from keras import backend as K

import argparse
import logging
from keras import losses, optimizers, utils
from keras.optimizers import SGD, RMSprop, Adagrad, Adadelta, Adam, Adamax, Nadam
from keras.callbacks import ModelCheckpoint


def maybe_rotate(image):
    # orient image in landscape
    height, width = image.shape
    return np.rot90(image) if width < height else image

class PatientData(object):
    """Data directory structure (for patient 01):
    directory/
      P01dicom.txt
      P01dicom/
        P01-0000.dcm
        P01-0001.dcm
        ...
      P01contours-manual/
        P01-0080-icontour-manual.txt
        P01-0120-ocontour-manual.txt
        ...
    """
    def __init__(self, directory):
        self.directory = os.path.normpath(directory)

        # get patient index from contour listing file
        glob_search = os.path.join(directory, "P*list.txt")
        files = glob.glob(glob_search)
        if len(files) == 0:
            raise Exception("Couldn't find contour listing file in {}. "
                            "Wrong directory?".format(directory))
        self.contour_list_file = files[0]
        match = re.search("P(..)list.txt", self.contour_list_file)
        self.index = int(match.group(1))

        # load all data into memory
        self.load_images()

        # some patients do not have contour data, and that's ok
        try:
            self.load_masks()
        except FileNotFoundError:
            pass

    @property
    def images(self):
        return [self.all_images[i] for i in self.labeled]

    @property
    def dicoms(self):
        return [self.all_dicoms[i] for i in self.labeled]

    @property
    def dicom_path(self):
        return os.path.join(self.directory, "P{:02d}dicom".format(self.index))

    def load_images(self):
        glob_search = os.path.join(self.dicom_path, "*.dcm")
        dicom_files = sorted(glob.glob(glob_search))
        self.all_images = []
        self.all_dicoms = []
        for dicom_file in dicom_files:
            plan = dicom.read_file(dicom_file)
            image = maybe_rotate(plan.pixel_array)
            self.all_images.append(image)
            self.all_dicoms.append(plan)
        self.image_height, self.image_width = image.shape
        self.rotated = (plan.pixel_array.shape != image.shape)

    def load_contour(self, filename):
        # strip out path head "patientXX/"
        match = re.search("patient../(.*)", filename)
        path = os.path.join(self.directory, match.group(1))
        x, y = np.loadtxt(path).T
        if self.rotated:
            x, y = y, self.image_height - x
        return x, y

    def contour_to_mask(self, x, y, norm=255):
        BW_8BIT = 'L'
        polygon = list(zip(x, y))
        image_dims = (self.image_width, self.image_height)
        img = Image.new(BW_8BIT, image_dims, color=0)
        ImageDraw.Draw(img).polygon(polygon, outline=1, fill=1)
        return norm * np.array(img, dtype='uint8')

    def load_masks(self):
        with open(self.contour_list_file, 'r') as f:
            files = [line.strip() for line in f.readlines()]

        inner_files = [path.replace("\\", "/") for path in files[0::2]]
        outer_files = [path.replace("\\", "/") for path in files[1::2]]

        # get list of frames which have contours
        self.labeled = []
        for inner_file in inner_files:
            match = re.search("P..-(....)-.contour", inner_file)
            frame_number = int(match.group(1))
            self.labeled.append(frame_number)

        self.endocardium_contours = []
        self.epicardium_contours = []
        self.endocardium_masks = []
        self.epicardium_masks = []
        for inner_file, outer_file in zip(inner_files, outer_files):
            inner_x, inner_y = self.load_contour(inner_file)
            self.endocardium_contours.append((inner_x, inner_y))
            outer_x, outer_y = self.load_contour(outer_file)
            self.epicardium_contours.append((outer_x, outer_y))

            inner_mask = self.contour_to_mask(inner_x, inner_y, norm=1)
            self.endocardium_masks.append(inner_mask)
            outer_mask = self.contour_to_mask(outer_x, outer_y, norm=1)
            self.epicardium_masks.append(outer_mask)
            
    def write_video(self, outfile, FPS=24):
        import cv2
        image_dims = (self.image_width, self.image_height)
        video = cv2.VideoWriter(outfile, -1, FPS, image_dims)
        for image in self.all_images:
            grayscale = np.asarray(image * (255 / image.max()), dtype='uint8')
            video.write(cv2.cvtColor(grayscale, cv2.COLOR_GRAY2BGR))
        video.release()

def load_images(data_dir, mask='both'):
    """Load all patient images and contours from TrainingSet, Test1Set or
    Test2Set directory. The directories and images are read in sorted order.

    Arguments:
      data_dir - path to data directory (TrainingSet, Test1Set or Test2Set)

    Output:
      tuples of (images, masks), both of which are 4-d tensors of shape
      (batchsize, height, width, channels). Images is uint16 and masks are
      uint8 with values 0 or 1.
    """
    assert mask in ['inner', 'outer', 'both']

    glob_search = os.path.join(data_dir, "patient*")
    patient_dirs = sorted(glob.glob(glob_search))
    if len(patient_dirs) == 0:
        raise Exception("No patient directors found in {}".format(data_dir))

    # load all images into memory (dataset is small)
    images = []
    inner_masks = []
    outer_masks = []
    for patient_dir in patient_dirs :
        p = PatientData(patient_dir)
        images += p.images
        inner_masks += p.endocardium_masks
        outer_masks += p.epicardium_masks

    # reshape to account for channel dimension
    images = np.asarray(images)[:,:,:,None]
    if mask == 'inner':
        masks = np.asarray(inner_masks)
    elif mask == 'outer':
        masks = np.asarray(outer_masks)
    elif mask == 'both':
        # mask = 2 for endocardium, 1 for cardiac wall, 0 elsewhere
        masks = np.asarray(inner_masks) + np.asarray(outer_masks)

    # one-hot encode masks
    dims = masks.shape
    classes = len(set(masks[0].flatten())) # get num classes from first image
    new_shape = dims + (classes,)
    masks = utils.to_categorical(masks).reshape(new_shape)

    return images, masks

def random_elastic_deformation(image, alpha, sigma, mode='nearest',
                               random_state=None):
    """Elastic deformation of images as described in [Simard2003]_.
    .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
       Convolutional Neural Networks applied to Visual Document Analysis", in
       Proc. of the International Conference on Document Analysis and
       Recognition, 2003.
    """
    assert len(image.shape) == 3

    if random_state is None:
        random_state = np.random.RandomState(None)

    height, width, channels = image.shape

    dx = gaussian_filter(2*random_state.rand(height, width) - 1,
                         sigma, mode="constant", cval=0) * alpha
    dy = gaussian_filter(2*random_state.rand(height, width) - 1,
                         sigma, mode="constant", cval=0) * alpha

    x, y = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')
    indices = (np.repeat(np.ravel(x+dx), channels),
               np.repeat(np.ravel(y+dy), channels),
               np.tile(np.arange(channels), height*width))
    
    values = map_coordinates(image, indices, order=1, mode=mode)

    return values.reshape((height, width, channels))

class Iterator(object):
    def __init__(self, images, masks, batch_size,
                 shuffle=True,
                 rotation_range=180,
                 width_shift_range=0.1,
                 height_shift_range=0.1,
                 shear_range=0.1,
                 zoom_range=0.01,
                 fill_mode='nearest',
                 alpha=500,
                 sigma=20):
        self.images = images
        self.masks = masks
        self.batch_size = batch_size
        self.shuffle = shuffle
        augment_options = {
            'rotation_range': rotation_range,
            'width_shift_range': width_shift_range,
            'height_shift_range': height_shift_range,
            'shear_range': shear_range,
            'zoom_range': zoom_range,
            'fill_mode': fill_mode,
        }
        self.idg = ImageDataGenerator(**augment_options)
        self.alpha = alpha
        self.sigma = sigma
        self.fill_mode = fill_mode
        self.i = 0
        self.index = np.arange(len(images))
        if shuffle:
            np.random.shuffle(self.index)

    def __next__(self):
        return self.next()

    def next(self):
        # compute how many images to output in this batch
        start = self.i
        end = min(start + self.batch_size, len(self.images))

        augmented_images = []
        augmented_masks = []
        for n in self.index[start:end]:
            image = self.images[n]
            mask = self.masks[n]

            _, _, channels = image.shape

            # stack image + mask together to simultaneously augment
            stacked = np.concatenate((image, mask), axis=2)

            # apply simple affine transforms first using Keras
            augmented = self.idg.random_transform(stacked)

            # maybe apply elastic deformation
            if self.alpha != 0 and self.sigma != 0:
                augmented = random_elastic_deformation(
                    augmented, self.alpha, self.sigma, self.fill_mode)

            # split image and mask back apart
            augmented_image = augmented[:,:,:channels]
            augmented_images.append(augmented_image)
            augmented_mask = np.round(augmented[:,:,channels:])
            augmented_masks.append(augmented_mask)

        self.i += self.batch_size
        if self.i >= len(self.images):
            self.i = 0
            if self.shuffle:
                np.random.shuffle(self.index)

        return np.asarray(augmented_images), np.asarray(augmented_masks)

def normalize(x, epsilon=1e-7, axis=(1,2)):
    x -= np.mean(x, axis=axis, keepdims=True)
    x /= np.std(x, axis=axis, keepdims=True) + epsilon

def create_generators(data_dir, batch_size, validation_split=0.0, mask='both',
                      shuffle_train_val=True, shuffle=True, seed=None,
                      normalize_images=True, augment_training=False,
                      augment_validation=False, augmentation_args={}):
    images, masks = load_images(data_dir, mask)

    # before: type(masks) = uint8 and type(images) = uint16
    # convert images to double-precision
    images = images.astype('float64')

    # maybe normalize image
    if normalize_images:
        normalize(images, axis=(1,2))

    if seed is not None:
        np.random.seed(seed)

    if shuffle_train_val:
        # shuffle images and masks in parallel
        rng_state = np.random.get_state()
        np.random.shuffle(images)
        np.random.set_state(rng_state)
        np.random.shuffle(masks)

    # split out last %(validation_split) of images as validation set
    split_index = int((1-validation_split) * len(images))

    if augment_training:
        train_generator = Iterator(
            images[:split_index], masks[:split_index],
            batch_size, shuffle=shuffle, **augmentation_args)
    else:
        idg = ImageDataGenerator()
        train_generator = idg.flow(images[:split_index], masks[:split_index],
                                   batch_size=batch_size, shuffle=shuffle)

    train_steps_per_epoch = ceil(split_index / batch_size)

    if validation_split > 0.0:
        if augment_validation:
            val_generator = Iterator(
                images[split_index:], masks[split_index:],
                batch_size, shuffle=shuffle, **augmentation_args)
        else:
            idg = ImageDataGenerator()
            val_generator = idg.flow(images[split_index:], masks[split_index:],
                                     batch_size=batch_size, shuffle=shuffle)
    else:
        val_generator = None

    val_steps_per_epoch = ceil((len(images) - split_index) / batch_size)

    return (train_generator, train_steps_per_epoch,
            val_generator, val_steps_per_epoch)



This code is using an older version of pydicom, which is no longer 
maintained as of Jan 2017.  You can access the new pydicom features and API 
by installing `pydicom` from PyPI.
See 'Transitioning to pydicom 1.x' section at pydicom.readthedocs.org 
for more information.



In [None]:
def downsampling_block(input_tensor, filters, padding='valid',
                       batchnorm=False, dropout=0.0):
    _, height, width, _ = K.int_shape(input_tensor)
    
    assert height % 2 == 0
    assert width % 2 == 0

    x = Conv2D(filters, kernel_size=(3,3), padding=padding)(input_tensor)
    x = BatchNormalization()(x) if batchnorm else x
    x = Activation('relu')(x)
    x = Dropout(dropout)(x) if dropout > 0 else x

    x = Conv2D(filters, kernel_size=(3,3), padding=padding)(x)
    x = BatchNormalization()(x) if batchnorm else x
    x = Activation('relu')(x)
    x = Dropout(dropout)(x) if dropout > 0 else x

    return MaxPooling2D(pool_size=(2,2))(x), x

def upsampling_block(input_tensor, skip_tensor, filters, padding='valid',
                     batchnorm=False, dropout=0.0):
    x = Conv2DTranspose(filters, kernel_size=(2,2), strides=(2,2))(input_tensor)

    # compute amount of cropping needed for skip_tensor
    _, x_height, x_width, _ = K.int_shape(x)
    _, s_height, s_width, _ = K.int_shape(skip_tensor)
    h_crop = s_height - x_height
    w_crop = s_width - x_width
    assert h_crop >= 0
    assert w_crop >= 0
    if h_crop == 0 and w_crop == 0:
        y = skip_tensor
    else:
        cropping = ((h_crop//2, h_crop - h_crop//2), (w_crop//2, w_crop - w_crop//2))
        y = Cropping2D(cropping=cropping)(skip_tensor)

    x = Concatenate()([x, y])

    x = Conv2D(filters, kernel_size=(3,3), padding=padding)(x)
    x = BatchNormalization()(x) if batchnorm else x
    x = Activation('relu')(x)
    x = Dropout(dropout)(x) if dropout > 0 else x

    x = Conv2D(filters, kernel_size=(3,3), padding=padding)(x)
    x = BatchNormalization()(x) if batchnorm else x
    x = Activation('relu')(x)
    x = Dropout(dropout)(x) if dropout > 0 else x

    return x

def unet(height, width, channels, classes, features=64, depth=3,
         temperature=1.0, padding='valid', batchnorm=False, dropout=0.0):
    """Generate U-Net model introduced in
      "U-Net: Convolutional Networks for Biomedical Image Segmentation"
      O. Ronneberger, P. Fischer, T. Brox (2015)
    Arbitrary number of input channels and output classes are supported.

    Arguments:
      height  - input image height (pixels)
      width   - input image width  (pixels)
      channels - input image features (1 for grayscale, 3 for RGB)
      classes - number of output classes (2 in paper)
      features - number of output features for first convolution (64 in paper)
          Number of features double after each down sampling block
      depth  - number of downsampling operations (4 in paper)
      padding - 'valid' (used in paper) or 'same'
      batchnorm - include batch normalization layers before activations
      dropout - fraction of units to dropout, 0 to keep all units

    Output:
      U-Net model expecting input shape (height, width, maps) and generate
      output with shape (output_height, output_width, classes). If padding is
      'same', then output_height = height and output_width = width.
    """
    x = Input(shape=(height, width, channels))
    inputs = x

    skips = []
    for i in range(depth):
        x, x0 = downsampling_block(x, features, padding,
                                   batchnorm, dropout)
        skips.append(x0)
        features *= 2

    x = Conv2D(filters=features, kernel_size=(3,3), padding=padding)(x)
    x = BatchNormalization()(x) if batchnorm else x
    x = Activation('relu')(x)
    x = Dropout(dropout)(x) if dropout > 0 else x

    x = Conv2D(filters=features, kernel_size=(3,3), padding=padding)(x)
    x = BatchNormalization()(x) if batchnorm else x
    x = Activation('relu')(x)
    x = Dropout(dropout)(x) if dropout > 0 else x

    for i in reversed(range(depth)):
        features //= 2
        x = upsampling_block(x, skips[i], features, padding,
                             batchnorm, dropout)

    x = Conv2D(filters=classes, kernel_size=(1,1))(x)

    logits = Lambda(lambda z: z/temperature)(x)
    probabilities = Activation('softmax')(logits)

    return Model(inputs=inputs, outputs=probabilities)

def soft_sorensen_dice(y_true, y_pred, axis=None, smooth=1):
    intersection = K.sum(y_true * y_pred, axis=axis)
    area_true = K.sum(y_true, axis=axis)
    area_pred = K.sum(y_pred, axis=axis)
    return (2 * intersection + smooth) / (area_true + area_pred + smooth)
    
def hard_sorensen_dice(y_true, y_pred, axis=None, smooth=1):
    y_true_int = K.round(y_true)
    y_pred_int = K.round(y_pred)
    return soft_sorensen_dice(y_true_int, y_pred_int, axis, smooth)

sorensen_dice = hard_sorensen_dice

def sorensen_dice_loss(y_true, y_pred, weights):
    # Input tensors have shape (batch_size, height, width, classes)
    # User must input list of weights with length equal to number of classes
    #
    # Ex: for simple binary classification, with the 0th mask
    # corresponding to the background and the 1st mask corresponding
    # to the object of interest, we set weights = [0, 1]
    batch_dice_coefs = soft_sorensen_dice(y_true, y_pred, axis=[1, 2])
    dice_coefs = K.mean(batch_dice_coefs, axis=0)
    w = K.constant(weights) / sum(weights)
    return 1 - K.sum(w * dice_coefs)

def soft_jaccard(y_true, y_pred, axis=None, smooth=1):
    intersection = K.sum(y_true * y_pred, axis=axis)
    area_true = K.sum(y_true, axis=axis)
    area_pred = K.sum(y_pred, axis=axis)
    union = area_true + area_pred - intersection
    return (intersection + smooth) / (union + smooth)

def hard_jaccard(y_true, y_pred, axis=None, smooth=1):
    y_true_int = K.round(y_true)
    y_pred_int = K.round(y_pred)
    return soft_jaccard(y_true_int, y_pred_int, axis, smooth)

jaccard = hard_jaccard

def jaccard_loss(y_true, y_pred, weights):
    batch_jaccard_coefs = soft_jaccard(y_true, y_pred, axis=[1, 2])
    jaccard_coefs = K.mean(batch_jaccard_coefs, axis=0)
    w = K.constant(weights) / sum(weights)
    return 1 - K.sum(w * jaccard_coefs)

def weighted_categorical_crossentropy(y_true, y_pred, weights, epsilon=1e-8):
    ndim = K.ndim(y_pred)
    ncategory = K.int_shape(y_pred)[-1]
    # scale predictions so class probabilities of each pixel sum to 1
    y_pred /= K.sum(y_pred, axis=(ndim-1), keepdims=True)
    y_pred = K.clip(y_pred, epsilon, 1-epsilon)
    w = K.constant(weights) * (ncategory / sum(weights))
    # first, average over all axis except classes
    cross_entropies = -K.mean(y_true * K.log(y_pred), axis=tuple(range(ndim-1)))
    return K.sum(w * cross_entropies)



In [None]:

datadir = "/content/drive/My Drive/TrainingSet/cardiac-segmentation" # Directory containing list of patientXX/ subdirectories
outdir = "/content/drive/My Drive/TrainingSet/cardiac-segmentation"  # Where to write weight files
outfile = 'weights-final.hdf5'                                       # File to write final model weights
testdir =  "/content/drive/My Drive/Test1Set"

augmentation_args = {
        'rotation_range': 180,       # Rotation range (0-180 degrees)
        'width_shift_range': 0.1,    # Width shift range, as a float fraction of the width
        'height_shift_range': 0.1,   # Height shift range, as a float fraction of the height
        'shear_range': 0.1,          # Shear intensity (in radians)
        'zoom_range': 0.05,          # Amount of zoom. If a scalar z, zoom in [1-z, 1+z].Can also pass a pair of floats as the zoom range.
        'fill_mode' :'nearest',      # Points outside boundaries are filled according to  mode: constant, nearest, reflect, or wrap)
        'alpha': 500,                # Random elastic distortion: magnitude of distortion
        'sigma': 20,                 # Random elastic distortion: length scale
    }
batch_size = 32               # Mini-batch size for training
validation_split = 0.2        # Fraction of training data to hold out for validation
shuffle_train_val = False
classes = 'inner'           # One of `inner', `outer', or `both' for endocardium, epicardium, or both
shuffle = False
seed = 1
normalize = False
augment_training= False       # Whether to apply image augmentation to training set
augment_validation = False    # Whether to apply image augmentation to validation set

train_generator, train_steps_per_epoch, \
        val_generator, val_steps_per_epoch = create_generators(
            datadir, batch_size,
            validation_split=validation_split,
            mask=classes,
            shuffle_train_val=shuffle_train_val,
            shuffle=shuffle,
            seed=seed,
            normalize_images=normalize,
            augment_training=augment_training,
            augment_validation=augment_validation,
            augmentation_args=augmentation_args)

images, masks = next(train_generator)
_, height, width, channels = images.shape
_, _, _, classes = masks.shape

test_generator, test_steps_per_epoch, \
        test_val_generator, test_val_steps_per_epoch = create_generators(
            testdir, 64,
            validation_split=0.0,
            mask='inner',
            shuffle_train_val=False,
            shuffle=False,
            seed=None,
            normalize_images=False,
            augment_training=False,
            augment_validation=False,
            augmentation_args=augmentation_args) 

In [None]:
def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred = K.cast(y_pred, 'float32')
    y_pred_f = K.cast(K.greater(K.flatten(y_pred), 0.5), 'float32')
    intersection = y_true_f * y_pred_f
    score = 2. * K.sum(intersection) / (K.sum(y_true_f) + K.sum(y_pred_f))
    return score

def dice_loss(y_true, y_pred):
    smooth = 1.
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = y_true_f * y_pred_f
    score = (2. * K.sum(intersection) + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return 1. - score

def bce_dice_loss(y_true, y_pred):
    return binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

def bce_logdice_loss(y_true, y_pred):
    return binary_crossentropy(y_true, y_pred) - K.log(1. - dice_loss(y_true, y_pred))

def weighted_bce_loss(y_true, y_pred, weight):
    epsilon = 1e-7
    y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
    logit_y_pred = K.log(y_pred / (1. - y_pred))
    loss = weight * (logit_y_pred * (1. - y_true) + 
                     K.log(1. + K.exp(-K.abs(logit_y_pred))) + K.maximum(-logit_y_pred, 0.))
    return K.sum(loss) / K.sum(weight)


def weighted_dice_loss(y_true, y_pred, weight):
    smooth = 1.
    w, m1, m2 = weight, y_true, y_pred
    intersection = (m1 * m2)
    score = (2. * K.sum(w * intersection) + smooth) / (K.sum(w * m1) + K.sum(w * m2) + smooth)
    loss = 1. - K.sum(score)
    return loss

def weighted_bce_dice_loss(y_true, y_pred):
    y_true = K.cast(y_true, 'float32')
    y_pred = K.cast(y_pred, 'float32')
    # if we want to get same size of output, kernel size must be odd
    averaged_mask = K.pool2d(
            y_true, pool_size=(50, 50), strides=(1, 1), padding='same', pool_mode='avg')
    weight = K.ones_like(averaged_mask)
    w0 = K.sum(weight)
    weight = 5. * K.exp(-5. * K.abs(averaged_mask - 0.5))
    w1 = K.sum(weight)
    weight *= (w0 / w1)
    loss = weighted_bce_loss(y_true, y_pred, weight) + dice_loss(y_true, y_pred)
    return loss


In [None]:
import tensorflow as tf
import keras

def scheduler(epoch, lr):
  if epoch < 10:
    return lr
  else:
    return lr * tf.math.exp(-0.1)

callback = keras.callbacks.LearningRateScheduler(scheduler)


In [None]:
def select_optimizer(optimizer_name, optimizer_args):
    optimizers = {
        'sgd': SGD,
        'rmsprop': RMSprop,
        'adagrad': Adagrad,
        'adadelta': Adadelta,
        'adam': Adam,
        'adamax': Adamax,
        'nadam': Nadam,
    }
    if optimizer_name not in optimizers:
        raise Exception("Unknown optimizer ({}).".format(name))
    return optimizers[optimizer_name](**optimizer_args)    
      
    # get image dimensions from first batch

features = 64          # Number of features maps after first convolutional layer
depth = 3              # Number of downsampled convolutional blocks
temperature = 1.0      # Temperature of final softmax layer in model
padding = 'same'       # Padding in convolutional layers. Either `same' or `valid'
dropout = 0.02         # Rate for dropout of activation units (set to zero to omit)
batchnorm = False      # Whether to apply batch normalization before activation layers
   
m = unet(height=height, width=width, channels=channels, classes=classes,
              features=features, depth=depth, padding=padding,
              temperature=temperature, batchnorm=batchnorm,
              dropout=dropout)

learning_rate = 0.1
momentum = None
decay = None
optimizer = 'adam'               # Optimizer: sgd, rmsprop, adagrad, adadelta, adam, adamax, or nadam
optimizer_args = {
        'lr':       learning_rate,   # Optimizer learning rate
        'momentum': momentum,        # Momentum for SGD optimizer
        'decay':    decay            # Learning rate decay (for all optimizers except nadam)
    }
for k in list(optimizer_args):
    if optimizer_args[k] is None:
        del optimizer_args[k]
optimizer = select_optimizer(optimizer, optimizer_args)

loss = 'dice'               # Loss function: `pixel' for pixel-wise cross entropy,
                                 # `dice' for sorensen-dice coefficient,
                                 # `jaccard' for intersection over union
loss_weights = [0.1, 0.9]    # When using dice or jaccard loss, how much to weight each output class

if loss == 'pixel':
    def lossfunc(y_true, y_pred):
        return weighted_categorical_crossentropy(
                y_true, y_pred, loss_weights)
elif loss == 'dice':
    def lossfunc(y_true, y_pred):
        return dice_loss(y_true, y_pred)
elif loss == 'jaccard':
    def lossfunc(y_true, y_pred):
        return jaccard_loss(y_true, y_pred, loss_weights)
else:
    raise Exception("Unknown loss ({})".format(loss))
def dice(y_true, y_pred):
    batch_dice_coefs = sorensen_dice(y_true, y_pred, axis=[1, 2])
    dice_coefs = K.mean(batch_dice_coefs, axis=0)
    return dice_coefs[1]    # HACK for 2-class case
def jaccard(y_true, y_pred):
    batch_jaccard_coefs = jaccard(y_true, y_pred, axis=[1, 2])
    jaccard_coefs = K.mean(batch_jaccard_coefs, axis=0)
    return jaccard_coefs[1] # HACK for 2-class case

import keras

metrics = ['accuracy', keras.metrics.MeanIoU(num_classes=2), dice_coef]
m.compile(optimizer=optimizer, loss=lossfunc, metrics=metrics)
    # automatic saving of model during training
checkpoint = False
if checkpoint:
    if loss == 'pixel':
        filepath = os.path.join(
                outdir, "weights-{epoch:02d}-{val_acc:.4f}.hdf5")
        monitor = 'val_acc'
        mode = 'max'
    elif loss == 'dice':
        filepath = os.path.join(
                outdir, "weights-{epoch:02d}-{val_dice:.4f}.hdf5")
        monitor='val_dice'
        mode = 'max'
    elif loss == 'jaccard':
        filepath = os.path.join(
                outdir, "weights-{epoch:02d}-{val_jaccard:.4f}.hdf5")
        monitor='val_jaccard'
        mode = 'max'
    checkpoint = ModelCheckpoint(
            filepath, monitor=monitor, verbose=1,
            save_best_only=True, mode=mode)
    callbacks = [checkpoint]
else:
    callbacks = []
    # train

epochs = 3  # Number of epochs to train
    
m.fit_generator(train_generator,
                    epochs=epochs,
                    steps_per_epoch=train_steps_per_epoch,
                    validation_data=val_generator,
                    validation_steps=val_steps_per_epoch,
                    callbacks=[callback],
                    verbose=2)
# m.save(os.path.join(outdir, outfile))

Epoch 1/3


ResourceExhaustedError: ignored

In [None]:
m.evaluate(test_generator)



[0.049127332866191864,
 0.9508726596832275,
 0.9063462615013123,
 0.9558917284011841]

In [None]:
def sorensen_dice(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    return 2*intersection / (np.sum(y_true) + np.sum(y_pred))

def jaccard(y_true, y_pred):
    intersection = np.sum(y_true & y_pred)
    union = np.sum(y_true | y_pred)
    return intersection / union