# Train

TODO: Refer to dissertation chapter on why existing datasets are insufficient

This notebooks trains a convolutional neural network (CNN) to segment UAV
imagery into `ground`, `water`, and `building` classes. This process is called
*semantic segmentation*.

The purpose of this functionality for the Lakehopper system is twofold:
- Provide awareness of the situation under the drone at the moment of flight
  ('in-situ'). This is as opposed to beforehand (when a certain dataset was
  captured or when the drone previously collected data). A lake might have dried
  up or a boat might be present where it wasn't previously.
- Provide a wide-area map based on orthographic imagery with features that are
  relevant to lakehopper. Other maps do not necessarily have the same feature
  definitions as lakehopper. The `water` class for Lakehopper is defined as 'a
  body of surface water where autonomous landing is possible'. Swamps or water
  obscured by bridges or vegetation do not fit this definition. They are however
  still classified as 'water' on most topographic maps.
  
The dataset is too large to upload via cloud storage (58.6 GiB). Please contact pieter@pfiers.net for a copy.

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

from datetime import datetime
from itertools import islice
import sys, os, csv

from IPython.display import clear_output
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.python.distribute.tpu_strategy import TPUStrategy
os.environ["SM_FRAMEWORK"] = "tf.keras"
import segmentation_models as sm

In [None]:
# Convenience methods
from visualise import show
from tpu import resolve_tpu_strategy, get_tpu_devices
from dataset import load_dataset, split_dataset_paths, filter_out_only_ground_paths
from masks import colorize_byte_mask
from models.unet import create_unet
from models.fcn import create_fcn8
from models.fpn import create_fpn

## Set up TPU

To speed up model training, we use [Google Cloud Platform](https://cloud.google.com/)'s "[TPU's](https://cloud.google.com/tpu)". TPU's use a machine learning ASIC.

In [None]:
tpu_strategy = resolve_tpu_strategy('semseg-node-us')

In [None]:
nbro_tpu_devices = len(get_tpu_devices())
print(f"{nbro_tpu_devices} TPU devices connected!")

In [None]:
PARALLEL_OPS = tf.data.AUTOTUNE
# PARALLEL_OPS = None

## Prepare dataset

In [None]:
IMAGE_SIZE = (128, 128)
GCS_PATTERN = 'gs://lakehopper-semseg-data/aggregated-28-7-tfr/*.tfr'
paths = filter_out_only_ground_paths(tf.io.gfile.glob(GCS_PATTERN))

In [None]:
TRAIN_RATIO = 0.7
VALIDATION_RATIO = 0.15
train_paths, validate_paths, test_paths = split_dataset_paths(paths, TRAIN_RATIO, VALIDATION_RATIO)
print(f"Found {len(paths)} tfrs. Splitting into {len(train_paths)} training, {len(validate_paths)} validation, and {len(test_paths)} test tfrs")

## Visualise dataset

In [None]:
visualisation_dataset = load_dataset(train_paths, IMAGE_SIZE, PARALLEL_OPS)
sample_image, sample_label = next(islice(iter(visualisation_dataset), 0, None))
show((sample_image, colorize_byte_mask(sample_label)))

## Load dataset

In [None]:
def add_sample_weights(image, label):
    # The weights for each class, with the constraint that:
    #     sum(class_weights) == 1.0
    class_weights = tf.constant([0.1, 0.01, 0.2])
    class_weights = class_weights/tf.reduce_sum(class_weights)

    # Create an image of `sample_weights` by using the label at each pixel as an 
    # index into the `class weights` .
    sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))

    return image, label, sample_weights

In [None]:
class Augment(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.rng = tf.random.Generator.from_seed(42, alg='philox')
        
    def call(self, image, labels, weights):
        seeds = tf.random.experimental.stateless_split(self.rng.make_seeds(2)[0], num=3)
        image = tf.image.stateless_random_flip_left_right(image, seeds[0])
        image = tf.image.stateless_random_brightness(image, max_delta=0.5, seed=seeds[1])
        image = tf.image.stateless_random_hue(image, max_delta=0.1, seed=seeds[2])
        labels = tf.image.stateless_random_flip_left_right(labels, seeds[0])
        weights = tf.image.stateless_random_flip_left_right(weights, seeds[0])
        return image, labels, weights

In [None]:
BATCH_SIZE = 64  # Using TPU v3-8 device => must be divisible by 8 for sharding
BUFFER_SIZE = 1000
# BACKBONE = 'efficientnetb3'
# preprocess_input = sm.get_preprocessing(BACKBONE)

# Dataset generation *must* come after tpu resolution
training_dataset = (
    load_dataset(train_paths, IMAGE_SIZE, PARALLEL_OPS)
    .map(add_sample_weights)
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .map(Augment())
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)
validation_dataset = (
    load_dataset(validate_paths, IMAGE_SIZE, PARALLEL_OPS)
    .batch(BATCH_SIZE)
)

## Create model

In [None]:
def create_model(nbro_classes: int, tpu_strategy: TPUStrategy = None) -> tf.keras.Model:
    LR = 0.0001
    ENCODER = 'InceptionResNetV2' # MobileNetV2  EfficientNetB3  ResNet50
    
    def scoped_create_model():
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        optimizer = tf.keras.optimizers.Adam(LR)
        metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
        input_shape = IMAGE_SIZE + (3,)

        model = create_unet(ENCODER, nbro_classes, input_shape)
#         model = create_fpn(ENCODER, nbro_classes, input_shape)
#         model = create_fcn8(ENCODER, nbro_classes, input_shape)
        model.compile(ENCODER, optimizer=optimizer, loss=loss, metrics=['accuracy'])
    
        return model

    if tpu_strategy is not None:
        with tpu_strategy.scope():
            return scoped_create_model()
    else:
        return scoped_create_model()

In [None]:
model = create_model(3, tpu_strategy)

In [None]:
def create_mask(pred_mask):
    pred_mask = tf.math.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

In [None]:
def show_predictions(dataset=[(sample_image, sample_label)]):
    rows = [
        (image, label, create_mask(model.predict(image[tf.newaxis, ...], verbose=0)))
        for image, label in dataset
    ]
    show(rows)

In [None]:
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

In [None]:
show_predictions()

## Train model

In [None]:
EPOCHS = 20
VAL_SUBSPLITS = 2
steps_per_epoch = len(train_paths) // BATCH_SIZE
validation_steps = len(validate_paths) // BATCH_SIZE // VAL_SUBSPLITS

print(f"With a batch size of {BATCH_SIZE}, there will be {steps_per_epoch} batches per training epoch and {validation_steps} batches per validation run.")

In [None]:
tensorboard_log_dir = "gs://lakehopper-semseg-data/model/" + datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
callbacks = [
    tf.keras.callbacks.ModelCheckpoint('./best_model.h5', save_weights_only=True, save_best_only=True, mode='min'),
    tf.keras.callbacks.ReduceLROnPlateau(),
    tf.keras.callbacks.TensorBoard(log_dir=tensorboard_log_dir, histogram_freq=1),
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0.001, patience=3),
    DisplayCallback()
]

In [None]:
history = model.fit(
    training_dataset,
    epochs=EPOCHS,
    steps_per_epoch=steps_per_epoch,
    validation_data=validation_dataset,
    validation_steps=validation_steps,
    callbacks=callbacks,
)
model.save('finished_model.h5')

model.summary()

## Metrics

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.show()

In [None]:
# Save metrics to csv
datetimestr = datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
with open(f'metrics-{datetimestr}', 'w') as f:
    writer = csv.writer(f)
    writer.writerow(['loss', 'val_loss'])
    for train, val in zip(history.history['loss'], history.history['val_loss']):
        writer.writerow([train, val])