In [None]:
import os
import nibabel as nib
from tqdm import tqdm
import numpy as np
import pandas as pd

# from sklearn.preprocessing import StandardScaler
# from sklearn.decomposition import PCA

# from keras.preprocessing.image import ImageDataGenerator
# from keras.utils import to_categorical
from keras.layers.merge import concatenate
from sklearn.model_selection import train_test_split

import tensorflow as tf
import keras
from keras.models import Model, load_model
from keras.layers import Input, BatchNormalization, Activation
from keras.layers.convolutional import Conv2D, UpSampling2D
from keras.layers.pooling import MaxPooling2D
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import optimizers
from keras import backend as K

import cv2 as cv2

import matplotlib.pyplot as plt
%matplotlib inline
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

In [None]:
def load_data(path, start, stop):
    """
    param path: path of the training dataset
    returns:
        data: files of type flair, t1. t1_ce and t2
        gt: segmented tumor in the file types
    """
    path = path + 'BraTS20_Training_'
#     my_dir = sorted(os.listdir(path))
    data = []
    for p in tqdm(range(start,stop+1)):
        
        p = str(p).zfill(3)+'/'
        data_list = sorted(os.listdir(path+p))
        
#         flair = np.array(nib.load(path+p+'/'+data_list[0]).get_fdata())
        
        seg = np.array(nib.load(path+p+'/'+data_list[1]).get_fdata())
        
        t1 = np.array(nib.load(path+p+'/'+data_list[2]).get_fdata())
        
#         t1ce = np.array(nib.load(path+p+'/'+data_list[3]).get_fdata())
        
#         t2 = np.array(nib.load(path+p+'/'+data_list[4]).get_fdata())
        
#         data.append([flair, t1, t1ce, t2, seg])
        data.append([t1, seg])
    data = np.array(data)
    data = np.rint(data).astype(np.int16)
    data = data[:, :, :, :]
    data = np.transpose(data)
    return data

In [None]:
path= '../input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/'
data = load_data(path,1,20)

In [None]:
data.shape, data.dtype

In [None]:
data = np.transpose(data, (4,0,1,2,3))
print(data.shape)

In [None]:
fig = plt.figure(figsize=(5,5))
immmg = data[5][100,:,:,0]
imgplot = plt.imshow(immmg, 'gray')
plt.show()

In [None]:
def Data_Concatenate(input_data):
    counter = 0
    output = []
    for i in range(2):
        print('$')
        c=0; counter=0;
        for ii in range(len(input_data)):
            if (counter < len(input_data)):
                a = input_data[counter][:,:,:,i]
                b = input_data[counter+1][:,:,:,i]
                
                if (counter == 0):
                    c = np.concatenate((a,b), axis=0)
                    print('c1={}'.format(c.shape))
                    counter += 2
                else:
                    c1 = np.concatenate((a,b), axis=0)
                    c = np.concatenate((c,c1), axis=0)
                    print('c2={}'.format(c.shape))
                    counter += 2
        c = c[:,:,:,np.newaxis]
        output.append(c)
    return output

In [None]:
indata = Data_Concatenate(data)

In [None]:
AIO = concatenate(indata, axis=3)
AIO = np.array(AIO, dtype=np.float32)
TR = np.array(AIO[:,:,:,0], dtype=np.float32)
TRL = np.array(AIO[:,:,:,1], dtype=np.float32)

In [None]:
X_train, X_test, Y_train, Y_test = train_test_split(TR, TRL, test_size=0.15, random_state=32)
AIO=TRL=0
print(X_train.shape, Y_train.shape, X_test.shape, Y_test.shape)

In [None]:
fig = plt.figure(figsize=(15,8))
ax1 = fig.add_subplot(121)
ax1.imshow(X_train[190],'gray')

ax2 = fig.add_subplot(122)
ax2.imshow(Y_train[190],'gray')

In [None]:
# Converting original image to Stationary wavelet transformed image
from pywt import swt2

for i in range(len(X_train)):
    c = swt2(data=X_train[i],wavelet='db1',level=1)
    X_train[i] = c[0][0]
    c=0

for i in range(len(X_test)):
    c = swt2(data=X_test[i], wavelet='db1',level=1)
    X_test[i] = c[0][0]
    c=0

In [None]:
X_train[0].shape, X_train[0].dtype

## U-Net Model Implementation

In [None]:
def Convolution(input_tensor, filters):
    
    x = Conv2D(filters=filters, kernel_size=(3,3), padding='same', strides=(1,1))(input_tensor)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    return x

def model(input_shape):
    
    inputs = Input((input_shape))
    
    conv_1 = Convolution(inputs, 32)
    maxp_1 = MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='same')(conv_1)
    
    conv_2 = Convolution(maxp_1, 64)
    maxp_2 = MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='same')(conv_2)
    
    conv_3 = Convolution(maxp_2, 128)
    maxp_3 = MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='same')(conv_3)
    
    conv_4 = Convolution(maxp_3, 256)
    maxp_4 = MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='same')(conv_4)
    
    conv_5 = Convolution(maxp_4, 512)
    upsample_6 = UpSampling2D((2,2))(conv_5)
    
    conv_6 = Convolution(upsample_6, 256)
    upsample_7 = UpSampling2D((2,2))(conv_6)
    
    upsample_7 = concatenate([upsample_7, conv_3])
    
    conv_7 = Convolution(upsample_7, 128)
    upsample_8 = UpSampling2D((2,2))(conv_7)
    
    conv_8 = Convolution(upsample_8, 64)
    upsample_9 = UpSampling2D((2,2))(conv_8)
    
    upsample_9 = concatenate([upsample_9, conv_1])
    
    conv_9 = Convolution(upsample_9, 32)
    outputs = Conv2D(1, (1,1), activation='sigmoid')(conv_9)
    
    model = Model(inputs=[inputs], outputs=[outputs])
    
    return model

In [None]:
# Loading the Light weighted CNN
model = model(input_shape=(240,240,1))
# model.summary()

In [None]:
# Computing Dice_Coefficient
def dice_coef(y_true, y_pred, smooth=1.0):
    
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

# Computing Precision
def precision(y_true, y_pred):
    
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    
    return precision

# Computing Sensitivity
def sensitivity(y_true, y_pred):
    
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    
    return true_positives / (possible_positives + K.epsilon())

# Computing Specificity
def specificity(y_true, y_pred):
    
    true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1)))
    possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1)))
    
    return true_negatives / (possible_negatives + K.epsilon())

In [None]:
# Compiling the model
Adam = optimizers.Adam(lr=0.001)
model.compile(optimizer=Adam, loss='binary_crossentropy', metrics=['accuracy', dice_coef, precision, sensitivity, specificity])

In [None]:
# Fitting the model over the data

history = model.fit(X_train, Y_train, batch_size=32, epochs=40, validation_split=0.20,verbose=1,initial_epoch=0)

In [None]:
# Evaluating the model on the training and testing data
model.evaluate(x=X_train, y=Y_train, batch_size=32, verbose=1, sample_weight=None, steps=None)
model.evaluate(x=X_test, y=Y_test, batch_size=32, verbose=1, sample_weight=None, steps=None)

In [None]:
# Accuracy vs Epoch
def Accuracy_Graph(history):
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    #plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='lower right')
    plt.subplots_adjust(top=1.00, bottom=0.0, left=0.0, right=0.95, hspace=0.25,
                        wspace=0.35)
    plt.show()
    
# Dice Similarity Coefficient vs Epoch
def Dice_coefficient_Graph(history):

    plt.plot(history.history['dice_coef'])
    plt.plot(history.history['val_dice_coef'])
    #plt.title('Dice_Coefficient')
    plt.ylabel('Dice_Coefficient')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    plt.subplots_adjust(top=1.00, bottom=0.0, left=0.0, right=0.95, hspace=0.25,
                        wspace=0.35)
    plt.show()
    
# Precision vs Epoch
def Precision_Graph(history):

    plt.plot(history.history['precision'])
    plt.plot(history.history['val_precision'])
    #plt.title('Dice_Coefficient')
    plt.ylabel('Precision')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='lower left')
    plt.subplots_adjust(top=1.00, bottom=0.0, left=0.0, right=0.95, hspace=0.25,
                        wspace=0.35)
    plt.show()

# Loss vs Epoch
def Loss_Graph(history):

    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    #plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper right')
    plt.subplots_adjust(top=1.00, bottom=0.0, left=0.0, right=0.95, hspace=0.25,
                        wspace=0.35)
    plt.show()

In [None]:
# Plotting the Graphs of Accuracy, Dice_coefficient, Loss at each epoch on Training and Testing data
Accuracy_Graph(history)
Dice_coefficient_Graph(history)
Loss_Graph(history)

In [None]:
model.save('./BraTs2020_swt_db1_l1.h5')

In [None]:
model.load_weights('../input/swt-db1-weights/BraTs2020_swt_haar_l1.h5')

In [None]:
X_train=X_test=Y_train=Y_test=0

In [None]:
fig = plt.figure(figsize=(5,5))
immmg = TR[210,:,:]
imgplot = plt.imshow(immmg)
plt.show()

In [None]:
from pywt import swt2
for i in range(len(TR)):
    c = swt2(data=TR[i],wavelet='db1',level=1)
    TR[i] = c[0][0]
    c=0

In [None]:
pref_tumor = model.predict(TR)

In [None]:
a=94
plt.figure(figsize=(15,10))

plt.subplot(121)
plt.title('Sample 1')
plt.axis('off')
plt.imshow(np.squeeze(TR[a,:,:]),cmap='gray')
plt.imshow(np.squeeze(pref_tumor[a,:,:]),alpha=0.3,cmap='Reds')

plt.subplot(122)
plt.title('Original MRI')
plt.axis('off')
plt.imshow(np.squeeze(TR[a,:,:]),cmap='gray')

In [None]:
plt.figure(figsize=(10,7))
plt.title('Original MRI with tumor highlighted')
plt.axis('off')
plt.imshow(np.squeeze(TR[a,:,:]),cmap='gray')
plt.imshow(np.squeeze(TRL[a,:,:]),alpha=0.3,cmap='Reds')
plt.show()