In [1]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
import numpy as np
import os
%matplotlib inline
import matplotlib.pyplot as plt
import glob
import pathlib

In [2]:
tf.__version__

'2.7.0'

## Data preprocessing

In [3]:
BATCH_SIZE = 32

In [4]:
image_input = sorted(tf.io.gfile.glob('./ISIC2018_Task1-2_Training_Input_x2/*.jpg'))
ground_truth = sorted(tf.io.gfile.glob('./ISIC2018_Task1_Training_GroundTruth_x2/*.png'))

In [5]:
image_input[0:3]

['.\\ISIC2018_Task1-2_Training_Input_x2\\ISIC_0000000.jpg',
 '.\\ISIC2018_Task1-2_Training_Input_x2\\ISIC_0000001.jpg',
 '.\\ISIC2018_Task1-2_Training_Input_x2\\ISIC_0000003.jpg']

In [6]:
ground_truth[:3]

['.\\ISIC2018_Task1_Training_GroundTruth_x2\\ISIC_0000000_segmentation.png',
 '.\\ISIC2018_Task1_Training_GroundTruth_x2\\ISIC_0000001_segmentation.png',
 '.\\ISIC2018_Task1_Training_GroundTruth_x2\\ISIC_0000003_segmentation.png']

In [7]:
# Divide the dataset into training set, test set and val set with 6：2：2
length = len(image_input)
print(length)

image_input_val = image_input[:(int(length*0.2))]
print(len(image_input_val))
ground_truth_val = ground_truth[:(int(length*0.2))]
print(len(ground_truth_val))

2594
518
518


In [8]:
image_input_test = image_input[int(length*0.2):int(length*0.4)]
ground_truth_test = ground_truth[int(length*0.2):int(length*0.4)]

image_input_train = image_input[int(length*0.4):]
ground_truth_train = ground_truth[int(length*0.4):]

In [9]:
print(len(image_input_train))
print(len(ground_truth_train))

1557
1557


In [10]:
train_ds = tf.data.Dataset.from_tensor_slices((image_input_train, ground_truth_train))
val_ds = tf.data.Dataset.from_tensor_slices((image_input_val, ground_truth_val))
test_ds = tf.data.Dataset.from_tensor_slices((image_input_test, ground_truth_test))

In [11]:
# Image processing function
def processing_jpg(path):
    
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (256, 256))
    
    return image
  
def processing_png(path):
    
    image = tf.io.read_file(path)
    image = tf.image.decode_png(image, channels=1)
    image = tf.image.resize(image, (256, 256))
    
    return image

In [12]:
# normalize function
def normal(image, ground):
    
    image = tf.cast(image, tf.float32) / 255.0
    ground = tf.cast(ground, tf.float32) / 255.0
    
    return image, ground

def load_image(image_path, ground_path):
    
    image = processing_jpg(image_path)
    ground = processing_png(ground_path)
    image, ground = normal(image, ground)

    return image, ground

In [13]:
train_ds = train_ds.map(load_image)
val_ds = val_ds.map(load_image)
test_ds = test_ds.map(load_image)

In [14]:
train_ds = train_ds.batch(BATCH_SIZE)
val_ds = val_ds.batch(BATCH_SIZE)
test_ds = test_ds.batch(BATCH_SIZE)

## Build improved Unet model

In [15]:
def context_module(input_layer, filters):
    
    norm_1 = tfa.layers.InstanceNormalization()(input_layer)
    conv_1 = keras.layers.Conv2D(filters, (3, 3), padding = "same", activation = keras.layers.LeakyReLU(alpha = 0.01))(norm_1)
    drop_layer = keras.layers.Dropout(0.3)(conv_1)
    norm_2 = tfa.layers.InstanceNormalization()(drop_layer)
    conv_2 = keras.layers.Conv2D(filters, (3, 3), padding = "same", activation = keras.layers.LeakyReLU(alpha = 0.01))(norm_2)

    return conv_2

def upsampling(input_layer, filters):
    
    up_layer = keras.layers.UpSampling2D((2, 2))(input_layer)
    up_layer_2 = keras.layers.Conv2D(filters, (3, 3), padding = "same", activation = keras.layers.LeakyReLU(alpha = 0.01))(up_layer)
    norm_1 = tfa.layers.InstanceNormalization()(up_layer_2)
    
    return norm_1

def localization_module(input_layer, filters):
    
    conv_1 = keras.layers.Conv2D(filters, (3, 3), padding = "same", activation = keras.layers.LeakyReLU(alpha = 0.01))(input_layer)
    norm_1 = tfa.layers.InstanceNormalization()(conv_1)
    conv_2 = keras.layers.Conv2D(filters, (1, 1), padding = "same", activation = keras.layers.LeakyReLU(alpha = 0.01))(norm_1)
    norm_2 = tfa.layers.InstanceNormalization()(conv_2)
    

    return norm_2