In [None]:
# !pip install git+https://github.com/miykael/gif_your_nifti # nifti to gif 


In [None]:
import os
import random
import cv2
import glob
import PIL
import shutil
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from skimage import data
from skimage.util import montage 
import skimage.transform as skTrans
from skimage.transform import rotate
from skimage.transform import resize
from PIL import Image, ImageOps 
import plotly.express as px
import pydot,graphviz

import nilearn as nl
import nibabel as nib
import nilearn.plotting as nlplt
import gif_your_nifti.core as gif2nif

import keras
import keras.backend as K
from keras.callbacks import CSVLogger
import tensorflow as tf
from tensorflow.keras.utils import plot_model
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TensorBoard
from tensorflow.keras import preprocessing
np.set_printoptions(precision=3, suppress=True)


In [None]:
test_image_flair=nib.load('/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_flair.nii').get_fdata()
test_image_t1=nib.load('/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_t1.nii').get_fdata()
test_image_t1ce=nib.load('/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_t1ce.nii').get_fdata()
test_image_t2=nib.load('/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_t2.nii').get_fdata()
test_mask=nib.load('/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_seg.nii').get_fdata()
test_mask=test_mask.astype(np.uint8)

In [None]:
SEGMENT_CLASSES = {
    0 : 'NO TUMOUR',
    1 : 'NECROTIC/CORE',
    2 : 'EDEMA',
    3 : 'ENHANCING'
}

In [None]:
n_slice=random.randint(0, test_mask.shape[2]-1)
print(n_slice)

plt.figure(figsize=(12, 8))
# n_slice=28

plt.subplot(231)
plt.imshow(test_image_flair[:,:,n_slice])
plt.title('Image flair')
plt.axis('off')

plt.subplot(232)
plt.imshow(test_image_t1[:,:,n_slice])
plt.title('Image t1')
plt.axis('off')

plt.subplot(233)
plt.imshow(test_image_t1ce[:,:,n_slice])
plt.title('Image t1ce')
plt.axis('off')

plt.subplot(234)
plt.imshow(test_image_t2[:,:,n_slice])
plt.title('Image t2')
plt.axis('off')

plt.subplot(235)
plt.imshow(test_mask[:,:,n_slice])
plt.title('Mask')
plt.axis('off')

plt.show()

In [None]:
unique_values,counts=np.unique(test_mask[:,:,n_slice],return_counts=True)
data = pd.DataFrame({'Value': unique_values, 'Count': counts})
data['Type']=np.where(data['Value']==0,SEGMENT_CLASSES[0],np.where(data['Value']==1,SEGMENT_CLASSES[1],np.where(data['Value']==2,SEGMENT_CLASSES[2],SEGMENT_CLASSES[3])))
# Calculate percentage composition
data['Percentage'] = (data['Count'] / data['Count'].sum()) * 100

# Plotting with Plotly
fig = px.pie(data, values='Percentage', names='Type', title='Percentage Composition of tumour present in the brain')
# fig.update_yaxes(type='log')

# Show the plot
fig.show()

In [None]:
TRAIN_DATASET_PATH='/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/'
TEST_DATASET_PATH='/kaggle/input/brats20-dataset-training-validation/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/'

train_img_dir=os.listdir(TRAIN_DATASET_PATH)
train_img_dir=os.listdir(TRAIN_DATASET_PATH)

In [None]:
# TRAIN_DATASET_PATH='/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/'
# shutil.copy2(TRAIN_DATASET_PATH + 'BraTS20_Training_204/BraTS20_Training_204_flair.nii', './test_gif_BraTS20_Training_204_flair.nii')
# gif2nif.write_gif_normal('./test_gif_BraTS20_Training_204_flair.nii')

In [None]:
niimg = nl.image.load_img(TRAIN_DATASET_PATH + 'BraTS20_Training_001/BraTS20_Training_001_flair.nii')
nimask = nl.image.load_img(TRAIN_DATASET_PATH + 'BraTS20_Training_001/BraTS20_Training_001_seg.nii')

fig, axes = plt.subplots(nrows=4, figsize=(30, 40))

nlplt.plot_anat(niimg,
                title='BraTS20_Training_001_flair.nii plot_anat',
                axes=axes[0])

nlplt.plot_epi(niimg,
               title='BraTS20_Training_001_flair.nii plot_epi',
               axes=axes[1])

nlplt.plot_img(niimg,
               title='BraTS20_Training_001_flair.nii plot_img',
               axes=axes[2])

nlplt.plot_roi(nimask, 
               title='BraTS20_Training_001_flair.nii with mask plot_roi',
               bg_img=niimg, 
               axes=axes[3], cmap='Paired')

plt.show()

In [None]:
def dice_coef(y_true, y_pred, smooth=1):
    class_num = 4
    dice = 0
    for i in range(class_num):
        y_true_f = tf.keras.backend.flatten(y_true[:,:,:,i])
        y_pred_f = tf.keras.backend.flatten(y_pred[:,:,:,i])
        intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
        dice += (2. * intersection + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth)
    return dice / class_num


 
# define per class evaluation of dice coef
# inspired by https://github.com/keras-team/keras/issues/9395
def dice_coef_necrotic(y_true, y_pred, epsilon=1e-6):
    intersection = tf.keras.backend.sum(tf.keras.backend.abs(y_true[:,:,:,1] * y_pred[:,:,:,1]))
    return (2. * intersection) / (tf.keras.backend.sum(tf.keras.backend.square(y_true[:,:,:,1])) + tf.keras.backend.sum(tf.keras.backend.square(y_pred[:,:,:,1])) + epsilon)

def dice_coef_edema(y_true, y_pred, epsilon=1e-6):
    intersection = tf.keras.backend.sum(tf.keras.backend.abs(y_true[:,:,:,2] * y_pred[:,:,:,2]))
    return (2. * intersection) / (tf.keras.backend.sum(tf.keras.backend.square(y_true[:,:,:,2])) + tf.keras.backend.sum(tf.keras.backend.square(y_pred[:,:,:,2])) + epsilon)

def dice_coef_enhancing(y_true, y_pred, epsilon=1e-6):
    intersection = tf.keras.backend.sum(tf.keras.backend.abs(y_true[:,:,:,3] * y_pred[:,:,:,3]))
    return (2. * intersection) / (tf.keras.backend.sum(tf.keras.backend.square(y_true[:,:,:,3])) + tf.keras.backend.sum(tf.keras.backend.square(y_pred[:,:,:,3])) + epsilon)

def precision(y_true, y_pred):
    true_positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_true * y_pred, 0, 1)))
    predicted_positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + tf.keras.backend.epsilon())
    return precision

def sensitivity(y_true, y_pred):
    true_positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_true * y_pred, 0, 1)))
    possible_positives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(y_true, 0, 1)))
    return true_positives / (possible_positives + tf.keras.backend.epsilon())

def specificity(y_true, y_pred):
    true_negatives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip((1-y_true) * (1-y_pred), 0, 1)))
    possible_negatives = tf.keras.backend.sum(tf.keras.backend.round(tf.keras.backend.clip(1-y_true, 0, 1)))
    return true_negatives / (possible_negatives + tf.keras.backend.epsilon())

In [None]:
def build_unet(inputs, ker_init, dropout):
    conv1 = Conv2D(64, 7, activation='relu', padding='same', kernel_initializer=ker_init)(inputs)
    conv1 = Conv2D(64, 7, activation='relu', padding='same', kernel_initializer=ker_init)(conv1)
    
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(128, 7, activation='relu', padding='same', kernel_initializer=ker_init)(pool1)
    conv2 = Conv2D(128, 7, activation='relu', padding='same', kernel_initializer=ker_init)(conv2)
    
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(256, 7, activation='relu', padding='same', kernel_initializer=ker_init)(pool2)
    conv3 = Conv2D(256, 7, activation='relu', padding='same', kernel_initializer=ker_init)(conv3)
    
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = Conv2D(512, 7, activation='relu', padding='same', kernel_initializer=ker_init)(pool3)
    conv4 = Conv2D(512, 7, activation='relu', padding='same', kernel_initializer=ker_init)(conv4)
    
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    conv5 = Conv2D(1024, 7, activation='relu', padding='same', kernel_initializer=ker_init)(pool4)
    conv5 = Conv2D(1024, 7, activation='relu', padding='same', kernel_initializer=ker_init)(conv5)
    drop5 = Dropout(dropout)(conv5)

    up6 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer=ker_init)(UpSampling2D(size=(2,2))(drop5))
    merge6 = concatenate([conv4, up6], axis=3)
    conv6 = Conv2D(512, 7, activation='relu', padding='same', kernel_initializer=ker_init)(merge6)
    conv6 = Conv2D(512, 7, activation='relu', padding='same', kernel_initializer=ker_init)(conv6)

    up7 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer=ker_init)(UpSampling2D(size=(2,2))(conv6))
    merge7 = concatenate([conv3, up7], axis=3)
    conv7 = Conv2D(256, 7, activation='relu', padding='same', kernel_initializer=ker_init)(merge7)
    conv7 = Conv2D(256, 7, activation='relu', padding='same', kernel_initializer=ker_init)(conv7)

    up8 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer=ker_init)(UpSampling2D(size=(2,2))(conv7))
    merge8 = concatenate([conv2, up8], axis=3)
    conv8 = Conv2D(128, 7, activation='relu', padding='same', kernel_initializer=ker_init)(merge8)
    conv8 = Conv2D(128, 7, activation='relu', padding='same', kernel_initializer=ker_init)(conv8)

    up9 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer=ker_init)(UpSampling2D(size=(2,2))(conv8))
    merge9 = concatenate([conv1, up9], axis=3)
    conv9 = Conv2D(64, 7, activation='relu', padding='same', kernel_initializer=ker_init)(merge9)
    conv9 = Conv2D(64, 7, activation='relu', padding='same', kernel_initializer=ker_init)(conv9)

    conv10 = Conv2D(4, 1, activation='softmax')(conv9)

    return Model(inputs=inputs, outputs=conv10)

IMG_SIZE = 128 
input_layer = Input((IMG_SIZE, IMG_SIZE, 2))

model = build_unet(inputs=input_layer,ker_init='he_normal',dropout=0.1)
model.compile(
    loss="categorical_crossentropy", 
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), 
    metrics=[
        'accuracy',
        tf.keras.metrics.MeanIoU(num_classes=4),
        dice_coef, precision, sensitivity, specificity,
        dice_coef_necrotic, dice_coef_edema, dice_coef_enhancing
    ]
)

model.summary()

In [None]:
# plot_model(model, to_file='model.png', show_shapes=True, show_layer_names=True)

import plotly.graph_objects as go

# Define a function to create an interactive plot of the model
def plot_model_interactive(model):
    fig = go.Figure()

    # Iterate over the layers and add them to the plot
    for i, layer in enumerate(model.layers):
        fig.add_trace(go.Scatter(
            x=[i],
            y=[0],
            mode='markers+text',
            text=[layer.name],
            textposition='bottom center',
            marker=dict(size=20)
        ))

    # Set plot layout
    fig.update_layout(
        title='Interactive Model Schematic',
        xaxis_title='Layer',
        yaxis_title='Position',
        showlegend=False
    )

    fig.show()

# Plot the model interactively
plot_model_interactive(model)

In [None]:
train_and_val_directories = [f.path for f in os.scandir(TRAIN_DATASET_PATH) if f.is_dir()]

# As because file BraTS20_Training_355 has ill formatted name for for seg.nii file
train_and_val_directories.remove(TRAIN_DATASET_PATH+'BraTS20_Training_355')


def pathListIntoIds(dirList):
    x = []
    for i in range(0,len(dirList)):
        x.append(dirList[i][dirList[i].rfind('/')+1:])
    return x

train_and_test_ids = pathListIntoIds(train_and_val_directories)

    
train_test_ids, val_ids = train_test_split(train_and_test_ids,test_size=0.2, random_state=2)  # splitting the data into training and validation data
train_ids, test_ids = train_test_split(train_test_ids,test_size=0.15, random_state=42) # splitting the train_test data into train and test data

In [None]:
from bokeh.plotting import figure, show
from bokeh.io import output_notebook
from bokeh.models import ColumnDataSource
from bokeh.transform import factor_cmap
from bokeh.palettes import Spectral6

labels = ['train', 'val', 'test']
counts = [len(train_ids),len(val_ids),len(test_ids)]

source = ColumnDataSource(data=dict(labels=labels, counts=counts))

p = figure(x_range=labels,height=400, width=600, title="Distribution of data for training, validation and testing",
           toolbar_location=None, tools="")

p.vbar(x='labels', top='counts', width=0.5, source=source, legend_field="labels",
       line_color='white', fill_color=factor_cmap('labels', palette=Spectral6, factors=labels))

p.xgrid.grid_line_color = None
p.y_range.start = 0
p.legend.orientation = "horizontal"
p.legend.location = "top_center"

output_notebook()
show(p)


In [None]:
VOLUME_SLICES = 100 
VOLUME_START_AT = 22

In [None]:
class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, volume_slices=VOLUME_SLICES,volume_start=VOLUME_START_AT,dim=(IMG_SIZE,IMG_SIZE), batch_size = 1, n_channels = 2, shuffle=True):
        'Initialization'
        self.volume_slices=volume_slices
        self.volume_start=volume_start
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        Batch_ids = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(Batch_ids)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, Batch_ids):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.zeros((self.batch_size*self.volume_slices, *self.dim, self.n_channels))
        y = np.zeros((self.batch_size*self.volume_slices, 240, 240))
        Y = np.zeros((self.batch_size*self.volume_slices, *self.dim, 4))

        
        # Generate data
        for c, i in enumerate(Batch_ids):
            case_path = os.path.join(TRAIN_DATASET_PATH, i)

            data_path = os.path.join(case_path, f'{i}_flair.nii');
            flair = nib.load(data_path).get_fdata()    

            data_path = os.path.join(case_path, f'{i}_t1ce.nii');
            ce = nib.load(data_path).get_fdata()
            
            data_path = os.path.join(case_path, f'{i}_seg.nii');
            seg = nib.load(data_path).get_fdata()
        
            for j in range(self.volume_slices):
                 X[j +self.volume_slices*c,:,:,0] = cv2.resize(flair[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE));
                 X[j +self.volume_slices*c,:,:,1] = cv2.resize(ce[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE));

                 y[j +self.volume_slices*c] = seg[:,:,j+VOLUME_START_AT];
                    
        # Generate masks
        y[y==4] = 3;
        mask = tf.one_hot(y, 4);
        Y = tf.image.resize(mask, (IMG_SIZE, IMG_SIZE));
        return X/np.max(X), Y
        
training_generator = DataGenerator(train_ids)
valid_generator = DataGenerator(val_ids)
test_generator = DataGenerator(test_ids)

In [None]:
print(valid_generator)
print(training_generator)
print(test_generator)

In [None]:
csv_logger = CSVLogger('training.log', separator=',', append=False)

callbacks = [
#     keras.callbacks.EarlyStopping(monitor='loss', min_delta=0,
#                               patience=2, verbose=1, mode='auto'),
      keras.callbacks.ReduceLROnPlateau(monitor='dice_coef', factor=0.5,
                              patience=5, min_lr=0.000001, verbose=1),
# todo add ModelCheckpoint
  #  keras.callbacks.ModelCheckpoint(filepath = 'model_.{epoch:02d}-{val_loss:.6f}.m5',
  #                             verbose=1, save_best_only=True, save_weights_only = True)
        csv_logger
    ]

In [None]:
K.clear_session()

history =  model.fit(training_generator,
                    epochs=38,
                    steps_per_epoch=len(train_ids),
                    callbacks= callbacks,
                    validation_data = valid_generator
                    )  

In [None]:
plt.plot(history.history['accuracy'])

In [None]:
report = pd.read_csv('training.log', sep=',', engine='python')
report = report[report['accuracy']!=0]
report

In [None]:
acc=report['accuracy']
val_acc=report['val_accuracy']

epoch=range(len(acc))

loss=report['loss']
val_loss=report['val_loss']

train_dice=report['dice_coef']
val_dice=report['val_dice_coef']

f,ax=plt.subplots(1,3,figsize=(16,8))

ax[0].plot(epoch,acc,'y',label='Training Accuracy')
ax[0].plot(epoch,val_acc,'g',label='Validation Accuracy')
ax[0].legend()

ax[1].plot(epoch,loss,'y',label='Training Loss')
ax[1].plot(epoch,val_loss,'g',label='Validation Loss')
ax[1].legend()

ax[2].plot(epoch,train_dice,'y',label='Training dice coef')
ax[2].plot(epoch,val_dice,'g',label='Validation dice coef')
ax[2].legend()

plt.show()

In [None]:
# mri type must one of 1) flair 2) t1 3) t1ce 4) t2 ------- or even 5) seg
# returns volume of specified study at `path`
def imageLoader(path):
    image = nib.load(path).get_fdata()
    X = np.zeros((self.batch_size*100, *self.dim, self.n_channels))
    for j in range(100):
        X[j +100*c,:,:,0] = cv2.resize(image[:,:,j+22], (IMG_SIZE, IMG_SIZE));
        X[j +100*c,:,:,1] = cv2.resize(ce[:,:,j+22], (IMG_SIZE, IMG_SIZE));

        y[j +100*c] = seg[:,:,j+22];
    return np.array(image)


# load nifti file at `path`
# and load each slice with mask from volume
# choose the mri type & resize to `IMG_SIZE`
def loadDataFromDir(path, list_of_files, mriType, n_images):
    scans = []
    masks = []
    for i in list_of_files[:n_images]:
        fullPath = glob.glob( i + '/*'+ mriType +'*')[0]
        currentScanVolume = imageLoader(fullPath)
        currentMaskVolume = imageLoader( glob.glob( i + '/*seg*')[0] ) 
        # for each slice in 3D volume, find also it's mask
        for j in range(0, currentScanVolume.shape[2]):
            scan_img = cv2.resize(currentScanVolume[:,:,j], dsize=(IMG_SIZE,IMG_SIZE), interpolation=cv2.INTER_AREA).astype('uint8')
            mask_img = cv2.resize(currentMaskVolume[:,:,j], dsize=(IMG_SIZE,IMG_SIZE), interpolation=cv2.INTER_AREA).astype('uint8')
            scans.append(scan_img[..., np.newaxis])
            masks.append(mask_img[..., np.newaxis])
    return np.array(scans, dtype='float32'), np.array(masks, dtype='float32')
        
#brains_list_test, masks_list_test = loadDataFromDir(VALIDATION_DATASET_PATH, test_directories, "flair", 5)

In [None]:
def predictByPath(case_path,case):
    files = next(os.walk(case_path))[2]
    X = np.empty((100, IMG_SIZE, IMG_SIZE, 2))
  #  y = np.empty((100, IMG_SIZE, IMG_SIZE))
    
    vol_path = os.path.join(case_path, f'BraTS20_Training_{case}_flair.nii');
    flair=nib.load(vol_path).get_fdata()
    
    vol_path = os.path.join(case_path, f'BraTS20_Training_{case}_t1ce.nii');
    ce=nib.load(vol_path).get_fdata() 
    
 #   vol_path = os.path.join(case_path, f'BraTS20_Training_{case}_seg.nii');
 #   seg=nib.load(vol_path).get_fdata()  

    
    for j in range(100):
        X[j,:,:,0] = cv2.resize(flair[:,:,j+22], (IMG_SIZE,IMG_SIZE))
        X[j,:,:,1] = cv2.resize(ce[:,:,j+22], (IMG_SIZE,IMG_SIZE))
 #       y[j,:,:] = cv2.resize(seg[:,:,j+22], (IMG_SIZE,IMG_SIZE))
        
  #  model.evaluate(x=X,y=y[:,:,:,0], callbacks= callbacks)
    return model.predict(X/np.max(X), verbose=1)


def showPredictsById(case, start_slice = 60):
    path = f"../input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_{case}"
    gt = nib.load(os.path.join(path, f'BraTS20_Training_{case}_seg.nii')).get_fdata()
    origImage = nib.load(os.path.join(path, f'BraTS20_Training_{case}_flair.nii')).get_fdata()
    p = predictByPath(path,case)

    core = p[:,:,:,1]
    edema= p[:,:,:,2]
    enhancing = p[:,:,:,3]

    plt.figure(figsize=(18, 50))
    f, axarr = plt.subplots(1,6, figsize = (18, 50)) 

    axarr[0].imshow(cv2.resize(origImage[:,:,start_slice+22], (IMG_SIZE, IMG_SIZE)))
    axarr[0].title.set_text('Original image flair')
    axarr[1].imshow(cv2.resize(gt[:,:,start_slice+22], (IMG_SIZE, IMG_SIZE)))
    axarr[1].title.set_text('Ground truth')
    axarr[3].imshow(edema[start_slice,:,:])
    axarr[3].title.set_text('edema predicted')
    axarr[4].imshow(core[start_slice,:,])
    axarr[4].title.set_text('core predicted')
    axarr[5].imshow(enhancing[start_slice,:,])
    axarr[5].title.set_text('enhancing predicted')
    plt.show()
    
    
showPredictsById(case=test_ids[0][-3:])
showPredictsById(case=test_ids[1][-3:])
showPredictsById(case=test_ids[2][-3:])
showPredictsById(case=test_ids[3][-3:])
showPredictsById(case=test_ids[4][-3:])
showPredictsById(case=test_ids[5][-3:])
showPredictsById(case=test_ids[6][-3:], start_slice=40)

In [None]:
model.save("model_1.h5")

In [None]:
model.save_weights("model.weights.h5")