In [None]:
import tensorflow as tf
from functools import partial

# DATASET FROM TFRECORDS DATA

tfrecord_files = tf.io.gfile.glob("/content/drive/MyDrive/mri-dataset/*.tfrecords")

def read_tfrecord(example):
    image_feature_description = {
        'height': tf.io.FixedLenFeature([], tf.int64),
        'width': tf.io.FixedLenFeature([], tf.int64),
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, image_feature_description)
    image_raw = example['image']
    label_raw = example['label']
    width = example['width']
    height = example['height']
    image = tf.io.decode_raw(image_raw, tf.float64)
    image = tf.reshape(image, [width, height])
    label = tf.io.decode_raw(label_raw, tf.float64)
    label = tf.reshape(label, [width, height])
    return image, label


def load_dataset(filenames):
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False  # disable order, increase speed
    dataset = tf.data.TFRecordDataset(
        filenames
    )  # automatically interleaves reads from multiple files
    dataset = dataset.with_options(
        ignore_order
    )  # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(
        partial(read_tfrecord), num_parallel_calls=AUTOTUNE
    )
    # returns a dataset of (image, label) pairs if labeled=True or just images if labeled=False
    return dataset


def get_dataset(filenames, BATCH_SIZE, labeled=True):
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    dataset = load_dataset(filenames)
    dataset = dataset.shuffle(2048)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset

train_dataset = get_dataset(tfrecord_files[0:20], 64)
validation_dataset = get_dataset(tfrecord_files[21:25], 64)
image_batch, label_batch = next(iter(train_dataset))



In [None]:
from keras import backend as K

def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2.0 * intersection + K.epsilon()) / (K.sum(y_true_f) + K.sum(y_pred_f) + K.epsilon())

def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

class KerasParas:
    def __init__(self):
        self.model_path = None
        self.outID = 0
        self.thd = 0.5
        self.img_format = 'channels_first'
        self.loss = None

class PreParas:
    def __init__(self):
        self.patch_dims = []
        self.patch_label_dims = []
        self.patch_strides = []
        self.n_class = ''


In [None]:
import os
from keras.models import load_model
import matplotlib.pyplot as plt

# Parameters for Keras model
keras_paras = KerasParas()
keras_paras.outID = 0
keras_paras.thd = 0.5
keras_paras.loss = 'dice_coef_loss'
keras_paras.img_format = 'channels_last'
keras_paras.model_path = '/content/drive/MyDrive/mri-dataset/rat_brain-2d_unet.hdf5'


seg_net = load_model(keras_paras.model_path, custom_objects={'dice_coef_loss': dice_coef_loss, 'dice_coef': dice_coef})

seg_net.summary()

history = seg_net.fit(train_dataset, epochs = 5, validation_data = validation_dataset)

def loss_plot(history):
    plt.clf()
    plt.plot(history.history['loss'], color = 'grey')
    plt.plot(history.history['val_loss'], color = 'indigo')
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Training', 'Validation'], loc='upper left')
    plt.show()

loss_plot(history)

image_batch, label_batch = next(iter(train_dataset))
pred = seg_net.predict(image_batch)
fig=plt.figure(figsize=(9,18))
for i in range(1,21,3):
  fig.add_subplot(7,3,i)
  plt.imshow(image_batch[i].numpy())
  fig.add_subplot(7,3,i+1)
  plt.imshow(label_batch[i].numpy())
  fig.add_subplot(7,3,i+2)
  plt.imshow(pred[i][:,:,0])
plt.show()