# The implementation of U-Net fully convolutional neural network
Add weights along with image input tensor.  
"U-Net: Convolutional Networks for Biomedical Image Segmentation"  
https://arxiv.org/pdf/1505.04597.pdf

In [None]:
import datetime
import pathlib
import sys
import random

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display

from PIL import Image, ImageDraw

from sklearn.model_selection import train_test_split

from functools import partial

%load_ext autoreload
%autoreload 2

Image.MAX_IMAGE_PIXELS = None

In [None]:
SEED=241
random.seed(SEED)
np.random.seed(SEED)

In [None]:
plt.rcParams["figure.figsize"] = [26,19]

# 1 Create dataset

In [None]:
def load_map_and_mask(map_file, mask_file, workspace_dir):
    map_img = Image.open(str(map_file))
    
    polygons = []

    with open(str(mask_file)) as f:
        for line in f:
            line = line.strip()
            points = line.split(' ')
            polygon = [(int(xy[0]), int(xy[1])) for xy in [point.split(',') for point in points]]
            polygons.append(polygon)
       
    mask_img = Image.new('1', map_img.size, 0)

    for p in polygons:
        ImageDraw.Draw(mask_img).polygon(p, fill=1)
    
    mask_img.convert('RGB').save(str(workspace_dir/(mask_file.stem + '.jpg')), format='JPEG', quality=100)
    
    return map_img, mask_img  

def plot_masks(map_imgs, mask_imgs, weights=None):
    rows = len(map_imgs)
    cols = 2 if weights is None else 3
    
    for i, m in enumerate(map_imgs):
        plt.subplot(rows,cols, cols*i +1)
        plt.imshow(m)

        plt.subplot(rows,cols,cols*i +2)
        plt.imshow(mask_imgs[i])
        
        if weights is not None:
            plt.subplot(rows,cols,cols*i +3)
            plt.imshow(weights[i])
        
    plt.show()

def map_stats(map_img, mask_img):
    m = np.array(mask_img).astype(np.byte)
    n = np.sum(m == 0)
    k = np.sum(m == 1)
    
    print(map_img.size)
    print('zeros ratio', 0 if n == 0 else round(n/(n+k), 3))
    print('ones ratio', 0 if k == 0 else round(k/(n+k), 3))

In [None]:
dataset_dir = pathlib.Path().cwd() / 'data' / 'train'
src_images = list(dataset_dir.glob('**/*.tif'))
src_images = [img_p for img_p in src_images if '.mask.' not in img_p.name]

print('src_images', len(src_images))

In [None]:
timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
workspace_dir = pathlib.Path().cwd() / ('workspace_' + str(timestamp))
workspace_dir.mkdir(parents=True)

### 1.1 Read source Map images

In [None]:
src_images_and_masks = []

for img_p in src_images:
    map_img, mask = load_map_and_mask(img_p, img_p.parent/(img_p.stem + '.markup.txt'), workspace_dir) 
    src_images_and_masks.append({
        'map': map_img,
        'mask': mask
    })
    
    print('Stats for', img_p)
    map_stats(map_img, mask)
    print()

In [None]:
plot_masks(
    [src['map'] for src in src_images_and_masks][:2], 
    [src['mask'] for src in src_images_and_masks][:2]
)

### 1.2 Create dataset images by cropping source Maps

In [None]:
import skimage
import scipy.ndimage.morphology as morphology


def create_weight_map(y, wc=None, w0 = 5, sigma = 30):

    """
    Generate weight maps accordin to the U-Net white-paper.
        Parameters:
            mask: numpy_array - array of shape (image_height, image_width) 
                                representing binary mask of objects
            wc:   dict        - weight classes
            w0:    int        - border weight
            sigma: int        - border width

        Returns:
            numpy_array - weights of shape (image_height, image_width).
    """

    labels = skimage.measure.label(y)
    label_ids = sorted(np.unique(labels))[1:]
    background_ids = labels == 0

    if len(label_ids) > 1:
        distances = np.zeros((y.shape[0], y.shape[1], len(label_ids)))

        for i, label_id in enumerate(label_ids):
            # https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.ndimage.morphology.distance_transform_edt.html
            distances[:,:,i] = morphology.distance_transform_edt(labels != label_id)

        distances = np.sort(distances, axis=2)
        d1 = distances[:,:,0]
        d2 = distances[:,:,1]
        w = w0 * np.exp(-1/2*((d1 + d2) / sigma)**2) * background_ids

        if wc:
            class_weights = np.zeros_like(y)
            for k, v in wc.items():
                class_weights[y == k] = v
            w = w + class_weights
    else:
        w = np.zeros_like(y)

    return w

In [None]:
def cut_map_into_tiles(map_img, 
                       mask_img, 
                       tile_size=100, 
                       tile_resize=100, 
                       tiles_count=100,
                       tile_prefix='',
                       save_tiles=True,
                       save_dir=None):
    X = []
    Y = []
    W = []

    width, height = map_img.size
    
    top_left_coordinates = zip(
        np.random.randint(0, width - tile_size, tiles_count), 
        np.random.randint(0, height - tile_size, tiles_count)
    )

    map_img_in_rgb = map_img.convert('RGB')
    
    for i, (x,y) in enumerate(top_left_coordinates):
        tile = map_img_in_rgb.crop( (x, y, x+tile_size, y+tile_size) )
        tile_mask = mask_img.crop( (x, y, x+tile_size, y+tile_size) )

        tile = tile.resize((tile_resize, tile_resize))
        tile_mask = tile_mask.resize((tile_resize, tile_resize))
        
        mp = np.array(tile)
               
        mask = np.array(tile_mask).astype(np.byte)
        weights = create_weight_map(mask)
        
        X.append(mp)
        Y.append(mask)
        W.append(weights)
        
        if save_tiles:
            np.save(str(save_dir/(tile_prefix + 'map_' + str(i) + '.np')), mp)
            np.save(str(save_dir/(tile_prefix + 'mask_' + str(i) + '.np')), mask)
            np.save(str(save_dir/(tile_prefix + 'weights_' + str(i) + '.np')), mask)
        
    return X, Y, W

def tiles_stats(Y):
    zeros_count = 0
    ones_count = 0
    for y in Y:
        zeros_count += np.sum((y==0))
        ones_count += np.sum((y==1))

    print('zeros', zeros_count)
    print('ones', ones_count)
    
    if zeros_count > 0:
        print('zeros ratio', zeros_count/(ones_count + zeros_count))
    if ones_count > 0:
        print('ones ratio', ones_count/(ones_count + zeros_count))
    print()

In [None]:
TILES_SIZE = 1024
TILES_COUNT = 1
UNET_INPUT_SIZE = 256

X = []
Y = []
W = []

tiles_folder = workspace_dir / 'train_tiles'
tiles_folder.mkdir(parents=True, exist_ok=True)
    
for i, src in enumerate(src_images_and_masks):
    x, y, w = cut_map_into_tiles(
        src['map'], 
        src['mask'], 
        tile_size=TILES_SIZE,
        tile_resize=UNET_INPUT_SIZE,
        tiles_count=TILES_COUNT,
        tile_prefix=str(i),
        save_dir=tiles_folder
    )
    
    X += x
    Y += y
    W += w
    print('done', i)
    
print('X', len(X))
print('Y', len(Y))
print('W', len(W))
tiles_stats(Y)

### 1.3 View cropped images

In [None]:
def binary_mask_to_img(data):
    size = data.shape[::-1]
    databytes = np.packbits(data, axis=1)
    
    return Image.frombytes(mode='1', size=size, data=databytes)

In [None]:
# View train data
show_count = 10
train_maps = [Image.fromarray(x.astype('uint8'), 'RGB') for x in X[:show_count]]
train_masks = [binary_mask_to_img(y) for y in Y[:show_count]]
train_weights = [w for w in W[:show_count]]

plt.rcParams['figure.figsize'] = [50,50]
plot_masks(train_maps, train_masks, train_weights)

In [None]:
X = np.array(X)
Y = np.array(Y)[...,np.newaxis]
W = np.array(W)

print(X.shape, Y.shape, W.shape)

np.save(str(workspace_dir/'X.np'), X)
np.save(str(workspace_dir/'Y.np'), Y)
np.save(str(workspace_dir/'W.np'), W)

# 2. Create and train U-Net model

In [None]:
import tensorflow as tf
from tensorflow.python.client import device_lib
from tensorflow.keras import backend as K
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import (ModelCheckpoint, LearningRateScheduler, ModelCheckpoint, EarlyStopping, 
                                        ReduceLROnPlateau, TensorBoard, TerminateOnNaN, Callback)
from tensorflow.keras.models import load_model
from tensorflow.keras.models import model_from_json
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [None]:
tf.__version__

In [None]:
def get_available_gpus():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos if x.device_type == 'GPU']

get_available_gpus()

### 2.1 Create U-net

In [None]:
def create_unet(input_sz=512):
    image_input = Input(shape=(input_sz, input_sz, 3))
    weights_input = Input(shape=(input_sz, input_sz))
    
# contracting path (down-sampling)
    conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(image_input)
    conv2 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
    conv4 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
    conv6 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv6)

    conv7 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
    conv8 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv8)

    conv9 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
    conv10 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
    
# expansive path (up-sampling)
    up_conv11 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(
        UpSampling2D(size = (2,2))(conv10)
    )
    cancat1 = concatenate([conv8, up_conv11], axis = 3)
    conv12 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(cancat1)
    conv13 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv12)
    
    up_conv14 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(
        UpSampling2D(size = (2,2))(conv13)
    )
    cancat2 = concatenate([conv6, up_conv14], axis = 3)
    conv15 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(cancat2)
    conv16 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv15)
    
    up_conv17 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(
        UpSampling2D(size = (2,2))(conv16)
    )
    cancat3 = concatenate([conv4, up_conv17], axis = 3)
    conv18 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(cancat3)
    conv19 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv18)

    up_conv20 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(
        UpSampling2D(size = (2,2))(conv19)
    )
    cancat4 = concatenate([conv2, up_conv20], axis = 3)
    conv21 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(cancat4)
    conv22 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv21)
    conv22 = SpatialDropout2D(0.2)(conv22)
    
    conv23 = Conv2D(1, 1, 1, activation = 'sigmoid')(conv22)
       
    return Model(inputs = [image_input, weights_input], outputs = conv23), weights_input

In [None]:
unet, weights_tensor = create_unet(input_sz=UNET_INPUT_SIZE)
unet.summary()

### 2.2 Train U-net

In [None]:
def preprocess_inputs(X):
    return (2.0 / 255.0) * X - 1.0

In [None]:
X_preprocessed = preprocess_inputs(X)

In [None]:
from keras.losses import binary_crossentropy

SMOOTH = 1

# https://github.com/keras-team/keras/blob/613aeff37a721450d94906df1a3f3cc51e2299d4/keras/backend/tensorflow_backend.py#L3626
def weighted_bce(y_true, y_pred, weights, sample_weight=None):
    bce = binary_crossentropy(y_true, y_pred)
    return K.mean(weights*bce)

def jaccard_score(gt, pr, smooth=SMOOTH, threshold=None):
    """ 
        Jaccard index: https://en.wikipedia.org/wiki/Jaccard_index
    Args:
        gt: ground truth 4D keras tensor (B, H, W, C)
        pr: prediction 4D keras tensor (B, H, W, C)
        smooth: value to avoid division by zero
        threshold: value to round predictions (use `>` comparison), 
                   if `None` prediction prediction will not be round
    Returns:
        IoU/Jaccard score in range [0, 1]
    """
    axes = [1, 2]
        
    if threshold is not None:
        pr = K.greater(pr, threshold)
        pr = K.cast(pr, K.floatx())

    intersection = K.sum(gt * pr, axis=axes)
    union = K.sum(gt + pr, axis=axes) - intersection
    iou = (intersection + smooth) / (union + smooth)

    iou = K.mean(iou, axis=0)

    return iou

In [None]:
def create_default_callbacks(workspace_dir, batch_sz=1):
    timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    checkpoint_folder = workspace_dir / 'checkpoints' / str(timestamp)
    checkpoint_folder.mkdir(parents=True)
    tensorboard_folder = workspace_dir / 'tensorboard_logs' / str(timestamp)
    
    checkpoint = ModelCheckpoint(
        str(checkpoint_folder / 'model-{loss:.2f}.h5'),
        monitor='loss',
        verbose=1,
        save_best_only=True,
        save_weights_only=True,
        mode='auto',
        period=1
    )
    
    stop = EarlyStopping(monitor='loss', patience=200, mode='min', verbose=1)
    reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.5, patience=5, min_lr=1e-9, verbose=1, mode='min')
    
    tensorboard = TensorBoard(log_dir=str(tensorboard_folder),
                              histogram_freq=0,
                              batch_size=batch_sz,
                              write_graph=False, 
                              write_grads=False, 
                              write_images=False,
                              embeddings_freq=0, 
                              embeddings_layer_names=None, 
                              embeddings_metadata=None, 
                              embeddings_data=None)
    
    return [reduce_lr, TerminateOnNaN(), checkpoint, tensorboard], checkpoint_folder


def train(model, weights_tensor, X_train, Weights, Y_train, workspace_dir, epochs=1, batch_sz=1):
    #sgd = SGD(lr=0.01, decay=1e-6, momentum=0.99, nesterov=True)
    
    unet.compile(
        #optimizer=sgd, 
        optimizer='Adam',
        loss=partial(weighted_bce, weights=weights_tensor),
        metrics=[jaccard_score, 'binary_accuracy']
    )
        
    callbacks, checkpoint_dir = create_default_callbacks(workspace_dir, batch_sz=batch_sz)
    
    model_json = model.to_json()
    with open(str(checkpoint_dir/'graph.json'), 'w') as json_file:
        json_file.write(model_json)
       
    return model.fit(
        [X_train, Weights], Y_train,
        batch_size=batch_sz,
        epochs=epochs,
        callbacks=callbacks,
        shuffle=True
    )

In [None]:
train(
    unet, 
    weights_tensor, 
    X_preprocessed, 
    W, 
    Y, 
    workspace_dir, 
    epochs=700, 
    batch_sz=16
)

# 3 Check the model

### 3.1 Load trained model

In [None]:
import json

final_model_path = workspace_dir / 'checkpoints' / '2019-06-26-08-44-13' / 'model-0.23.h5'

with open(str(final_model_path.parent / 'graph.json'), 'r') as json_file:
    fitted_model = model_from_json(json_file.read())
    
fitted_model.load_weights(str(final_model_path), by_name=True)

In [None]:
fitted_model.summary()

### 3.2 Use test images

In [None]:
X_test = []
Y_test = []

for i, src in enumerate(src_images_and_masks):
    x, y = cut_map_into_tiles(
        src['map'], 
        src['mask'], 
        tile_size=TILES_SIZE,
        tile_resize=UNET_INPUT_SIZE,
        tiles_count=10,
        tile_prefix=str(i),
        save_tiles=False
    )
    
    X_test += x
    Y_test += y
    print('done', i)

### 3.3 Run inference

In [None]:
def convert_grayscale_data_to_red_rgba(mask, alpha_value=50):
    data = (mask * 255).astype('uint8')
    alpha = (mask * alpha_value).astype('uint8')

    data = data.reshape((data.shape[0], data.shape[1], 1))
    npad = ((0, 0), (0, 0), (0, 2))
    rgba_array = np.pad(data, pad_width=npad, mode='constant', constant_values=0)

    rgba_array = np.insert(
        rgba_array,
        3,
        alpha,
        axis=2
    )
    return Image.fromarray(rgba_array, 'RGBA')

def apply_predicted_mask(orig_image, predicted_2d_values):
    predicted_img = convert_grayscale_data_to_red_rgba(predicted_2d_values)

    orig_image = Image.fromarray(orig_image.astype('uint8'), 'RGB')
    orig_image = orig_image.convert('RGBA')
    orig_image.paste(predicted_img, (0, 0), predicted_img)

    return orig_image.convert('RGB')

In [None]:
for x,y in zip(X_test, Y_test):
    x_np = preprocess_inputs(np.array(x))
    predicted = fitted_model.predict(x_np[np.newaxis,:,:,:])
    predicted_2d = predicted.reshape((predicted.shape[1], predicted.shape[2]))
    
    display(apply_predicted_mask(x, predicted_2d))