In [None]:
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

import keras
import keras.backend as K
from keras.callbacks import CSVLogger
import tensorflow as tf
from tensorflow.keras.utils import plot_model
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TensorBoard
from tensorflow.keras.layers.experimental import preprocessing
import cv2

In [None]:
import pandas as pd

In [None]:
IMG_SIZE=128
SEGMENT_CLASSES = {
    0 : 'NOT tumor',
    1 : 'NECROTIC/CORE', 
    2 : 'EDEMA',
    3 : 'ENHANCING'}


VOLUME_SLICES = 128
VOLUME_START_AT = 22

In [None]:
TRAIN_DATASET_PATH = '/Users/sanjaydilli/Downloads/archive/BraTS2021_Training_Data'


In [None]:
import os
train_and_val_directories = [f.path for f in os.scandir(TRAIN_DATASET_PATH) if f.is_dir()]

In [None]:
def pathListIntoIds(dirList):
    x = []
    for i in range(0,len(dirList)):
        x.append(dirList[i][dirList[i].rfind('/')+1:])
    return x

train_and_test_ids = pathListIntoIds(train_and_val_directories);


train_test_ids, val_ids = train_test_split(train_and_test_ids,test_size=0.2)
train_ids, test_ids = train_test_split(train_test_ids,test_size=0.15)

In [None]:

def dice_coef(y_true, y_pred, smooth=1.0):
    class_num = 4
    for i in range(class_num):
        y_true_f = K.flatten(y_true[:,:,:,i])
        y_pred_f = K.flatten(y_pred[:,:,:,i])
        intersection = K.sum(y_true_f * y_pred_f)
        loss = ((2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth))
        if i == 0:
            total_loss = loss
        else:
            total_loss = total_loss + loss
    total_loss = total_loss / class_num
    return total_loss

 

def dice_coef_necrotic(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,1] * y_pred[:,:,:,1]))
    return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,1])) + K.sum(K.square(y_pred[:,:,:,1])) + epsilon)

def dice_coef_edema(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,2] * y_pred[:,:,:,2]))
    return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,2])) + K.sum(K.square(y_pred[:,:,:,2])) + epsilon)

def dice_coef_enhancing(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,3] * y_pred[:,:,:,3]))
    return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,3])) + K.sum(K.square(y_pred[:,:,:,3])) + epsilon)


def precision(y_true, y_pred):
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision
     
def sensitivity(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    return true_positives / (possible_positives + K.epsilon())



def specificity(y_true, y_pred):
    true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1)))
    possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1)))
    return true_negatives / (possible_negatives + K.epsilon())

In [None]:
keras = tf.compat.v1.keras
Sequence = keras.utils.Sequence

class DataGenerator(Sequence):
    
    def __init__(self, list_IDs, dim=(IMG_SIZE,IMG_SIZE), batch_size = 1, n_channels = 2, shuffle=True):
       
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        
        Batch_ids = [self.list_IDs[k] for k in indexes]

       
        X, y = self.__data_generation(Batch_ids)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, Batch_ids):
       
       
        X = np.zeros((self.batch_size*VOLUME_SLICES, *self.dim, self.n_channels))
        y = np.zeros((self.batch_size*VOLUME_SLICES, 128, 128))
        Y = np.zeros((self.batch_size*VOLUME_SLICES, *self.dim, 4))
        
        for c, i in enumerate(Batch_ids):
            case_path = os.path.join(TRAIN_DATASET_PATH, i)

            data_path = os.path.join(case_path, f'{i}_flair.nii.gz');
            flair = nib.load(data_path).get_fdata()

            data_path = os.path.join(case_path, f'{i}_t1ce.nii.gz');
            ce = nib.load(data_path).get_fdata()

            data_path = os.path.join(case_path, f'{i}_seg.nii.gz');
            seg = nib.load(data_path).get_fdata()

            for j in range(VOLUME_SLICES):
             X[j+(VOLUME_SLICES*c),:,:,0] = cv2.resize(flair[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE))

             X[j+(VOLUME_SLICES*c),:,:,1] = cv2.resize(ce[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE))


             y[j +VOLUME_SLICES*c,:,:] = cv2.resize(seg[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE))
       
       
        y[y==4] = 3;
        y = tf.one_hot(y, 4);
        return X/np.max(X), y

training_generator = DataGenerator(train_ids)
valid_generator = DataGenerator(val_ids)
test_generator = DataGenerator(test_ids)

In [None]:
csv_logger = CSVLogger('training.log', separator=',', append=False)


callbacks = [keras.callbacks.EarlyStopping(monitor='loss', min_delta=0,
                               patience=10, verbose=1, mode='auto'),
      keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                              patience=10, min_lr=0.0001, verbose=1),
      csv_logger
    ]

![Screenshot%202024-03-15%20at%2011.08.56%E2%80%AFAM.png](attachment:Screenshot%202024-03-15%20at%2011.08.56%E2%80%AFAM.png)

In [None]:
def build_unet(inputs, ker_init, dropout):
    conv1 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(inputs)
    conv1 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv1)
    
    pool = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(pool)
    conv = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv)
    
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(pool1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv2)
    
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(pool2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv3)
    
    
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv5 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(pool4)
    conv5 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv5)
    drop5 = Dropout(dropout)(conv5)

    up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(UpSampling2D(size = (2,2))(drop5))
    merge7 = concatenate([conv3,up7], axis = 3)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(merge7)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv7)

    up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(UpSampling2D(size = (2,2))(conv7))
    merge8 = concatenate([conv2,up8], axis = 3)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(merge8)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv8)

    up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(UpSampling2D(size = (2,2))(conv8))
    merge9 = concatenate([conv,up9], axis = 3)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(merge9)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv9)
    
    up = Conv2D(32, 2, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(UpSampling2D(size = (2,2))(conv9))
    merge = concatenate([conv1,up], axis = 3)
    conv = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(merge)
    conv = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv)
    
    conv10 = Conv2D(4, (1,1), activation = 'softmax')(conv)
    
    return Model(inputs = inputs, outputs = conv10)

input_layer = Input((IMG_SIZE, IMG_SIZE, 2))

In [None]:


model = build_unet(input_layer, 'he_normal', 0.2)
model.compile(loss="categorical_crossentropy", optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=0.001), metrics = ['accuracy',tf.keras.metrics.MeanIoU(num_classes=4),
                                                                                                                dice_coef, precision, sensitivity, specificity, dice_coef_necrotic,
                                                                                                                dice_coef_edema ,dice_coef_enhancing] )

model.summary()

In [None]:
history =  model.fit(training_generator,
                     epochs=30,
                     steps_per_epoch= 300,
                     callbacks= callbacks,
                     validation_data = valid_generator
                     )


In [None]:
model.save('UNET_2')

In [None]:
model = keras.models.load_model('/Users/sanjaydilli/Downloads/UNET.h5', 
                                   custom_objects={ 'accuracy' : tf.keras.metrics.MeanIoU(num_classes=4),
                                                   "dice_coef": dice_coef,
                                                   "precision": precision,
                                                   "sensitivity":sensitivity,
                                                   "specificity":specificity,
                                                   "dice_coef_necrotic": dice_coef_necrotic,
                                                   "dice_coef_edema": dice_coef_edema,
                                                   "dice_coef_enhancing": dice_coef_enhancing
                                                  }, compile=False)


In [None]:
model.compile(loss="categorical_crossentropy", optimizer=keras.optimizers.Adam(learning_rate=0.001), metrics = ['accuracy',tf.keras.metrics.MeanIoU(num_classes=4), dice_coef, precision, sensitivity, specificity, dice_coef_necrotic, dice_coef_edema, dice_coef_enhancing] )


In [None]:
history = pd.read_csv('/Users/sanjaydilli/Documents/UNET/training.log', sep=',', engine='python')


In [None]:
hist=history

In [None]:
acc=hist['accuracy']
val_acc=hist['val_accuracy']

epoch=range(len(acc))

loss=hist['loss']
val_loss=hist['val_loss']

train_dice=hist['dice_coef']
val_dice=hist['val_dice_coef']

f,ax=plt.subplots(1,3,figsize=(16,8))

ax[0].plot(epoch,acc,'b',label='Training Accuracy')
ax[0].plot(epoch,val_acc,'r',label='Validation Accuracy')
ax[0].legend()

ax[1].plot(epoch,loss,'b',label='Training Loss')
ax[1].plot(epoch,val_loss,'r',label='Validation Loss')
ax[1].legend()

ax[2].plot(epoch,train_dice,'b',label='Training dice coef')
ax[2].plot(epoch,val_dice,'r',label='Validation dice coef')
ax[2].legend()



plt.show()

In [None]:
print("Evaluate on test data")
results = model.evaluate(test_generator, callbacks= callbacks)
print("test loss, test acc:", results)

In [None]:
def predictByPath(case_path,case):
    files = next(os.walk(case_path))[2]
    X = np.empty((VOLUME_SLICES, IMG_SIZE, IMG_SIZE, 2))
 
    
    vol_path = os.path.join(case_path, f'BraTS2021{case}_flair.nii.gz');
    flair=nib.load(vol_path).get_fdata()
    
    vol_path = os.path.join(case_path, f'BraTS2021{case}_t1ce.nii.gz');
    ce=nib.load(vol_path).get_fdata() 
    

    
    for j in range(VOLUME_SLICES):
        X[j,:,:,0] = cv2.resize(flair[:,:,j+VOLUME_START_AT], (IMG_SIZE,IMG_SIZE))
        X[j,:,:,1] = cv2.resize(ce[:,:,j+VOLUME_START_AT], (IMG_SIZE,IMG_SIZE))
 
    return model.predict(X/np.max(X), verbose=1)


In [None]:
def showPredictsById(case, start_slice = 60):
    path = f"/Users/sanjaydilli/Downloads/archive/BraTS2021_Training_Data/BraTS2021{case}"
    gt = nib.load(os.path.join(path, f'BraTS2021{case}_seg.nii.gz')).get_fdata()
    origImage = nib.load(os.path.join(path,f'BraTS2021{case}_flair.nii.gz')).get_fdata()
    p = predictByPath(path,case)

    core = p[:,:,:,1]
    edema= p[:,:,:,2]
    enhancing = p[:,:,:,3]

    plt.figure(figsize=(18, 50))
    f, axarr = plt.subplots(1,6, figsize = (18, 50)) 

    for i in range(6):
        axarr[i].imshow(cv2.resize(origImage[:,:,start_slice+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE)), cmap="gray", interpolation='none')
    
    axarr[0].imshow(cv2.resize(origImage[:,:,start_slice+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE)), cmap="gray")
    axarr[0].title.set_text('Original image flair')
    curr_gt=cv2.resize(gt[:,:,start_slice+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE), interpolation = cv2.INTER_NEAREST)
    axarr[1].imshow(curr_gt, cmap="Reds", interpolation='none', alpha=0.3) # ,alpha=0.3,cmap='Reds'
    axarr[1].title.set_text('Ground truth')
    axarr[2].imshow(p[start_slice,:,:,1:4], cmap="Reds", interpolation='none', alpha=0.3)
    axarr[2].title.set_text('all classes')
    axarr[3].imshow(edema[start_slice,:,:], cmap="OrRd", interpolation='none', alpha=0.3)
    axarr[3].title.set_text(f'{SEGMENT_CLASSES[1]} predicted')
    axarr[4].imshow(core[start_slice,:,], cmap="OrRd", interpolation='none', alpha=0.3)
    axarr[4].title.set_text(f'{SEGMENT_CLASSES[2]} predicted')
    axarr[5].imshow(enhancing[start_slice,:,], cmap="OrRd", interpolation='none', alpha=0.3)
    axarr[5].title.set_text(f'{SEGMENT_CLASSES[3]} predicted')
    plt.show()
    

In [None]:

showPredictsById(case=test_ids[1][-6:])
showPredictsById(case=test_ids[2][-6:])
showPredictsById(case=test_ids[3][-6:])
showPredictsById(case=test_ids[4][-6:])
showPredictsById(case=test_ids[5][-6:])
showPredictsById(case=test_ids[6][-6:])
showPredictsById(case=test_ids[7][-6:])
showPredictsById(case=test_ids[8][-6:])
showPredictsById(case=test_ids[9][-6:])
showPredictsById(case=test_ids[10][-6:])
showPredictsById(case=test_ids[11][-6:])
showPredictsById(case=test_ids[12][-6:])
showPredictsById(case=test_ids[13][-6:])
showPredictsById(case=test_ids[14][-6:])
showPredictsById(case=test_ids[15][-6:])
showPredictsById(case=test_ids[16][-6:])
showPredictsById(case=test_ids[17][-6:])
showPredictsById(case=test_ids[18][-6:])
showPredictsById(case=test_ids[19][-6:])
showPredictsById(case=test_ids[20][-6:])
showPredictsById(case=test_ids[21][-6:])
showPredictsById(case=test_ids[22][-6:])
showPredictsById(case=test_ids[23][-6:])
showPredictsById(case=test_ids[24][-6:])
showPredictsById(case=test_ids[25][-6:])
showPredictsById(case=test_ids[26][-6:])
showPredictsById(case=test_ids[27][-6:])
showPredictsById(case=test_ids[28][-6:])
showPredictsById(case=test_ids[29][-6:])
showPredictsById(case=test_ids[30][-6:])
showPredictsById(case=test_ids[31][-6:])
showPredictsById(case=test_ids[32][-6:])
showPredictsById(case=test_ids[33][-6:])
showPredictsById(case=test_ids[34][-6:])
showPredictsById(case=test_ids[35][-6:])
showPredictsById(case=test_ids[36][-6:])
showPredictsById(case=test_ids[37][-6:])
showPredictsById(case=test_ids[38][-6:])
showPredictsById(case=test_ids[39][-6:])
showPredictsById(case=test_ids[40][-6:])
showPredictsById(case=test_ids[41][-6:])
showPredictsById(case=test_ids[42][-6:])
showPredictsById(case=test_ids[43][-6:])
showPredictsById(case=test_ids[44][-6:])
showPredictsById(case=test_ids[45][-6:])
showPredictsById(case=test_ids[46][-6:])
showPredictsById(case=test_ids[47][-6:])
showPredictsById(case=test_ids[48][-6:])
showPredictsById(case=test_ids[55][-6:])
showPredictsById(case=test_ids[56][-6:])
showPredictsById(case=test_ids[57][-6:])
showPredictsById(case=test_ids[58][-6:])
showPredictsById(case=test_ids[59][-6:])
showPredictsById(case=test_ids[60][-6:])
showPredictsById(case=test_ids[61][-6:])
showPredictsById(case=test_ids[62][-6:])
showPredictsById(case=test_ids[63][-6:])
showPredictsById(case=test_ids[64][-6:])
showPredictsById(case=test_ids[65][-6:])
showPredictsById(case=test_ids[66][-6:])




In [None]:
showPredictsById(case=test_ids[7][-6:])
showPredictsById(case=test_ids[8][-6:])
showPredictsById(case=test_ids[9][-6:])
showPredictsById(case=test_ids[10][-6:])
showPredictsById(case=test_ids[11][-6:])
showPredictsById(case=test_ids[12][-6:])


In [None]:
showPredictsById(case=test_ids[13][-6:])
showPredictsById(case=test_ids[14][-6:])
showPredictsById(case=test_ids[15][-6:])
showPredictsById(case=test_ids[16][-6:])
showPredictsById(case=test_ids[17][-6:])
showPredictsById(case=test_ids[18][-6:])


In [None]:
showPredictsById(case=test_ids[19][-6:])
showPredictsById(case=test_ids[20][-6:])
showPredictsById(case=test_ids[21][-6:])
showPredictsById(case=test_ids[22][-6:])
showPredictsById(case=test_ids[23][-6:])
showPredictsById(case=test_ids[24][-6:])


In [None]:
showPredictsById(case=test_ids[25][-6:])
showPredictsById(case=test_ids[26][-6:])
showPredictsById(case=test_ids[27][-6:])
showPredictsById(case=test_ids[28][-6:])
showPredictsById(case=test_ids[29][-6:])
showPredictsById(case=test_ids[30][-6:])


In [None]:
showPredictsById(case=test_ids[31][-6:])
showPredictsById(case=test_ids[32][-6:])
showPredictsById(case=test_ids[33][-6:])
showPredictsById(case=test_ids[34][-6:])
showPredictsById(case=test_ids[35][-6:])
showPredictsById(case=test_ids[36][-6:])


In [None]:
showPredictsById(case=test_ids[37][-6:])
showPredictsById(case=test_ids[38][-6:])
showPredictsById(case=test_ids[39][-6:])
showPredictsById(case=test_ids[40][-6:])
showPredictsById(case=test_ids[41][-6:])
showPredictsById(case=test_ids[42][-6:])


In [None]:
showPredictsById(case=test_ids[43][-6:])
showPredictsById(case=test_ids[44][-6:])
showPredictsById(case=test_ids[45][-6:])
showPredictsById(case=test_ids[46][-6:])
showPredictsById(case=test_ids[47][-6:])
showPredictsById(case=test_ids[48][-6:])


In [None]:
showPredictsById(case=test_ids[49][-6:])
showPredictsById(case=test_ids[50][-6:])
showPredictsById(case=test_ids[51][-6:])
showPredictsById(case=test_ids[52][-6:])
showPredictsById(case=test_ids[53][-6:])
showPredictsById(case=test_ids[54][-6:])


In [None]:
showPredictsById(case=test_ids[55][-6:])
showPredictsById(case=test_ids[56][-6:])
showPredictsById(case=test_ids[57][-6:])
showPredictsById(case=test_ids[58][-6:])
showPredictsById(case=test_ids[59][-6:])
showPredictsById(case=test_ids[60][-6:])


In [None]:
showPredictsById(case=test_ids[61][-6:])
showPredictsById(case=test_ids[62][-6:])
showPredictsById(case=test_ids[63][-6:])
showPredictsById(case=test_ids[64][-6:])
showPredictsById(case=test_ids[65][-6:])
showPredictsById(case=test_ids[66][-6:])
