# Semantic Segmentation with Keras

In this exercise, you'll use the U-Net network to perform binary classification and segmentation for images of planes.

> **Important**: Using the U-Net model is resource-intensive. before running the code in this notebook, shut down all other notebooks in this library (In each open notebook other than this one, on the **File** menu, click **Close and Halt**). If you experience and Out-of-Memory (OOM) error when running code in this notebook, shut down this entire library, and then reopen it and open only this notebook.

## Install Keras

To begin with, we'll install the latest version of Keras.

In [None]:
!pip install --upgrade keras

## Explore the Training Data

The training data for a U-Net model consists of two kinds of input:

- **Image files**: The images that represent the *features* on which we want to train the model.
- **Mask files**: Images of the object masks that the network will be trained to predict - these are the *labels*.

In this example, we're going to use U-Net for binary classification of airplanes images, so there's only one class of object - and therefore one class of mask. We've deliberately made this example as simple as possible, partly to make it easier to understand what's going on, and partly to ensure it can be run in a resource-constrained environment. 

Let's take a look at the training images and masks:

In [None]:
import os
from matplotlib import pyplot as plt
import skimage.io as io
import numpy as np
%matplotlib inline


fig = plt.figure(figsize=(12, 60))


train_dir = '../../data/segmentation/train'
image_dir = os.path.join(train_dir,"image/plane")
mask_dir = os.path.join(train_dir,"mask/plane")

files = os.listdir(image_dir)
rows = len(files)
cell = 0
for file in files:
    cell += 1
    
    # Open the image and mask files
    img_path = os.path.join(image_dir, file)
    img = io.imread(img_path, as_gray = True)
    
    mask_path = os.path.join(mask_dir, file)
    mask = io.imread(mask_path, as_gray = True)
    
    # plot the image
    a=fig.add_subplot(rows,3,cell)
    imgplot=plt.imshow(img, "gray")
    cell += 1

    # plot the mask
    a=fig.add_subplot(rows,3,cell)
    imgplot=plt.imshow(mask, "gray")
    cell += 1
    
    # Plot them overlaid
    a=fig.add_subplot(rows,3,cell)
    imgplot=plt.imshow(img, "gray")
    imgplot=plt.imshow(mask, "gray", alpha=0.4)

plt.show()



## Import the U-Net Code

The code to implement U-Net is provided in two python files:

- **model.py**: This file contains the code that implements the U-Net model
- **data.py**: This file contains functions to help load and prepare training data.

> **Tip**: You should explore the code in these files to get a better understanding of the way the model works.


In [None]:
from unet_keras.data import *
from unet_keras.model import *

The ouput from the code above shows the training images with their corresponding mask labels, and finally the mask overlaid on the image so you can clearly see that the masks represent the pixels that belong to the plane objects in the images.

> **Note**: We deliberately chose images in which the plane objects are clearly contrasted with the background to make it easier to train with a very small number of training images and a very small amount of training!

## Load the Training Data
We have a very small number of training images, so we'll apply some data augmentation to randomly flip, zoom, shear, and otherwise transform the images for each batch.

In [None]:
data_gen_args = dict(rotation_range=0.2,
                    width_shift_range=0.05,
                    height_shift_range=0.05,
                    shear_range=0.05,
                    zoom_range=0.05,
                    horizontal_flip=True,
                    fill_mode='nearest')
train_images = trainGenerator(2,train_dir,'image','mask',data_gen_args,save_to_dir = None)

## Download the Model Weights
The model has already been partially trained, so we'll download the trained weights as a starting point.

In [None]:
!wget "https://aka.ms/unet.h5" -O ~/unet.h5

## Train the Model

Now we're ready to train the U-Net model. We'll train it from the training generator we created, and save the model weights after each epoch if the loss has improved. In this example, to reduce the required compute resources we'll train it for just one epoch with minimal batches. In reality, you'd need to train the model over several epochs on a GPU-based computer.

> _**Note**: This will take a while on a non-GPU machine - go get some coffee!_

In [None]:
model = unet()
home = os.path.expanduser("~")
weights_file = os.path.join(home, "unet.h5")
model.load_weights(weights_file)
model_checkpoint = ModelCheckpoint(weights_file, monitor='loss',verbose=1, save_best_only=True)
model.fit_generator(train_images,steps_per_epoch=1,epochs=1,callbacks=[model_checkpoint])

## Test the Trained Model

OK, let's see how well our trained model does with some images of airplanes it hasn't seen.

In [None]:
import os
from matplotlib import pyplot as plt
import skimage.io as io
import numpy as np
%matplotlib inline

model = unet()
model.load_weights(weights_file)

fig = plt.figure(figsize=(12, 60))


test_dir = '../../data/segmentation/test'

files = os.listdir(test_dir)
rows = len(files)
cell = 0
for file in files:
    cell += 1
    # Open the file
    img_path = os.path.join(test_dir, file)
    img = io.imread(img_path, as_gray = True)
    
    src_img = img
    a=fig.add_subplot(rows,3,cell)
    imgplot=plt.imshow(src_img, "gray")
    cell += 1

    img = np.reshape(img,img.shape+(1,))
    mask_predictions = model.predict([[img]])
    mask = mask_predictions[0]
    img_mask = mask.reshape(mask.shape[0], mask.shape[1])

    a=fig.add_subplot(rows,3,cell)
    imgplot=plt.imshow(img_mask, "gray")
    cell += 1


    a=fig.add_subplot(rows,3,cell)
    imgplot=plt.imshow(src_img, "gray")
    imgplot=plt.imshow(img_mask, "binary", alpha=0.6)

plt.show()



It's not fantastic, largely because we used such a small amount of data; but hopefully it serves to demonstrate the principles of semantic segmentation with U-Net.

## Acknowledgements and Citations

The U-Net architecture is documented by its inventors (Olaf Ronneberger, Philipp Fischer, and Thomas Brox), at https://arxiv.org/abs/1505.04597.

The Keras implementation of U-Net used in this exercise is based on zhixuhao's work at https://github.com/zhixuhao/unet, with some simplifications. 

The data used in this exercise includes images adapted from the PASCAL Visual Object Classes Challenge (VOC2007) dataset at http://host.robots.ox.ac.uk/pascal/VOC/voc2007/.


    @misc{pascal-voc-2007,
        author = "Everingham, M. and Van~Gool, L. and Williams, C. K. I. and Winn, J. and Zisserman, A.",
        title = "The {PASCAL} {V}isual {O}bject {C}lasses {C}hallenge 2007 {(VOC2007)} {R}esults",
        howpublished = "http://www.pascal-network.org/challenges/VOC/voc2007/workshop/index.html"}

