## Training U-Net

The first step is to load the modules we need.

In [1]:
import tensorflow as tf
from tensorflow import keras
from keras.optimizers import SGD, RMSprop

from utils import rate_scheduler, train_model_sample
from models import unet as unet

import os
import datetime
import numpy as np

Using TensorFlow backend.


We define the training and validation datasets we want to use to train the classifier.

In [2]:
dataset_training_CF = "datasets/nucleiSegmentation_E2Fs/training_CF"
dataset_validation_CF = "datasets/nucleiSegmentation_E2Fs/validation_CF"

dataset_training_WF = "datasets/nucleiSegmentation_E2Fs/training_WF"
dataset_validation_WF = "datasets/nucleiSegmentation_E2Fs/validation_WF"

dataset_training_CFWF = "datasets/nucleiSegmentation_E2Fs/training_CFWF"
dataset_validation_CFWF = "datasets/nucleiSegmentation_E2Fs/validation_CFWF"

direc_save = "./trainedClassifiers/nucleiSegmentation/"

We define the global parameters used for training the classifier: <br>
    1) the image dimensions (imaging_field_x and imaging_field_y) <br>
    2) the number of classes <br>
    3) the number of images trained at once (batch_size) <br>
    4) the number of epochs <br>
    5) the number of data augmentations <br>
    6) the class to dilate, if any <br> 
    7) the dilation radius for the class to dilate, if any <br> <br>
We also set up the optimizer that will be used for training.

In [3]:
# parameters
imaging_field_x = 256
imaging_field_y = 256
nb_classes = 3
batch_size = 1
nb_epochs = 10
nb_augmentations = 100
class_to_dilate = [1,0,0]
dilation_radius = 1

# optimizer
optimizer = RMSprop(lr=1e-4)
lr_sched = rate_scheduler(lr = 1e-4, decay = 0.99)

We define the model and train it.

In [None]:
# Confocal
model = unet(nb_classes, imaging_field_x, imaging_field_y)
expt = "Unet_CFtraining_DA100_10ep"
train_model_sample(model = model, 
                   dataset_training = dataset_training_CF, dataset_validation = dataset_validation_CF, 
                   optimizer = optimizer, expt = expt, batch_size = batch_size, n_epoch = nb_epochs, 
                   imaging_field_x = imaging_field_x, imaging_field_y = imaging_field_y,
                   direc_save = direc_save, lr_sched = lr_sched, nb_augmentations = nb_augmentations,
                   class_to_dilate = class_to_dilate, dil_radius = dilation_radius)

del model

# Wide-field
model = unet(nb_classes, imaging_field_x, imaging_field_y)

expt = "Unet_WFtraining_DA100_10ep"
train_model_sample(model = model, 
                   dataset_training = dataset_training_WF, dataset_validation = dataset_validation_WF, 
                   optimizer = optimizer, expt = expt, batch_size = batch_size, n_epoch = nb_epochs,
                   imaging_field_x = imaging_field_x, imaging_field_y = imaging_field_y, direc_save = direc_save, 
                   lr_sched = lr_sched, nb_augmentations = nb_augmentations,
                   class_to_dilate = class_to_dilate, dil_radius = dilation_radius)

del model

# Confocal & Wide-field
model = unet(nb_classes, imaging_field_x, imaging_field_y)

expt = "Unet_CFWFtraining_DA100_10ep"
train_model_sample(model = model, 
                   dataset_training = dataset_training_CFWF, dataset_validation = dataset_validation_CFWF, 
                   optimizer = optimizer, expt = expt, batch_size = batch_size, n_epoch = nb_epochs,
                   imaging_field_x = imaging_field_x, imaging_field_y = imaging_field_y, direc_save = direc_save, 
                   lr_sched = lr_sched, nb_augmentations = nb_augmentations,
                   class_to_dilate = class_to_dilate, dil_radius = dilation_radius)

del model

W0528 17:26:00.519341 140700929681216 deprecation_wrapper.py:119] From /home/thierry/.local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0528 17:26:00.521277 140700929681216 deprecation_wrapper.py:119] From /home/thierry/.local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0528 17:26:00.525899 140700929681216 deprecation_wrapper.py:119] From /home/thierry/.local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W0528 17:26:00.567018 140700929681216 deprecation_wrapper.py:119] From /home/thierry/.local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session in

9 training images
3 validation images
Epoch 1/10
  8/900 [..............................] - ETA: 39:25 - loss: 1.2224 - acc: 0.5256