In [None]:
"""
good paper about losses:
https://arxiv.org/pdf/2006.14822.pdf

try freezing the encoder?? what does it mean? --> https://segmentation-models.readthedocs.io/en/latest/tutorial.html
transfer learning in encoder with tumor relared neural network
how to implement dropout
number of parameters in network
resnet34 or densenet121

TODOs:

check the activation of layers --> https://awjuliani.medium.com/visualizing-neural-network-layer-activation-tensorflow-tutorial-d45f8bf7bbc4
implement shuffling in the generator --> https://github.com/qubvel/segmentation_models/blob/master/examples/multiclass%20segmentation%20(camvid).ipynb // https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly 
"""

In [None]:
from tensorflow.keras import layers
from tensorflow import keras
import segmentation_models as sm

def get_model_with_dropout(base_model_name='efficientnetb4',activation='sigmoid',dropout = 0.1):
    base_model = sm.Unet(base_model_name, encoder_weights='imagenet')
    base_model_input = base_model.input
    base_model_output = base_model.get_layer('final_conv').output
    #add dropout
    base_model_output = keras.layers.Dropout(dropout)(base_model_output)
    #add activation
    output = keras.layers.Activation(activation, name=activation)(base_model_output)
    model_dp = keras.models.Model(base_model_input, output)

    return model_dp

In [None]:
import os
from tensorflow.python.client import device_lib

def get_available_devices():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos]

#Run from server

gpu4 = "GPU-71b4cdfc-e381-0b98-9b24-4fc06284b496" 
gpu5 = "GPU-99d0769a-9f86-4800-a40e-2320dddcf5d1" 
gpu6 = "GPU-7423cfb5-cff4-ec4d-7e96-ea6e1591d56f"
gpu7 = "GPU-c0a8738f-6dd0-1b78-c38f-4969fd3886a8"

os.environ["CUDA_VISIBLE_DEVICES"]= gpu4
print(get_available_devices())

In [None]:
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import h5py
import cv2

In [None]:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Activation, Lambda, GlobalAveragePooling2D, concatenate
from tensorflow.keras.layers import UpSampling2D, Conv2D, Dropout, MaxPooling2D, Conv2DTranspose
from tensorflow.keras.layers import Dense, Flatten, Input
from tensorflow.keras.models import Model, Sequential, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import RMSprop, Adam, SGD

In [None]:
IMG_HEIGHT = 240
IMG_WIDTH = 240
IMG_DEPTH = 155

IMG_HEIGHT_UNET = 256
IMG_WIDTH_UNET = 256

N_IMG = 369
length_file = IMG_DEPTH * N_IMG
h5py_file_name = 'training.hdf5'
smooth = 1e-4

___

Generating the data and creating data arrays

In [None]:
ImgDir = os.path.join("..","data","MICCAI_BraTS2020_TrainingData")
features_path = list()
labels_path = list()
count = 0
lim = 10

for folder in os.listdir(ImgDir):
    count +=1
    if 'Training' in folder:
        new_dir = os.path.join(ImgDir,folder)
        data = os.listdir(new_dir)
        for files in data:
            if 'flair' in files:
                features_path.append(os.path.join(new_dir, files))
            if 'seg' in files:
                labels_path.append(os.path.join(new_dir, files))

print(len(features_path))
print(len(labels_path))


In [None]:
img_conc_features = np.zeros((len(features_path),IMG_HEIGHT,IMG_WIDTH,IMG_DEPTH))
path = features_path

for file, i in zip(path,range(len(path))):
    img = nib.load(file)
    imgarr = img.get_fdata()
    img_conc_features[i,:,:,:] = imgarr
    
img_conc_features = np.concatenate(img_conc_features,axis=2)
print(np.shape(img_conc_features))

In [None]:
img_conc_labels = np.zeros((len(labels_path),IMG_HEIGHT,IMG_WIDTH,IMG_DEPTH))
path = labels_path

for file, i in zip(path,range(len(path))):
    img = nib.load(file)
    imgarr = img.get_fdata()
    img_conc_labels[i,:,:,:] = imgarr
    
img_conc_labels = np.concatenate(img_conc_labels,axis=2)
print(np.shape(img_conc_labels))

In [None]:
with h5py.File(os.path.join('..','data',h5py_file_name), 'a') as f:
    f.create_dataset("features", data=img_conc_features, compression="gzip")
    f.create_dataset("labels", data=img_conc_labels, compression="gzip")

In [None]:
hdf5_store_Brats = h5py.File(os.path.join('..','data',h5py_file_name), "r")
test_img = hdf5_store_Brats["features"][:, :, 100]
plt.imshow(test_img, cmap ='gray')
hdf5_store_Brats.close()

In [None]:
hdf5_store_Brats = h5py.File(os.path.join('..','data',h5py_file_name), "r")
test_img = hdf5_store_Brats["labels"][:, :, 100]
plt.imshow(test_img, cmap ='gray')
hdf5_store_Brats.close()

___

Load data and model

In [None]:
train_stop = 9001
n_images = 10000

# n_images = 57195
# train_stop = int(np.floor(n_images * 0.9))

# images = X_nib.get_fdata()
# labels = y_nib.get_fdata()

#https://www.pythonforthelab.com/blog/how-to-use-hdf5-files-in-python/

#tf.debugging.set_log_device_placement(True)
# strategy = tf.distribute.MirroredStrategy(devices=["/device:GPU:0", "/device:GPU:1"])
# with strategy.scope():

#with tf.device("/device:GPU:0"):    

with h5py.File(os.path.join('..','data',h5py_file_name), "r") as f:
    #images_train = f["features"][()] #the whole dataset: 57195 images
    #labels_train = f["labels"][()]

    images = f["features"]
    images_train = tf.constant(images[:,:,:train_stop],dtype=tf.float32) #taking a subsample
    images_val = tf.constant(images[:,:,train_stop:n_images+1],dtype=tf.float32)

    labels = f["labels"]
    labels_train = tf.constant(labels[:,:,:train_stop],dtype=tf.float32)
    labels_val = tf.constant(images[:,:,train_stop:n_images+1],dtype=tf.float32)

#batch_size = 30 

print(labels_val[:,:,1])

In [None]:
np.shape(images_train)

In [None]:
plt.imshow(images_train[:,:,96], cmap = 'gray', interpolation = 'bicubic')
plt.xticks([]), plt.yticks([])  # to hide tick values on X and Y axis
plt.show()

In [None]:
plt.imshow(labels_train[:,:,100], cmap = 'gray', interpolation = 'bicubic')
plt.xticks([]), plt.yticks([])  # to hide tick values on X and Y axis
plt.show()

____

In [None]:
import segmentation_models as sm

BACKBONE = 'resnet34' #'densenet121' 
CLASSES = ['0','1','2','3','4']
LR = 0.0001
EPOCHS = 50
n_classes = 5
BATCH_SIZE = 30
decoder_filters = (128, 64, 32, 16, 8)   #standard:(256, 128, 64, 32, 16)

In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()
    
# helper function for data visualization    
def denormalize(x):
    """Scale image to range 0..1 for correct plot"""
    x_max = np.percentile(x, 98)
    x_min = np.percentile(x, 2)    
    x = (x - x_min) / (x_max - x_min)
    x = x.clip(0, 1)
    return x

def visualize_histories(**images):
    """Import as tuples: one = (a,b)"""
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.ylim(0,1)
        plt.xlabel('epoch')
        if i == 0:
            plt.ylabel('dice_score')
        plt.title(' '.join(name.split('_')).title())
        plt.plot(np.asarray(image[0]))
        plt.plot(np.asarray(image[1]))
        plt.legend(['train', 'val'])
    plt.show()

In [None]:
class Dataset:
    
    CLASSES = ['0','1','2','3','4']
    
    def __init__(self, images_train, images_label, classes = None):
        self.images_fps = images_train
        self.masks_fps = images_label
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        self.ids = len(images_train[1,1,:])
    
    def __getitem__(self, i):

        X_new = np.zeros((IMG_HEIGHT_UNET, IMG_WIDTH_UNET, 3), np.float32)

        X_new[:IMG_HEIGHT,:IMG_WIDTH,0] = self.images_fps[:,:,i]
        X_new[:IMG_HEIGHT,:IMG_WIDTH,1] = self.images_fps[:,:,i]
        X_new[:IMG_HEIGHT,:IMG_WIDTH,2] = self.images_fps[:,:,i]
        
        y_new = np.zeros((IMG_HEIGHT_UNET, IMG_WIDTH_UNET), np.float32)

        image = X_new[:,:,:]
        
        y_new[:IMG_HEIGHT,:IMG_WIDTH] = self.masks_fps[:,:,i]
        
        mask = y_new
        
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # add background if mask is not binary
#         if mask.shape[-1] != 1:
#             background = 1 - mask.sum(axis=-1, keepdims=True)
#             mask = np.concatenate((mask, background), axis=-1)
        
        return image, mask
    
    def __len__(self):
        return self.ids

In [None]:
class Dataloder(tf.keras.utils.Sequence):
    """Load data from dataset and form batches
    
    Args:
        dataset: instance of Dataset class for image loading and preprocessing.
        batch_size: Integet number of images in batch.
        shuffle: Boolean, if `True` shuffle image indexes each epoch.
    """
    
    def __init__(self, dataset, batch_size=BATCH_SIZE, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(dataset))

        self.on_epoch_end()

    def __getitem__(self, i):
        
        # collect batch data
        start = i * self.batch_size
        stop = (i + 1) * self.batch_size
        data = []
        for j in range(start, stop):
            data.append(self.dataset[j])
            
        batch = [np.stack(samples, axis=0) for samples in zip(*data)]
        
        return (batch[0],batch[1])
    
    def __len__(self):
        """Denotes the number of batches per epoch"""
        return len(self.indexes) // self.batch_size
    
    def on_epoch_end(self):
        """Callback function to shuffle indexes each epoch"""
        if self.shuffle:
            self.indexes = np.random.permutation(self.indexes)  

In [None]:
dataset = Dataset(images_train, labels_train, classes=['0','1','2','3','4'])
image, mask = dataset[100]
visualize(
    image=image[:,:,0], ### if putting the whole image, black and white, could it be the reason of the problem??
    no_tumor=mask[..., 0].squeeze(),
    one=mask[..., 1].squeeze(),
    two=mask[..., 2].squeeze(),
    three=mask[..., 3].squeeze(),
    four=mask[..., 4].squeeze()
)

In [None]:
#image = np.expand_dims(image, axis=0)
np.shape(image[np.newaxis, ...])

In [None]:
np.shape(mask)

In [None]:
# define network parameters
n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES))  # case for binary and multiclass segmentation
# n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1)  # case for binary and multiclass segmentation
activation = 'sigmoid' if n_classes == 1 else 'softmax'

#create model
model = sm.Unet(BACKBONE, classes=n_classes, activation=activation, decoder_filters = decoder_filters)
#encoder weights from ImageNet, not pretty realted to tumors, so i delete it, interesting to add other weights

In [None]:
#tf.debugging.set_log_device_placement(False)

In [None]:
def _gather_channels(x, indexes, **kwargs):
    """Slice tensor along channels axis by given indexes"""
    backend = K
    if backend.image_data_format() == 'channels_last':
        x = backend.permute_dimensions(x, (3, 0, 1, 2))
        x = backend.gather(x, indexes)
        x = backend.permute_dimensions(x, (1, 2, 3, 0))
    else:
        x = backend.permute_dimensions(x, (1, 0, 2, 3))
        x = backend.gather(x, indexes)
        x = backend.permute_dimensions(x, (1, 0, 2, 3))
    return x

def get_reduce_axes(per_image, **kwargs):
    backend = K
    axes = [1, 2] if backend.image_data_format() == 'channels_last' else [2, 3]
    if not per_image:
        axes.insert(0, 0)
    return axes

def gather_channels(*xs, indexes=None, **kwargs):
    """Slice tensors along channels axis by given indexes"""
    if indexes is None:
        return xs
    elif isinstance(indexes, (int)):
        indexes = [indexes]
    xs = [_gather_channels(x, indexes=indexes, **kwargs) for x in xs]
    return xs

def average(x, per_image=False, class_weights=None, **kwargs):
    backend = K
    if per_image:
        x = backend.mean(x, axis=0)
    if class_weights is not None:
        x = x * class_weights
    return backend.mean(x)

def round_if_needed(x, threshold, **kwargs):
    backend = K
    if threshold is not None:
        x = backend.greater(x, threshold)
        x = backend.cast(x, backend.floatx())
    return x

def iou_score(gt, pr, class_weights=1., class_indexes=None, smooth=smooth, per_image=False, threshold=None, **kwargs):
    backend = K
    gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs)
    pr = round_if_needed(pr, threshold, **kwargs)
    axes = get_reduce_axes(per_image, **kwargs)

    # score calculation
    intersection = backend.sum(gt * pr, axis=axes)
    union = backend.sum(gt + pr, axis=axes) - intersection

    score = (intersection + smooth) / (union + smooth)
    print(score)
    score = average(score, per_image, class_weights, **kwargs)

    return score

metrics_test = [iou_score]

In [None]:
from sklearn.metrics import make_scorer

def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    
    mask0 = y_pred[...,0]
    mask1 = y_pred[...,1]
    mask2 = y_pred[...,2]
    mask3 = y_pred[...,3]
    mask4 = y_pred[...,4]
    
    print(mask4)
    
    masks = [mask0, mask1, mask2, mask3, mask4]
    
    true0 = y_true[...,0]
    true1 = y_true[...,1]
    true2 = y_true[...,2]
    true3 = y_true[...,3]
    true4 = y_true[...,4]
    
    print(true4)
    
    trues = [true0, true1, true2, true3, true4]
    
    intersection = []
    
    for i in range(5):
        intersection.append(K.sum(trues[i] * masks[i]))
    
    intersection_test = np.asarray(intersection)
    coef_test = (2. * intersection_test + smooth) / (intersection_test + masks + smooth)
    
    intersection = K.sum(y_true_f * y_pred_f)
    coef =  (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    
    
    return coef

score = make_scorer(dice_coef, greater_is_better=True)

metrics_test2 = [dice_coef]

In [None]:
def dice_metric(ground_truth, prediction):

    #for metrics, it's good to round predictions:
    prediction = K.round(prediction)

    #intersection and totals per class per batch (considers channels last)
    intersection = ground_truth * prediction
    print(np.shape(intersection))
    intersection = K.sum(intersection, axis=[1,2])
    ground_truth = K.sum(ground_truth, axis=[1,2])
    prediction = K.sum(prediction, axis=[1,2])

    dice = ((2 * intersection) + K.epsilon()) / (ground_truth + prediction + K.epsilon())
    
metrics_test3 = [dice_metric]

In [None]:
#calculates dice considering an input with a single class
def dice_single(true,pred):
    true = K.batch_flatten(true)
    pred = K.batch_flatten(pred)
    pred = K.round(pred)

    intersection = K.sum(true * pred, axis=-1)
    true = K.sum(true, axis=-1)
    pred = K.sum(pred, axis=-1)

    return ((2*intersection) + K.epsilon()) / (true + pred + K.epsilon())

def dice_for_class(index):
    def dice_inner(true,pred):

        #get only the desired class
        true = true[:,:,:,index]
        pred = pred[:,:,:,index]

        #return dice per class
        return dice_single(true,pred)
    return dice_inner

metrics_test4 = [dice_for_class(i) for i in range(5)]

In [None]:
#calculates dice considering an input with a single class
def dice_single(true,pred):
    true = K.batch_flatten(true)
    pred = K.batch_flatten(pred)
    pred = K.round(pred)

    intersection = K.sum(true * pred, axis=-1)
    true = K.sum(true, axis=-1)
    pred = K.sum(pred, axis=-1)

    return ((2*intersection) + K.epsilon()) / (true + pred + K.epsilon())

def dice_inner_0(true,pred,index=0):

    #get only the desired class
    true = true[:,:,:,index]
    pred = pred[:,:,:,index]

    #return dice per class
    return dice_single(true,pred)

def dice_inner_1(true,pred,index=1):

    #get only the desired class
    true = true[:,:,:,index]
    pred = pred[:,:,:,index]

    #return dice per class
    return dice_single(true,pred)

def dice_inner_2(true,pred,index=2):

    #get only the desired class
    true = true[:,:,:,index]
    pred = pred[:,:,:,index]

    #return dice per class
    return dice_single(true,pred)

def dice_inner_3(true,pred,index=3):

    #get only the desired class
    true = true[:,:,:,index]
    pred = pred[:,:,:,index]

    #return dice per class
    return dice_single(true,pred)

def dice_inner_4(true,pred,index=4):

    #get only the desired class
    true = true[:,:,:,index]
    pred = pred[:,:,:,index]

    #return dice per class
    return dice_single(true,pred)

metrics_test5 = [dice_inner_0, dice_inner_1, dice_inner_2, dice_inner_3, dice_inner_4, sm.metrics.FScore(threshold=0.5)]

In [None]:
# define optomizer

    
optim = tf.keras.optimizers.Adam(LR)

# Segmentation models losses can be combined together by '+' and scaled by integer or float factor
# set class weights for dice_loss (car: 1.; pedestrian: 2.; background: 0.5;)
# dice_loss = sm.losses.DiceLoss(class_weights=np.array([0.1, 5, 3, 1, 1])) 
# focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
# total_loss = dice_loss + (1 * focal_loss)

total_loss = sm.losses.DiceLoss(class_weights=np.array([1, 1, 1, 1, 1])) 

# actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
# total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss 

metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]

In [None]:
# compile keras model with defined optimozer, loss and metrics
model.compile(optim, total_loss, metrics_test5)

In [None]:
# Dataset for train images


#strategy = tf.distribute.MirroredStrategy(devices=["/device:GPU:0", "/device:GPU:1"])
#with strategy.scope():

        train_dataset = Dataset(
            images_train, 
            labels_train, 
            classes=CLASSES
        )

        # Dataset for validation images
        valid_dataset = Dataset(
            images_val, 
            labels_val, 
            classes=CLASSES
        )

        train_dataloader = Dataloder(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
        valid_dataloader = Dataloder(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

        # check shapes for errors
        assert train_dataloader[0][0].shape == (BATCH_SIZE, IMG_HEIGHT_UNET, IMG_WIDTH_UNET, 3)
        assert train_dataloader[0][1].shape == (BATCH_SIZE, IMG_HEIGHT_UNET, IMG_WIDTH_UNET, n_classes)

# define callbacks for learning rate scheduling and best checkpoints saving
# callbacks = [
#     tf.keras.callbacks.ModelCheckpoint('./best_model.h5', save_weights_only=True, save_best_only=True, mode='min'),
#     tf.keras.callbacks.ReduceLROnPlateau(),
# ]

# callbacks = [
#     tf.keras.callbacks.ReduceLROnPlateau(),
# ]

In [None]:
train_dataloader[0][1].shape

In [None]:
class printbatch(tf.keras.callbacks.Callback):
    
    def on_train_begin(self, epoch, logs={}):

        res_dir = os.path.join("..","data","test_images")

        try:
            os.makedirs(res_dir)
        except:
            print(f"{res_dir} directory already exist")

    
    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        
        res_dir = os.path.join("..","data","test_images")
    
        x_img = os.path.join(res_dir,"X_input.jpg")
        y_img = os.path.join(res_dir,"Y_truth.jpg")
        predicted_img = os.path.join(res_dir,f"{epoch}_Y_predicted.jpg")
        
        image, gt_mask = train_dataset[100]
        #image = np.expand_dims(image, axis=0)
        pr_mask = self.model.predict(image[np.newaxis, ...])
    
        gt_mask_vis = gt_mask[...,0:3]
        pr_mask_vis = pr_mask[...,0:3]

#         visualize(
#             image=denormalize(image.squeeze()),
#             gt_mask_vis=gt_mask_vis.squeeze(),
#             pr_mask_vis=pr_mask_vis.squeeze(),
#         )

        cv2.imwrite(x_img, image[:,:,0])
        cv2.imwrite(y_img, gt_mask_vis.squeeze())
        cv2.imwrite(predicted_img, pr_mask_vis.squeeze())
#         cv2.imwrite(predicted_img, prediction[0,:,:,:] * 255.)
        
        
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

pb = printbatch()        

In [None]:
checkpoint_path =  os.path.join("..","data","checkpoints","training_2/cp-{epoch:04d}.ckpt")
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    save_weights_only=True,
    save_freq=5*BATCH_SIZE)


callbacks = [pb, cp_callback]

In [None]:
history = model.fit_generator(
    train_dataloader, 
    steps_per_epoch=len(train_dataloader), 
    epochs=EPOCHS,
    verbose = 0,
    callbacks = callbacks,
    validation_data=valid_dataloader, 
    validation_steps=len(valid_dataloader)
)

___

In [None]:
visualize_histories(
    class_zero = (history.history['dice_inner_0'], history.history['val_dice_inner_0']), ### if putting the whole image, black and white, could it be the reason of the problem??
    class_one = (history.history['dice_inner_1'], history.history['val_dice_inner_1']),
    class_two = (history.history['dice_inner_2'], history.history['val_dice_inner_2']),
    class_three = (history.history['dice_inner_3'], history.history['val_dice_inner_3']),
    class_four = (history.history['dice_inner_4'], history.history['val_dice_inner_4'])
)

In [None]:
checkpoint_path= os.path.join("..","data","checkpoints","unet_brats.ckpt")

In [None]:
model.save_weights(checkpoint_path)
# Restore the weights
# model.load_weights(checkpoint_path)

In [None]:
# Plot training & validation iou_score values
plt.figure(figsize=(30, 5))
plt.subplot(121)
plt.plot(history.history['f1-score'])
plt.plot(history.history['val_f1-score'])
plt.title('Model f1-score')
plt.ylabel('f1-score')
plt.xlabel('Epoch')
plt.ylim([0,1])
plt.legend(['Train', 'Val'], loc='upper left')

# Plot training & validation loss values
plt.subplot(122)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.ylim([0,1])
plt.legend(['Train', 'Val'], loc='upper left')
plt.show()

In [None]:
n = 1
# ids = np.random.choice(np.arange(len(labels_train)), size=n)
ids = [50]
for i in ids:
    
    image, gt_mask = valid_dataset[i]
    #image = np.expand_dims(image, axis=0)
    pr_mask = model.predict(image[np.newaxis, ...])
    
    visualize(
        image=denormalize(image.squeeze()),
        gt_mask=gt_mask[...,1:4].squeeze(),
        pr_mask=pr_mask[...,0:3].squeeze(),
    )


In [None]:
# model = sm.Unet('resnet34', classes = 5, input_shape=(IMG_HEIGHT_UNET, IMG_WIDTH_UNET, 3))
# model.compile(
#     'Adam',
#     loss=sm.losses.bce_jaccard_loss,
#     metrics=[sm.metrics.iou_score],
# )
# history = model.fit_generator(train_dataloader, epochs = 20, verbose=1, validation_data=None, class_weight=None)

Image visualization of training

In [None]:
test_img_feat_dir = os.path.join("..","data","MICCAI_BraTS2020_TrainingData","BraTS20_Training_001","BraTS20_Training_001_flair.nii.gz")
img_feat = nib.load(test_img_feat_dir)
imgarr_feat = img_feat.get_fdata()
test_img_feat_slice = imgarr_feat[:,:,100]
test_img = np.zeros((240, 240 ,3), np.float32)

test_img[:,:,0] = test_img_feat_slice
test_img[:,:,1] = test_img_feat_slice
test_img[:,:,2] = test_img_feat_slice

plt.imshow(test_img[:,:,0], cmap = 'gray', interpolation = 'bicubic')
plt.xticks([]), plt.yticks([])  # to hide tick values on X and Y axis
plt.show()

In [None]:
directory_predicted_3 = os.path.join("..","data","test_images","3_Y_predicted.jpg")
directory_predicted_29 = os.path.join("..","data","test_images","29_Y_predicted.jpg")
directory_predicted_39 = os.path.join("..","data","test_images","39_Y_predicted.jpg")
directory_predicted_49 = os.path.join("..","data","test_images","49_Y_predicted.jpg")
directory_y = os.path.join("..","data","test_images","Y_truth.jpg")
directory_X = os.path.join("..","data","test_images","X_input.jpg")

visualize(
    input_image=cv2.imread(directory_X,0).squeeze(),
    ground_truth=cv2.imread(directory_y,0).squeeze(),
    predicted_3=cv2.imread(directory_predicted_3,0),
    predicted_29=cv2.imread(directory_predicted_29,0),
    predicted_39=cv2.imread(directory_predicted_39,0),
    predicted_49=cv2.imread(directory_predicted_39,0),### if putting the whole image, black and white, could it be the reason of the problem??
    
)


In [None]:
plot_metrics(history)

____