In [None]:
import os
#import glob
#import zipfile
#import functools

#import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['axes.grid'] = False
mpl.rcParams['figure.figsize'] = (12,12)

from sklearn.model_selection import train_test_split
#from sklearn.model_selection import GridSearchCV
from itertools import combinations
#import matplotlib.image as mpimg
#import pandas as pd
#from PIL import Image
#import json


In [None]:
import tensorflow as tf
#import tensorflow.contrib as tfcontrib
#from tensorflow.python.keras import layers
#from tensorflow.python.keras import losses
#from tensorflow.python.keras import models
from tensorflow.python.keras import backend as K
#from tensorflow.python.keras.wrappers.scikit_learn import KerasClassifier


In [None]:
from Segmentation.solver import Solver
from Segmentation.architecture.unet import Unet
from Segmentation.preprocess import Data_Preprocess
from Segmentation.assistant import assistant
import Segmentation.vis_utils as vis

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
dirname = os.getcwd()
dataset_path = os.path.join( dirname , "DATA/dataset")
train_path = os.path.join(dataset_path , "training")
test_path = os.path.join( dataset_path , "testing")
csv_path = os.path.join( dataset_path , "train.txt")



In [None]:
test = Data_Preprocess(dataset_path = dataset_path
                       ,train_path = train_path
                       ,test_path = test_path
                       ,image_shape = (64, 64, 3)
                       ,batch_size = 1
                       ,csv_path = csv_path
                       )

In [None]:
x_train_filenames, x_val_filenames, y_train_filenames, y_val_filenames = test.get_train_val_split_paths()

# Visualize
Let's take a look at some of the examples of different images in our dataset. 

In [None]:
vis.visualize_pairs(x_train_filenames,y_train_filenames,num = 5 )

# Set up 

Let’s begin by setting up some parameters. We’ll standardize and resize all the shapes of the images. We’ll also set up some training parameters: 

In [None]:
img_shape = (64, 64, 3)  #(256, 256, 3)
batch_size = 10 #3
epochs = 2

# Build our input pipeline with `tf.data`


## Set up train and validation datasets
Note that we apply image augmentation to our training dataset but not our validation dataset. 

In [None]:
tr_cfg = {
    'resize': [img_shape[0], img_shape[1]],
    'scale': 1 / 255.,
    'hue_delta': 0.1,
    'horizontal_flip': True,
    'width_shift_range': 0.1,
    'height_shift_range': 0.1
}
tr_preprocessing_fn = functools.partial(test._augment, **tr_cfg)

In [None]:
val_cfg = {
    'resize': [img_shape[0], img_shape[1]],
    'scale': 1 / 255.,
}
val_preprocessing_fn = functools.partial(test._augment, **val_cfg)

In [None]:
train_ds = test.get_baseline_dataset(x_train_filenames,
                                     y_train_filenames,
                                     preproc_fn=tr_preprocessing_fn,
                                    )
val_ds = test.get_baseline_dataset(x_val_filenames,
                                   y_val_filenames, 
                                   preproc_fn=val_preprocessing_fn,
                                  )

## Let's see if our image augmentor data pipeline is producing expected results

In [None]:
temp_ds = test.get_baseline_dataset(x_train_filenames, 
                               y_train_filenames,
                               preproc_fn=tr_preprocessing_fn,
                               shuffle=False)
# Let's examine some of these augmented images
data_aug_iter = temp_ds.make_one_shot_iterator()
next_element = data_aug_iter.get_next()
with tf.Session() as sess: 
  batch_of_imgs, label = sess.run(next_element)

  # Running next element in our graph will produce a batch of images
  plt.figure(figsize=(10, 10))
  img = batch_of_imgs[0]

  plt.subplot(1, 2, 1)
  plt.imshow(img)

  plt.subplot(1, 2, 2)
  #plt.imshow(label[0, :, :, 0])
  plt.imshow(label[0, :, :, 0])
  plt.show()

# Build the model

## The Keras Functional API


## Train your model
Training your model with `tf.data` involves simply providing the model's `fit` function with your training/validation dataset, the number of steps, and epochs.  

We also include a Model callback, [`ModelCheckpoint`](https://keras.io/callbacks/#modelcheckpoint) that will save the model to disk after each epoch. We configure it such that it only saves our highest performing model. Note that saving the model capture more than just the weights of the model: by default, it saves the model architecture, weights, as well as information about the training process such as the state of the optimizer, etc.

In [None]:
num_train_examples = len(x_train_filenames)
num_val_examples = len(x_val_filenames)

In [None]:
optim_config = {
                    'lr': 1e-5,
                    'decay': 0.9,
                    'rho': 0.9,
                    'epsilon': 1e-10
                }
params = {  'num_train_examples' : num_train_examples
            ,'num_val_examples' : num_val_examples
            ,'batch_size' : batch_size
            ,'num_epochs' : epochs
            ,'loss' : 'bce_dice_loss'
            ,'optimizer' : 'rms'
            ,'optimizer_config' : optim_config
            ,'metrics' : ['dice_loss','f1','recall','precision']
            ,'save_model_path' : os.path.join(os.getcwd(),'weights.hdf5')
            ,'verbose' : True
         }
history = assis.run_training()

# Old version of training. Don't use encoder decoder spliter

model = Unet()
#model.build_model()

optim_config = {
                    'lr': 1e-5,
                    'decay': 0.9,
                    'rho': 0.9,
                    'epsilon': 1e-10
                }

solver = Solver(model.build_model()
                ,train_ds, val_ds
                ,num_train_examples = num_train_examples
                ,num_val_examples = num_val_examples
                ,batch_size = 2
                ,num_epochs = epochs
                ,loss = 'bce_dice_loss'
                ,optimizer = 'rms'
                ,optimizer_config = optim_config
                ,metrics = ['dice_loss','f1','recall','precision']
                ,save_model_path = os.path.join(os.getcwd(),'weights.hdf5')
                ,verbose = True
               )
history = solver.train()


# Visualize training process

In [None]:
train_dice = history.history['dice_loss']
val_dice = history.history['val_dice_loss']

train_loss = history.history['loss']
val_loss = history.history['val_loss']


epochs_range = range(epochs)

plt.figure(figsize=(20, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_dice, label='Training Dice Loss')
plt.plot(epochs_range, val_dice, label='Validation Dice Loss')
plt.xlabel("Number of epochs")
plt.legend(loc='upper right')
plt.title('Training and Validation Dice Loss')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.xlabel("Number of epochs")
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.show()

In [None]:
train_recall = history.history['recall']
val_recall   = history.history['val_recall']

train_precision = history.history['precision']
val_precision   = history.history['val_precision']

plt.figure(figsize=(20, 8))

plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_precision, label='Training Precision')
plt.plot(epochs_range, val_precision, label='Validation Precision')
plt.xlabel("Number of epochs")
plt.legend(loc='upper right')
plt.title('Training and Validation Precisions')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_recall, label='Training Recall')
plt.plot(epochs_range, val_recall, label='Validation Recall')
plt.xlabel("Number of epochs")
plt.legend(loc='upper right')
plt.title('Training and Validation Recall')


plt.show()



In [None]:
train_f1 = history.history['F1']
val_f1   = history.history['val_F1']

plt.subplot(1, 1, 1)
plt.plot(epochs_range, train_f1, label='Training F1 score')
plt.plot(epochs_range, val_f1, label='Validation F1 score')
plt.legend(loc='upper right')
plt.xlabel("Number of epochs")
plt.title('Training and Validation F1 scores')

plt.show()

# Visualize actual performance 
We'll visualize our performance on the validation set.

Note that in an actual setting (competition, deployment, etc.) we'd evaluate on the test set with the full image resolution. 

In [None]:
model_recover = models.load_model(assis.weights_path, custom_objects={'bce_dice_loss': assis.decoder.bce_dice_loss
                                                              ,'dice_loss': assis.solver.dice_loss
                                                              ,'precision' : assis.solver.precision
                                                              ,'recall' : assis.solver.recall
                                                              ,'F1' : assis.solver.F1
                                                             })

In [None]:
assis.decoder

In [None]:
# Let's visualize some of the outputs 
vis.visualize_result_triples(assis,assis.weights_path,val_ds)