### Training with train/validation split

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import numpy as np
from keras.models import load_model
from keras.optimizers import SGD, Adam
from keras_contrib.losses import DSSIMObjective
from dataset import ProbaVDataset
from model import create_model, BatchNorm, PSNR
from training import lr_finder, train, predict, predict_on_test

dataset = ProbaVDataset(batch_size=8, validation_split=0.2, upsample_input=True)
model = create_model(input_shape=(384, 384, 1))
model.summary()
custom_objects = {'PSNR': PSNR, 'BatchNorm':BatchNorm, 'DSSIMObjective':DSSIMObjective()}

First we find the optimal maximum learning rate with lr_finder. It should be just before the loss starts to explode:

In [None]:
hyperparams = {
    'optimizer' : Adam(),
    'loss_fn' : DSSIMObjective(),
    'start_lr' : 1e-6,
    'end_lr' : 1e2,
    'max_loss' : 1.0 
}

lr_finder(model, dataset, hyperparams, custom_objects)

Now we train the network.

In [None]:
hyperparams = {
        'optimizer' : Adam(),
        'loss_fn' : DSSIMObjective(),
        'num_epochs' : 20,
        'max_lr' : 3e-3,
        'div_factor' : 25.0,
        'max_momentum' : 0.90,
        'num_cycles' : 1 
      }

model = train('clr', model, dataset, hyperparams)

Let's see how well it performs against the validation set.

In [None]:
sr_maps = predict(model, dataset.validation_data, batch_size=dataset.batch_size, display_n=20)

In [None]:
sr_maps = np.clip(sr_maps.squeeze(), 0.0, 1.0)
highres_maps = dataset.validation_data[1].squeeze()
status_maps = dataset.validation_sms
scene_ids = dataset.validation_scene_ids
print('Score on validation set:', dataset.score_images(sr_maps, highres_maps, status_maps, scene_ids))

### Training with the whole training set

Now let's train again, this time with the whole training set, without splitting it into training and validation sets.

In [None]:
import keras.backend as K
K.clear_session()
del dataset

dataset = ProbaVDataset(batch_size=8, validation_split=0.0, upsample_input=True)
model = create_model(input_shape=(384, 384, 1))
hyperparams = {
        'optimizer' : Adam(),
        'loss_fn' : DSSIMObjective(),
        'num_epochs' : 20,
        'max_lr' : 3e-3,
        'div_factor' : 25.0,
        'max_momentum' : 0.90,
        'num_cycles' : 1 
      }

model = train('clr', model, dataset, hyperparams)
model.save('model.h5')

### Inference on test set

Finally, perform inference on the test set and write the images to disk for submission

In [None]:
dataset.reset_for_testing()
model = load_model('model.h5', custom_objects)
predict_on_test(model, dataset, path='submission', display_n=20)