# Train a HED CNN for shoreline detection
This shoreline detector uses a Holistically-Nested Edge Detection framework (Xie and Tu, 2015) -http://openaccess.thecvf.com/content_iccv_2015/papers/Xie_Holistically-Nested_Edge_Detection_ICCV_2015_paper.pdf.

This notebook contains the code to train a HED CNN.

Please note, at the current time the image and shoreline dataset used to train the network is not publicly available. As such, unless you have your own dataset you should try running the "HED_edge_detection_predict_unseen" notebook.

In [1]:
# magic
%load_ext autoreload
%autoreload 2
%pdb 1
%matplotlib inline

Automatic pdb calling has been turned ON


In [2]:
# imports
from functions.pytorch_models import hed_cnn, Trainer, pretrained_weights
from functions.data_preprocessing import load_images, augment_images_kp, mask_to_uv
from functions.data_preprocessing import load_train_test_imagedata, save_train_test_imagedata
from functions.data_visualisation import plot_predictions, plot_refined_predictions

import scipy.io as sio
import matplotlib.pyplot as plt
import numpy as np

import os

import torch

from sklearn.model_selection import train_test_split

import imgaug as ia
import imgaug.augmenters as iaa

from ipywidgets import interact, fixed, IntSlider, FloatSlider, interact_manual

## Setup HED model
1. Import the model from model.py
2. Define and apply training parameters

In [None]:
basePath = './data/'
partition, labels = load_train_test_imagedata(basePath)

In [None]:
# load the model
applyWeights = True
modelSave = 'A'

#pretrained model
weightsPath = './pytorch/pretrained_models/vgg16-397923af.pth'

hedModel = hed_cnn()
hedModel = pretrained_weights(hedModel, weightsPath, applyWeights)

## Train the model
1. Specify and train the model
2. Save the model
2. Make predictions for training and test

In [None]:
# setup training
modelParams = {
    'epochs': 25,
    'batchSize': 8,
    'lr': 3e-4,
    'lrDecay': 5e-1,
    'lossFunction': 'weightedBCE', # weightedBCEReg
    'cuda': False,
    'basePath': basePath,
}

fuse_params = list(map(id, hedModel.fuse.parameters()))
base_params = filter(lambda p: id(p) not in fuse_params, hedModel.parameters())

modelParams['optimiser'] = torch.optim.Adam(filter(lambda p: p.requires_grad, hedModel.parameters()), lr=modelParams['lr'])

# initialize trainer class
trainer = Trainer(hedModel, partition, labels, modelParams)

In [None]:
#train
history = trainer.train()

In [None]:
# save the model
trainedHedModel = trainer.model
torch.save(trainedHedModel.state_dict(), './models/shorelineDetectModel_{}.pt'.format(modelSave))

In [None]:
# predict
trainX, trainPred, trainY = trainer.predict('train')
valX, valPred, valY  = trainer.predict('validation')

## View model output
For a small sample (one batch) of the training and validation datasets

### Raw output
From each layer.

In [None]:
# training output
print('Training output...')
interact(plot_predictions,
         prntNum =IntSlider(
             value=0,
             min=0,
             step=1,
             max=trainX.shape[0]-1,
             continuous_update=False,
         ),
         dataX=fixed(trainX),
         dataY=fixed(trainY),
         dataPred=fixed(trainPred),
         jj=IntSlider(
             value=5,
             min=0,
             step=1,
             max=5,
             continuous_update=False,
         ),
         thres=FloatSlider(
             value=0.5,
             min=0.05,
             step=0.05,
             max=0.95,
             continuous_update=False,
         ),
        )

In [None]:
print('Validation output...')
prntNum = 3
interact(plot_predictions,
         prntNum =IntSlider(
             value=0,
             min=0,
             step=1,
             max=valX.shape[0]-1,
             continuous_update=False,
         ),
         dataX=fixed(valX),
         dataY=fixed(valY),
         dataPred=fixed(valPred),
         jj=IntSlider(
             value=5,
             min=0,
             step=1,
             max=5,
             continuous_update=False,
         ),
         thres=FloatSlider(
             value=0.5,
             min=0.05,
             step=0.05,
             max=1,
             continuous_update=False,
         ),
        )

### Final output
From weighted combination.

In [None]:
print('Training output...')
interact(plot_refined_predictions,
         prntNum =IntSlider(
             value=0,
             min=0,
             step=1,
             max=trainX.shape[0]-1,
             continuous_update=False,
         ),
         dataX=fixed(trainX),
         dataY=fixed(trainY),
         dataPred=fixed(trainPred),
         thres=FloatSlider(
             value=0.7,
             min=0.05,
             step=0.05,
             max=0.95,
             continuous_update=False,
         ),
        )

In [None]:
print('Validation output...')
interact(plot_refined_predictions,
         prntNum =IntSlider(
             value=0,
             min=0,
             step=1,
             max=valX.shape[0]-1,
             continuous_update=False,
         ),
         dataX=fixed(valX),
         dataY=fixed(valY),
         dataPred=fixed(valPred),
         thres=FloatSlider(
             value=0.7,
             min=0.05,
             step=0.05,
             max=0.95,
             continuous_update=False,
         ),
        )