In [None]:
#------------------------------------------------------------
# Import necessary libraries
#------------------------------------------------------------
from tensorflow.keras.layers import Input, MaxPooling2D, concatenate, Conv2D, UpSampling2D, SpatialDropout2D, BatchNormalization, Activation
from tensorflow.keras import backend as K
from tensorflow.keras import Model
from tensorflow.keras.optimizers import SGD, RMSprop, Adam
from tensorflow.keras.preprocessing import image
from tensorflow.keras.callbacks import LearningRateScheduler, ModelCheckpoint
from google.cloud import storage
import numpy as np
import os
import tensorflow
import csv
import time
import tarfile
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow, show
import glob
from IPython.display import Image, display, clear_output
import cv2


In [None]:
#------------------------------------------------------------
# Run to generate credentials - paste the block of code from 
# "credentials for BE 547 lab" into this section
#------------------------------------------------------------

In [None]:
#------------------------------------------------------------
# Download the data
#------------------------------------------------------------
client = storage.Client.from_service_account_json("auth.json")
bucket = client.get_bucket("pennclassdata")
blob = bucket.blob("class_data.tar")
blob.download_to_filename("class_data.tar")
# Untar it
file = tarfile.open('class_data.tar')
file.extractall('data')  
file.close()
print("Download complete")

In [None]:
#------------------------------------------------------------
# Load helper functions
#------------------------------------------------------------
def conv_block_simple(prevlayer, filters, prefix, strides=(1, 1)):
    conv = Conv2D(filters, (3, 3), padding="same", kernel_initializer="he_normal", strides=strides, name=prefix + "_conv")(prevlayer)
    conv = BatchNormalization(name=prefix + "_bn")(conv)
    conv = Activation('relu', name=prefix + "_activation")(conv)
    return conv

def get_simple_unet(input_shape):
    img_input = Input((input_shape + (1,)))
    conv1 = conv_block_simple(img_input, 32, "conv1_1")
    conv1 = conv_block_simple(conv1, 32, "conv1_2")
    pool1 = MaxPooling2D((2, 2), strides=(2, 2), padding="same", name="pool1")(conv1)

    conv2 = conv_block_simple(pool1, 64, "conv2_1")
    conv2 = conv_block_simple(conv2, 64, "conv2_2")
    pool2 = MaxPooling2D((2, 2), strides=(2, 2), padding="same", name="pool2")(conv2)

    conv3 = conv_block_simple(pool2, 128, "conv3_1")
    conv3 = conv_block_simple(conv3, 128, "conv3_2")
    pool3 = MaxPooling2D((2, 2), strides=(2, 2), padding="same", name="pool3")(conv3)

    conv4 = conv_block_simple(pool3, 256, "conv4_1")
    conv4 = conv_block_simple(conv4, 256, "conv4_2")
    conv4 = conv_block_simple(conv4, 256, "conv4_3")

    up5 = concatenate([UpSampling2D()(conv4), conv3], axis=3)
    conv5 = conv_block_simple(up5, 128, "conv5_1")
    conv5 = conv_block_simple(conv5, 128, "conv5_2")

    up6 = concatenate([UpSampling2D()(conv5), conv2], axis=3)
    conv6 = conv_block_simple(up6, 64, "conv6_1")
    conv6 = conv_block_simple(conv6, 64, "conv6_2")

    up7 = concatenate([UpSampling2D()(conv6), conv1], axis=3)
    conv7 = conv_block_simple(up7, 32, "conv7_1")
    conv7 = conv_block_simple(conv7, 32, "conv7_2")

    conv7 = SpatialDropout2D(0.2)(conv7)

    prediction = Conv2D(1, (1, 1), activation="sigmoid", name="prediction")(conv7)
    model = Model(img_input, prediction)
    return model

def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return -dice_coef(y_true, y_pred)  

In [None]:
#------------------------------------------------------------
# Start training
#------------------------------------------------------------
smooth = 1
    
image_datagen = image.ImageDataGenerator(zoom_range=0.1,
                                        rotation_range=5,
                                        width_shift_range=0.1,
                                        height_shift_range=0.1)
mask_datagen  = image.ImageDataGenerator(zoom_range=0.1,
                                        rotation_range=5,
                                        rescale=1./255.,
                                        width_shift_range=0.1,
                                        height_shift_range=0.1)
                                        
seed = 1
input_shape = (64, 64)
output_shape = (64, 64)
batch_size = 32

image_generator = image_datagen.flow_from_directory(
    'data/train/image',
    class_mode=None,
    seed=seed,
    shuffle=True,
    batch_size=batch_size,
    target_size=output_shape,
    color_mode='grayscale')

mask_generator = mask_datagen.flow_from_directory(
    'data/train/mask',
    class_mode=None,
    seed=seed,
    shuffle=True,
    batch_size=batch_size,
    target_size=output_shape,
    color_mode='grayscale')

train_generator = zip(image_generator,mask_generator)

seed = 1

valid_image_generator = image_datagen.flow_from_directory(
    'data/valid/image',
    class_mode=None,
    seed=seed,
    shuffle=True,
    batch_size=batch_size,
    target_size=output_shape,
    color_mode='grayscale')

valid_mask_generator = mask_datagen.flow_from_directory(
    'data/valid/mask',
    class_mode=None,
    seed=seed,
    shuffle=True,
    batch_size=batch_size,
    target_size=output_shape,
    color_mode='grayscale')

valid_generator = zip(valid_image_generator,valid_mask_generator)

class WeightsRecorder(tensorflow.keras.callbacks.Callback):
    def __init__(self, progressFilePath):
        super(WeightsRecorder, self).__init__()
        self.progressFilePath = progressFilePath
        self.lastTimePoint = time.time()

    def on_epoch_end(self, epoch, logs=None):
        epoch += 1
        training_loss = logs["loss"]
        validation_loss = logs["val_loss"]
        training_dice = logs["dice_coef"]
        validation_dice = logs["val_dice_coef"]
        ellapsed = "%0.1f" % (time.time() - self.lastTimePoint)
        self.lastTimePoint = time.time()
        with open(self.progressFilePath, "a") as outputFile: 
            writer = csv.DictWriter(outputFile, lineterminator='\n', fieldnames=["epoch","time(s)","training_loss","validation_loss", "training_dice","validation_dice"])
            writer.writerow({"epoch": epoch,"time(s)": ellapsed, "training_loss": training_loss,"validation_loss": validation_loss, "training_dice": training_dice, "validation_dice": validation_dice})


outputDirPath = "training_output"
if not os.path.isdir(outputDirPath):
  os.mkdir(outputDirPath)

progressFilePath = os.path.join(outputDirPath, "liver_2D_training_progress.csv")
if not os.path.isfile(progressFilePath):
    with open(progressFilePath, "w") as outputFile: 
        writer = csv.DictWriter(outputFile, lineterminator='\n', fieldnames=["epoch","time(s)","training_loss","validation_loss", "training_dice","validation_dice"])
        writer.writeheader()
recorder = WeightsRecorder(progressFilePath)

weight_saver = ModelCheckpoint(os.path.join(outputDirPath, 'liver_model.{epoch:02d}-{val_loss:.2f}-{val_dice_coef:.2f}.h5'),save_best_only=False, save_weights_only=False)
callbackList = [recorder, weight_saver]
    
model = get_simple_unet(input_shape)
model.compile(optimizer=Adam(2e-5), loss=tensorflow.keras.losses.binary_crossentropy, metrics=[dice_coef])



hist=model.fit_generator(train_generator, validation_data=valid_generator, validation_steps=20, 
            steps_per_epoch=100, epochs=100, callbacks=callbackList)

In [None]:
#------------------------------------------------------------
# Plot history: cross-entropy
#------------------------------------------------------------ 
plt.plot(hist.history['loss'], label='loss (training data)')
plt.plot(hist.history['val_loss'], label='loss (validation data)')
plt.ylabel('loss value')
plt.xlabel('No. epoch')
plt.legend(loc="upper right")
plt.show()

In [None]:
#------------------------------------------------------------
# Plot history: dice coeff
#------------------------------------------------------------ 
plt.plot(hist.history['dice_coef'], label='dice (training data)')
plt.plot(hist.history['val_dice_coef'], label='dice (validation data)')
plt.ylabel('dice value')
plt.xlabel('No. epoch')
plt.legend(loc="upper left")
plt.show()

In [None]:
#------------------------------------------------------------
# Helper functions for visualizing segmentations
#------------------------------------------------------------
def getBGRWithOverlay(image, mask, alpha=0.5, color=(0,255,0)):
    overlay = image.copy()
    output = image.copy()
    overlay[:,:,0][mask!=0] = color[0]
    overlay[:,:,1][mask!=0] = color[1]
    overlay[:,:,2][mask!=0] = color[2]
    cv2.addWeighted(overlay, alpha, output, 1-alpha, 0, output)
    return output
def getGrayWithOverlay(image, mask, alpha=0.5, color=(0,255,0)):
    image = grayToBGR(image)
    return getBGRWithOverlay(image, mask, alpha=alpha, color=color)

def grayToBGR(gray):
    grayDims = gray.shape
    grayBGR = np.zeros((grayDims[0], grayDims[1], 3), np.uint8)
    grayBGR[:,:,0] = gray
    grayBGR[:,:,1] = gray
    grayBGR[:,:,2] = gray
    return grayBGR

In [None]:
#------------------------------------------------------------
# Load training set for visualization
#------------------------------------------------------------
imgDirPath_train = "data/train/image/dummy_class"
imgFilePathList_train = glob.glob(os.path.join(imgDirPath_train, "*.png"))
imgFilePathList_train = sorted(imgFilePathList_train)

maskDirPath_train = "data/train/mask/dummy_class"
maskFilePathList_train = glob.glob(os.path.join(maskDirPath_train, "*.png"))
maskFilePathList_train = sorted(maskFilePathList_train)

In [None]:
#------------------------------------------------------------
# Load the weights from an early epoch
#------------------------------------------------------------
liverWeightsFilePath = "/content/training_output/liver_model.05-0.26-0.13.h5"
modelLiver = get_simple_unet((256,256))
modelLiver.load_weights(liverWeightsFilePath)

In [None]:
#------------------------------------------------------------
# Show an example segmentation
#------------------------------------------------------------
img_index = 250
imgFilePath = imgFilePathList_train[img_index]
img = cv2.imread(imgFilePath,0)
img = cv2.resize(img, (256,256))
imgOrig = img.copy()
img = img[np.newaxis, :, :, np.newaxis]

# predicted mask
resultLiver = modelLiver.predict(img)[0,:,:,0]
predictedmaskLiver = resultLiver > 0.5

# ground truth mask
maskFilePath = maskFilePathList_train[img_index]
mask_orig = cv2.imread(maskFilePath,0)
mask_orig = cv2.resize(mask_orig, (256,256))

overlay_predicted = getGrayWithOverlay(imgOrig, predictedmaskLiver, alpha=0.5, color=(0,255,0))
overlay_groundtruth = getGrayWithOverlay(imgOrig, mask_orig, alpha=0.5, color=(0,255,0))

imshow(overlay_predicted)
show()
imshow(overlay_groundtruth)
show()
print(imgFilePath)

In [None]:
#------------------------------------------------------------
# Load the weights from a late epoch
#------------------------------------------------------------
liverWeightsFilePath = "/content/training_output/liver_model.99-0.03-0.62.h5"
modelLiver = get_simple_unet((256,256))
modelLiver.load_weights(liverWeightsFilePath)

In [None]:
#------------------------------------------------------------
# Show an example segmentation
#------------------------------------------------------------
img_index = 250
imgFilePath = imgFilePathList_train[img_index]
img = cv2.imread(imgFilePath,0)
img = cv2.resize(img, (256,256))
imgOrig = img.copy()
img = img[np.newaxis, :, :, np.newaxis]

# predicted mask
resultLiver = modelLiver.predict(img)[0,:,:,0]
predictedmaskLiver = resultLiver > 0.5

# ground truth mask
maskFilePath = maskFilePathList_train[img_index]
mask_orig = cv2.imread(maskFilePath,0)
mask_orig = cv2.resize(mask_orig, (256,256))

overlay_predicted = getGrayWithOverlay(imgOrig, predictedmaskLiver, alpha=0.5, color=(0,255,0))
overlay_groundtruth = getGrayWithOverlay(imgOrig, mask_orig, alpha=0.5, color=(0,255,0))

imshow(overlay_predicted)
show()
imshow(overlay_groundtruth)
show()
print(imgFilePath)