In [None]:
import os
import random

import functools
import tensorflow as tf
from keras.models import Model, load_model
from keras.layers import Input, BatchNormalization, Activation, Dropout, Lambda, SpatialDropout2D
from keras.layers.convolutional import Conv2D, UpSampling2D
from keras.layers.pooling import MaxPooling2D
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import optimizers
from keras import backend as K
from keras.layers.merge import concatenate

from keras.losses import BinaryCrossentropy
import matplotlib.pyplot as plt

In [None]:
!ls ../input/imagedata/

In [None]:
import nibabel as nib
import math
import cv2

In [None]:
import numpy as np
import keras
#from datagenerator import SequenceData
from tensorflow.keras.utils import Sequence
from keras.callbacks import EarlyStopping, ModelCheckpoint

In [None]:
# coding=Big5
#trainset_path = 'trainset/C2_TrainDev/Train/'
print("***csv_dev :2***")
class SequenceData(keras.utils.Sequence):

    def __init__(self, model, im_dir, label_dir, im_list ,target_size, batch_size, shuffle=True):
        self.model = model
        self.datasets = []
        self.im_dataset_path = im_dir
        self.label_dataset_path = label_dir
        self.datasets = im_list
        self.image_size = target_size[0:2]
        self.batch_size = batch_size
        self.indexes = np.arange(len(self.datasets))
        self.shuffle = shuffle
        



    def __len__(self):
        
        num_imgs = len(self.datasets)
        return math.ceil(num_imgs / float(self.batch_size))

    def __getitem__(self, idx):
       
        batch_indexs = self.indexes[idx *
                                    self.batch_size:(idx + 1) * self.batch_size]
        
        batch = [self.datasets[k] for k in batch_indexs]
        
        X, y = self.data_generation(batch)
        return X, y

    def on_epoch_end(self):
     
        if self.shuffle:
            np.random.shuffle(self.indexes)

   

    def data_generation(self, batch_datasets):
        images = []
        labels = []

        X, y = self.get_data(batch_datasets)
        # images.append(image)
        # labels.append(label)

        # X = np.array(images)
        # y = np.array(labels)

        return X, y

    def get_data(self, imgs):

        # Ground Truth for entire data (num_data, 7, 7, class+5)
        
        Gt_list = np.zeros((len(imgs), 240, 240, 1))
        im_index = 0
        #imgs = ['02176.jpg']
        Img_list = np.zeros(shape=(len(imgs), 240, 240, 1))
        
        for img in imgs:
            gray_img = cv2.imread(self.im_dataset_path + img, cv2.IMREAD_GRAYSCALE)
            label = cv2.imread(self.label_dataset_path + img, cv2.IMREAD_GRAYSCALE)
            """"
            img_resize = cv2.resize(gray_img, (240, 240),
                                    interpolation=cv2.INTER_LINEAR)
                                    
            label_resize = cv2.resize(label, (240, 240),
                                    interpolation=cv2.INTER_LINEAR)
            """
            img_resize = gray_img / 255.0
            
            label_resize = label / 255.0
            label_resize = np.ceil(label_resize)
            
            img_resize = img_resize.reshape((img_resize.shape[0], img_resize.shape[1], 1))
            label_resize = label_resize.reshape((label_resize.shape[0], label_resize.shape[1], 1))
           
            Gt_list[im_index] = label_resize
            Img_list[im_index] =  img_resize
            im_index += 1
       
        return Img_list, Gt_list


In [None]:
def draw(history):
    print(history.history.keys())
    fig1 = plt.figure()
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Val'], loc='upper left')
    plt.savefig("./Loss_dice_s.png")
    
    fig2 = plt.figure()
    plt.plot(history.history['dice_coef'])
    plt.plot(history.history['val_dice_coef'])
    plt.title('Model Dice_coef')
    plt.ylabel('Dice_coef')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Val'], loc='upper left')
    plt.savefig("./Dice_dice_s.png")
    
    fig3 = plt.figure()
    plt.plot(history.history['precision'])
    plt.plot(history.history['val_precision'])
    plt.title('Model precision')
    plt.ylabel('precision')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Val'], loc='upper left')
    plt.savefig("./precision_dice_s.png")
    
    fig4 = plt.figure()
    plt.plot(history.history['recall'])
    plt.plot(history.history['val_recall'])
    plt.title('Model recall')
    plt.ylabel('recall')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Val'], loc='upper left')
    plt.savefig("./recall_dice_s.png")

In [None]:
def Convolution(input_tensor, filters, drop = 0.0):
    
    x = Conv2D(filters=filters, kernel_size=(3,3), padding='same', strides=(1,1))(input_tensor)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = SpatialDropout2D(drop)(x)
    return x

def unet(input_shape):
    
    inputs = Input((input_shape))
    
    conv_1 = Convolution(inputs, 16 * 1)  #Origin 32
    conv_1 = Convolution(conv_1, 16 * 1) 
    maxp_1 = MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='same')(conv_1)
    
    conv_2 = Convolution(maxp_1, 32 * 1)  #Origin 64
    conv_2 = Convolution(conv_2, 32 * 1)
    maxp_2 = MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='same')(conv_2)
    
    conv_3 = Convolution(maxp_2, 64 * 1, 0.5)  #Origin 128
    conv_3 = Convolution(conv_3, 64 * 1, 0.5)
    maxp_3 = MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='same')(conv_3)
    #maxp_3 = Dropout(0.5)(maxp_3)
    
    conv_4 = Convolution(maxp_3, 128 * 1, 0.5)  #Origin 256
    conv_4 = Convolution(conv_4, 128 * 1, 0.5)
    maxp_4 = MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='same')(conv_4)
    #3maxp_4 = Dropout(0.5)(maxp_4)
    
    conv_5 = Convolution(maxp_4, 256 * 1, 0.5)   #Origin 512
    conv_5 = Convolution(conv_5, 256 * 1, 0.5)
    
    # deconv1
    upsample_6 = UpSampling2D((2,2))(conv_5)
    upsample_6 = Convolution(upsample_6, 128 * 1, 0.5)  #CH:128 #Origin 256
    
    # Merge 1
    conv_6 = concatenate([upsample_6, conv_4])  #128 + 128
    #conv_6 = Dropout(0.5)(conv_6)
    conv_6 = Convolution(conv_6, 128 * 1, 0.5)
    conv_6 = Convolution(conv_6, 128 * 1, 0.5)
    
    # deconv 2
    upsample_7 = UpSampling2D((2,2))(conv_6)
    upsample_7 = Convolution(upsample_7, 64 * 1, 0.5) #64
    
    # Merge 2
    upsample_7 = concatenate([upsample_7, conv_3]) #64 + 64
    #upsample_7 = Dropout(0.5)(upsample_7)
    conv_7 = Convolution(upsample_7, 64 * 1, 0.5)  #Origin 128
    conv_7 = Convolution(conv_7, 64 * 1, 0.5)  #Origin 128
    
    # deconv 3
    upsample_8 = UpSampling2D((2,2))(conv_7)
    upsample_8 = Convolution(upsample_8, 32 * 1, 0.5)
    
    # Merge 3
    upsample_8 = concatenate([upsample_8, conv_2]) #32 + 32
    #upsample_8  = Dropout(0.5)(upsample_8)
    conv_8 = Convolution(upsample_8, 32 * 1, 0.5)  #CH = 32 #Origin 64
    conv_8 = Convolution(conv_8, 32 * 1, 0.5)
    
    # deconv 4
    upsample_9 = UpSampling2D((2,2))(conv_8)
    upsample_9 = Convolution(upsample_9, 16 * 1, 0.5)
    
    # Merge 4
    upsample_9 = concatenate([upsample_9, conv_1])
    #upsample_9 = Dropout(0.5)(upsample_9)
    conv_9 = Convolution(upsample_9, 16 * 1, 0.5)  #Origin 16
    
    conv_10 = Convolution(conv_9, 16 * 1, 0.5)
    
    outputs = Conv2D(2, (1,1), activation='softmax')(conv_10)
    #outputs = outputs[:, :, :, 0]
    
    
    model = Model(inputs=[inputs], outputs=[outputs])
    
    return model


def dice_coef(y_true, y_pred, smooth=1.0):
    y_true = y_true[:, :, :, 0]
    y_pred = y_pred[:, :, :, 0]
    
    y_true_f = K.round(K.flatten(y_true)) # K.round(y_true) #
    y_pred_f= K.round(K.flatten(y_pred)) # K.round(y_pred) #
    intersection = K.sum(y_true_f * y_pred_f)
    
    return (2. * intersection + K.epsilon()) / (K.sum(y_true_f) + K.sum(y_pred_f) + K.epsilon())

# Computing Precision
def precision(y_true, y_pred):
    
    y_true = y_true[:, :, :, 0]
    y_pred = y_pred[:, :, :, 0]
    
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = (true_positives ) / (predicted_positives + K.epsilon())
    
    return precision

# Computing Sensitivity
def recall(y_true, y_pred):
    
    y_true = y_true[:, :, :, 0]
    y_pred = y_pred[:, :, :, 0]
    
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    actual_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    
    return (true_positives ) / (actual_positives + K.epsilon())    
    
def dice_loss(y_true, y_pred, smooth=1.0):
    
    y_true = y_true[:, :, :, 0]
    y_pred = y_pred[:, :, :, 0]
    
    y_true_f = K.round(K.flatten(y_true))
    y_pred_f = K.flatten(y_pred)
    
    
    intersection = K.sum(y_true_f * y_pred_f)
    dice = (2. * intersection + K.epsilon()) / (K.sum(y_true_f) + K.sum(y_pred_f) + K.epsilon())
    
    return 1 - dice

    
def dice_bin_loss(y_true, y_pred):

    bce = BinaryCrossentropy(from_logits=True)(y_true, y_pred)

    return bce + dice_loss(y_true, y_pred)

def focal_loss(y_true, y_pred, gamma=2., alpha=.25):

  y_true = K.flatten(y_true)
  y_pred = K.flatten(y_pred)

  pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
  pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
  focal_loss_fixed =  -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1+K.epsilon())) - K.sum((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon())) / 200000
  return focal_loss_fixed

def gdice(y_true, y_pred):
  
  y_true_f = K.flatten(y_true)
  y_pred_f = K.flatten(y_pred)
  weights = 1./K.square(K.sum(y_true_f) + 1e-6)
  #weights = weights/K.sum(weights)
  num = weights*K.sum(y_true_f * y_pred_f)
  den = weights*K.sum(y_true_f + y_pred_f)
  return 1 - 2.*(num + K.epsilon()) / (den + K.epsilon())
    


In [None]:
    
    EPOCH = 40
    im_path =  '../input/imdata/im/'
    label_path = '../input/segdata/seg/'
    pretrained_weight = '../input/smoothdice40/Agdice100_s.h5'
    
    BATCH_SIZE = 16

    input_shape = (240, 240, 1)
   
    model = unet(input_shape=(240,240,1))
    
    
    lr_schedule = optimizers.schedules.ExponentialDecay(
                                                        initial_learning_rate=1e-4,
                                                        decay_steps=2496 * 15,
                                                        decay_rate=0.2)
    
    
    
    Adam = optimizers.Adam(learning_rate = lr_schedule)
                           
    alpha = 0.0001  # weight decay coefficient
    
    
    for layer in model.layers:
        if isinstance(layer, keras.layers.Conv2D) :
            #layer.add_loss(lambda: keras.regularizers.l2(alpha)(layer.kernel))
            setattr(layer, 'kernel_regularizer', keras.regularizers.l2(alpha))
        if hasattr(layer, 'bias_regularizer') and layer.use_bias:
            #layer.add_loss(lambda: keras.regularizers.l2(alpha)(layer.bias))
            setattr(layer, 'bias_regularizer', keras.regularizers.l2(alpha))
                                   
                           
                           
                           
    model.compile(optimizer=Adam, loss=dice_loss, metrics=[dice_coef, precision, recall])


    # Spilt validation from dataset
    
    dataset = sorted(os.listdir(im_path))
    
    #dataset = dataset[0: 1009]
    val_len = int(len(dataset) * 0.2)
    test_len = int(len(dataset) * 0.1)
    train_len = len(dataset) - val_len - test_len
    
    val_dataset = dataset[: val_len]
    train_dataset = dataset[val_len: val_len + train_len]
    random.shuffle(train_dataset)
    
    
    
    print(len(train_dataset), len(val_dataset))

    train_generator = SequenceData('train', im_path, label_path, train_dataset, input_shape, BATCH_SIZE)
    val_generator = SequenceData('val', im_path, label_path, val_dataset, input_shape, BATCH_SIZE)


    #model.load_weights(pretrained_weight)
    model.summary()



    es = EarlyStopping(monitor='val_loss',
                       mode='min',
                       patience=20,
                       restore_best_weights=True)

    history = model.fit(train_generator,
                        epochs=EPOCH,
                        steps_per_epoch=len(train_generator),
                        validation_data=val_generator,
                        validation_steps=len(val_generator),
                        callbacks = [es]
                        )
    weight_name = 'final_s.h5'
    model.save(weight_name)   
    print(f'Model saved! {weight_name}')
    print(history.history['loss'])
    draw(history) 
    
    

In [None]:

m_dataset_path = '../input/brats20-dataset-training-validation/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/'

output_path = './output/Test/'

def resize_img(img):
    
    img = np.rint(img).astype(np.uint8)
    #img = cv2.resize(img, (240, 240), interpolation=cv2.INTER_LINEAR)
    img = img / 255.0
    #print(img.dtype, img[100][100])
    
    img = img.reshape((1, img.shape[0], img.shape[1], 1))
    
    return img

def test_img(img_name, model, im_path, label_path):
    
    
    #print("Predict ", img_name)
    #####   Read image #####
    img = cv2.imread(im_path + img_name, cv2.IMREAD_GRAYSCALE)
    label = cv2.imread(label_path + img_name, cv2.IMREAD_GRAYSCALE)
    img = img / 255.0
    label = np.ceil(label / 255)
    
    
    img = img.reshape((1, img.shape[0], img.shape[1], 1))
    print(img.shape)
    label = label.reshape((1, label.shape[0], label.shape[1], 1))
    #label = np.ceil(label)
    #model.evaluate(img) 
    mask = model.predict(img)
    
    mask = np.round(mask)
    
    TP = np.sum(mask * label)
    AP = np.sum(label)
    PP = np.sum(mask)
    
    print('Predict ', img_name, ' dice: ', TP,' ', AP,' ', PP,' ', 2 * (TP + 1e-8) / (AP + PP + 1e-8))
   
    return TP, AP, PP
    


In [None]:
    """"
    im_path =  '../input/imdata/im/'
    label_path = '../input/segdata/seg/'
    
    
    dataset = sorted(os.listdir(im_path))
    
    #dataset = dataset[0: 1009]
    val_len = int(len(dataset) * 0.2)
    test_len = int(len(dataset) * 0.1)
    train_len = len(dataset) - val_len - test_len
    print(val_len, train_len, test_len)
    
    #test_dataset = dataset[74*155: 332*155]
    test_dataset = dataset[332*155: ]
    print(len(test_dataset))
    
    val_len = int(len(dataset) * 0.2)
    test_len = int(len(dataset) * 0.1)
    train_len = len(dataset) - val_len - test_len
    
    val_dataset = dataset[: val_len]
    train_dataset = dataset[val_len: val_len + train_len]
    
    weights_path = '../input/smoothdice40/Agdice100_s.h5'
    
    model = unet(input_shape=(240, 240, 1))
    """"
    Adam = optimizers.Adam(lr=1e-4)
    model.compile(optimizer=Adam, loss=dice_loss, metrics=[dice_coef, precision, recall])
    """
    model.load_weights(weights_path)
    
    
    TP3D = 0
    AP3D = 0
    PP3D = 0
    count = 1
    dice3D = 0.0
    precision3D = 0.0
    recall3D = 0.0
    
    
    meanprecision = 0.0
    meanrecall = 0.0
    meandice = 0.0
    
    for name in test_dataset:
        
        TP, AP, PP =test_img(name, model, im_path, label_path)
        
        if count < 155:
            TP3D += TP
            AP3D += AP
            PP3D += PP
            
        else:
            dice3D = ( 2 * TP3D + 1e-8) / (AP3D + PP3D + 1e-8)
            
            precision3D = (TP3D) / (PP3D + 1e-8)
            recall3D = (TP3D) / (AP3D + 1e-8)
            print('dice3D: ', dice3D, ' precision3D: ', precision3D, ' recall3D: ', recall3D)
            
            
            meandice += dice3D
            meanprecision += precision3D
            meanrecall += recall3D
            
            TP3D = 0
            AP3D = 0
            PP3D = 0
            count = 0
        
        count += 1
    
    meandice /= 36
    print("result: ", meandice, ' meanprecision: ', meanprecision / 36, ' meanrecall: ', meanrecall / 36)
    """
    

In [None]:
    """"
    weights_path = '../input/dice40/Abin_s2.h5'
    
    model = unet(input_shape=(240, 240, 1))
    
    model.load_weights(weights_path)
    
    #print("Read ", img_name) #057011
    #####   Read image #####
    gray_img = cv2.imread('../input/imdata/im/057011.jpg', cv2.IMREAD_GRAYSCALE)
    print(gray_img.shape)
    #origin_shape = gray_img.shape
    label = cv2.imread('../input/segdata/seg/057011.jpg', cv2.IMREAD_GRAYSCALE)
          
    #gray_img_resize = cv2.resize(gray_img, (240, 240), interpolation=cv2.INTER_LINEAR)
                                    
    #label = cv2.resize(label, (240, 240), interpolation=cv2.INTER_LINEAR)
          
    gray_img = gray_img / 255.0
    
            
    gray_img_resize = gray_img.reshape((1, gray_img.shape[0], gray_img.shape[1], 1))
    print(gray_img_resize.shape)
    
    predict_mask = model.predict(gray_img_resize)
    #print(predict_mask)
    gray_img = gray_img * 255.0
   
    predict_mask = np.round(predict_mask)
    predict_mask = predict_mask * 255.0  
    predict_mask = np.squeeze(predict_mask) 
    #predict_mask = cv2.resize(predict_mask, gray_img.shape, interpolation=cv2.INTER_LINEAR)
    predict_mask = np.ceil(predict_mask)
    
    
    fig1 = plt.figure()
    plt.subplot(2, 2, 1)
    plt.title('Origin')
    plt.axis('off')
    plt.imshow(np.squeeze(gray_img),cmap='gray')
    
    plt.subplot(2, 2, 2)
    plt.title('Predicted Mask')
    plt.axis('off')
    plt.imshow(np.squeeze(gray_img),cmap='gray')
    plt.imshow(np.squeeze(predict_mask),alpha=0.3,cmap='Reds')
    #plt.imshow(np.squeeze(label),alpha=0.3,cmap='Greens')
    
    plt.subplot(2, 2, 3)
    plt.title('Ground Truth')
    plt.axis('off')
    plt.imshow(np.squeeze(gray_img),cmap='gray')
    #plt.imshow(np.squeeze(predict_mask),alpha=0.3,cmap='Reds')
    plt.imshow(label,alpha=0.3,cmap='Greens')
    
    
    plt.subplot(2, 2, 4)
    plt.title('Overlap')
    plt.axis('off')
    plt.imshow(np.squeeze(gray_img),cmap='gray')
    plt.imshow(np.squeeze(predict_mask),alpha=0.3,cmap='Reds')
    plt.imshow(np.squeeze(label),alpha=0.3,cmap='Greens')
    file = './TEST' + img_name + '.jpg' 
    plt.savefig(file)
    """