<a href="https://colab.research.google.com/github/stegmuel/binarization/blob/master/train_unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Connect notebook to drive

In [111]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [0]:
from google.colab import files

In [0]:
!git clone https://stegmuel:JtBz06100307@github.com/stegmuel/binarization.git

## Import useful packages

In [0]:
source_path = '/content/binarization/source/'

In [0]:
import importlib.util
import sys
sys.path.append(source_path)
from models import *
from classes import *
from keras.optimizers import Adam
import keras.backend as K
from keras.models import load_model
import zipfile

## Useful functions

In [0]:
def jaccard_accuracy(y_true, y_pred):
    eps = 1.0
    num = K.sum(y_true * y_pred) + 1.0
    den = K.sum(y_true + y_pred) - num + 1.0
    return num / den


def jaccard_loss(y_true, y_pred):
    return 1 - jaccard_accuracy(y_true, y_pred)


def dice_accuracy(y_true, y_pred):
    eps = 1.0
    num = 2 * K.sum(y_true * y_pred) + 1.0
    den = K.sum(y_true + y_pred) + 1.0
    return num / den


def dice_loss(y_true, y_pred):
    return 1 - dice_accuracy(y_true, y_pred)

## Define base directories

In [0]:
data_path = '/content/drive/My Drive/Colab Notebooks/binarization/data/'
training_path = '/content/drive/My Drive/Colab Notebooks/binarization/data/training/'
models_path = '/content/drive/My Drive/Colab Notebooks/binarization/models'

In [0]:
!rm -r /content/drive/My\ Drive/Colab\ Notebooks/binarization/data/training
# !rm /content/drive/My\ Drive/Colab\ Notebooks/binarization/data/training.zip

In [127]:
!ls /content/drive/My\ Drive/Colab\ Notebooks/binarization/data

training  training.zip


In [0]:
!mkdir /content/drive/My\ Drive/Colab\ Notebooks/binarization/data/training

In [0]:
with zipfile.ZipFile(os.path.join(data_path, 'training.zip'), 'r') as zip_ref:
  zip_ref.extractall(training_path)

In [130]:
!ls /content/drive/My\ Drive/Colab\ Notebooks/binarization/data/training

train  train.lst  validation  validation.lst


## Get the data

In [0]:
# Get train and validation images
train_images_names, train_images_gt_names = \
  get_images_names(os.path.join(training_path, 'train.lst'))
validation_images_names, validation_images_gt_names = \
  get_images_names(os.path.join(training_path, 'validation.lst'))

# Create the generators
train_generator = DataGenerator(train_images_names, 
                                train_images_gt_names, 
                                32,
                                os.path.join(data_path, 'training/train/'))

validation_generator = DataGenerator(validation_images_names, 
                                     validation_images_gt_names, 
                                     32,
                                     os.path.join(data_path, 'training/validation'))

## Load the model

In [0]:
if os.path.exists(os.path.join(models_path, 'UNet.h5')):
    UNet = load_model(os.path.join(models_path, 'UNet.h5'),
                      custom_objects={'jaccard_loss': jaccard_loss, 
                                      'jaccard_accuracy': jaccard_accuracy,
                                      'dice_loss': dice_loss, 
                                      'dice_accuracy': dice_accuracy})
else:
    UNet = unet()
    UNet.compile(optimizer=Adam(), 
                 loss=jaccard_loss, 
                 metrics=['accuracy', 
                 jaccard_accuracy, 
                 dice_accuracy])

In [136]:
UNet.fit_generator(
    generator=train_generator,
    steps_per_epoch=train_generator.__len__(),
    validation_data=validation_generator,
    validation_steps=validation_generator.__len__(),
    epochs=1,
    shuffle=True,
    use_multiprocessing=True,
    workers=4,
)
UNet.save(os.path.join(models_path, 'UNet.h5'))

Epoch 1/1


In [137]:
!ls /content/drive/My\ Drive/Colab\ Notebooks/binarization/models/

UNet.h5


## Check behaviour of model

In [0]:
import matplotlib.pyplot as plt
import numpy as np

In [0]:
index = 0
image_path = os.path.join(training_path, validation_images_names[index])
image_gt_path = os.path.join(training_path, validation_images_names[index])
image = np.load(image_path)
image_gt = np.load(image_gt_path)
input_image = np.expand_dims(np.expand_dims(image, axis=2), axis=0)

## Get a prediction