In [None]:
from google.colab import drive

MOUNTPOINT = '/content/gdrive'
drive.mount(MOUNTPOINT)

In [None]:
import tensorflow as tf
import glob
import sklearn.model_selection

input_dataset = '/content/gdrive/My Drive/Colab Notebooks/datasets/ISIC2018_Task1-2_Training_Data/ISIC2018_Task1-2_Training_Input_x2/*.jpg'
output_dataset = '/content/gdrive/My Drive/Colab Notebooks/datasets/ISIC2018_Task1-2_Training_Data/ISIC2018_Task1_Training_GroundTruth_x2/*.png'

input = sorted(glob.glob(input_dataset))
output = sorted(glob.glob(output_dataset))
# 80:10:10 train/test/val split 
x_train, x_test, y_train, y_test = sklearn.model_selection.train_test_split(input, output, test_size=0.1) 
x_train, x_val, y_train, y_val = sklearn.model_selection.train_test_split(x_train, y_train, test_size=0.111) 

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))

train_ds = train_ds.shuffle(len(x_train))
test_ds = test_ds.shuffle(len(x_test))
val_ds = val_ds.shuffle(len(x_val))


def process(input_path, output_path):
  img = tf.image.decode_jpeg(tf.io.read_file(input_path), channels=3) # extract data
  img = tf.image.resize(img, [256,256]) # resize
  img = tf.cast(img, tf.float32) / 255 #normalize
  img = tf.reshape(img, (-1, 256, 256, 3)) 

  mask = tf.image.decode_png(tf.io.read_file(output_path), channels=1)
  mask = tf.image.resize(mask, [256,256])
  mask = tf.math.round(tf.cast(mask, tf.float32) / 255)
  mask = tf.reshape(mask, (-1, 256, 256, 1))

  return img, mask

# apply preprocessing to datasets
train_ds = train_ds.map(process)
test_ds = test_ds.map(process)
val_ds = val_ds.map(process)

In [None]:
# Copy paste contents of model.py into here OR upload model.py into Google Drive and import

In [None]:
height = 256
width = 256
channels = 3

"""Dice Coefficient calculates the similarity between two images"""
def dice_coefficient(y_pred, y_test): # This function was made with reference to the '45223499_improved_unet' folder
    flat_pred = tf.keras.backend.flatten(y_pred)
    flat_test = tf.keras.backend.flatten(y_test)
    return 2 * tf.keras.backend.sum(flat_pred * flat_test) / (tf.keras.backend.sum(flat_pred) + tf.keras.backend.sum(flat_test)) # DC = 2 * intersection / union

"""Dice coefficient loss function used in training"""
def dice_coefficient_loss(y_pred, y_test):
    return 1 - dice_coefficient(y_pred, y_test)

unet = model(height, width, channels)
unet.compile(optimizer=tf.keras.optimizers.Adam(learning_rate = 5*(10**-4)), loss=[dice_coefficient_loss], metrics=["accuracy", dice_coefficient])
unet.summary()

history =  unet.fit(train_ds.batch(32), validation_data=val_ds.batch(32), epochs=100)

In [None]:
# Calc avg dice coefficient within test dataset

total = 0
size = 0

for img, mask in test_ds.batch(1):
  pred = unet.predict(img)
  total += dice_coefficient(pred, mask)
  size += 1

dice_coefficient = total / size

# Dice coefficient: 
tf.print(dice_coefficient, summarize=-1)


In [None]:
import matplotlib.pyplot as plt

# plot accuracy, loss and dice coefficient
plt.plot(history.history['accuracy'], label='Accuracy')
plt.plot(history.history['loss'], label='Loss')
plt.plot(history.history['dice_coefficient'], label='Dice Coefficient')
plt.legend() 

# plot image visualisation comparing predicted and actual masks

plt.figure(figsize=(10, 10))

images = []
predictions = []
masks = []
for img, mask in test_ds.take(3): # take 3 random samples and visualise 
  # calculate predicted mask
  img = tf.reshape(img, [-1, 256, 256, 3])
  predictions.append(unet.predict(img))
  images.append(img)
  masks.append(mask)

for num in range(3):
  plt.subplot(4, 3, num*3+1)
  plt.imshow(tf.reshape(images[num], [256, 256, 3]))
  plt.axis('off')
  plt.title("Input")

  plt.subplot(4, 3, num*3+2)
  plt.imshow(tf.squeeze(tf.reshape(predictions[num], [256, 256, 1]), axis=2))
  plt.axis('off')
  plt.title("Predicted Mask")

  plt.subplot(4, 3, num*3+3)
  plt.imshow(tf.squeeze(tf.reshape(masks[num], [256, 256, 1]), axis=2))
  plt.axis('off')
  plt.title("Actual Mask")

plt.show()