This notebook demonstrates how to use train a convolutional neural network for DEM refinement. We leverage `pytorch`, `pytorch-lightning`, `torchgeo` as part of our ML stack. Our model follows the `UNet` architecture and incorporates a few variations in the form of `ResNet` backbones and additional skip connections.

In [2]:
# pytorch imports
from pytorch_lightning.loggers import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
import lightning as L
import torch
from torch import nn 

# torchgeo imports
from torchgeo.datasets import stack_samples

# misc imports
from pathlib import Path
import kornia.augmentation as K
import sys

# local imports
sys.path.insert(0, str(Path('.').absolute().parent/'lib'))
sys.path.insert(0, str(Path('.').absolute().parent))
from lib.dataset_modules import CustomDataModule
from lib.task_module import DeepDEMRegressionTask

torch.set_float32_matmul_precision('medium')

In [None]:
# Basic hyperparameters
BATCH_SIZE = 24
NUM_WORKERS = 12
CHANNEL_SWAP = True # Swap the two stereo channels to generalize training
FAST_DEV_RUN = False # Set to True if doing debugging/sanity check run
CHIP_SIZE = 256 # Size of model input chips
MODEL_ENCODER = 'resnet18'

# this can be 'unet' for ResDepth architecture or 'smp-unet' for UNet with ResNet encoder
MODEL_TYPE = 'smp-unet' 

# Determines the fraction of the image used for training along x-axis, manually determined
# For the Mt Baker dataset, there are large swathes of no-data region on one side of the image
# necessitating a peculiar split instead of the typical 90/10
TRAIN_SPLIT = 0.65 

In [None]:
# Image augmentation transforms
# These would be applied to the images during training
transforms = nn.Sequential(
    K.RandomHorizontalFlip(p=0.5),
    K.RandomVerticalFlip(p=0.5),
)

In [None]:
# Bands that will be use as model inputs
bands = [
    "asp_dsm",
    "ortho_left",
    "ortho_right",
    "ndvi",
    "nodata_mask",
    "triangulation_error",
    "lidar_data",
]

In [None]:
# Path to processed rasters generated from 0c_Data_Preprocessing.ipynb 
datapath = '/mnt/working/karthikv/DeepDEM/data/mt_baker/WV01_20150911_1020010042D39D00_1020010043455300/processed_rasters'

# Datamodule parameters
datamodule_params = {
    'dataset_class':CustomDataModule,
    'chip_size': CHIP_SIZE,
    'batch_size':BATCH_SIZE,
    'num_workers':NUM_WORKERS,
    'collate_fn':stack_samples,
    'cuda': torch.cuda.is_available(),
    'train_aug':transforms,
    'train_split':TRAIN_SPLIT,
    'paths':datapath,
    'bands':bands
}

# Model kwargs
model_kwargs = {
    'chip_size': CHIP_SIZE,
    'encoder_weights':'imagenet',
    'channel_swap':CHANNEL_SWAP,
    'do_BN':False,
    'bias_conv_layer':False,
    'lr':5e-4,
    'num_workers':NUM_WORKERS,
    'max_epochs':300,
    'lr_scheduler':True,
    'lr_scheduler_scale_factor':0.5,
    'lr_scheduler_patience':50,
    'early_stopping':True,
    'earlystopping_patience':75,
    'datapath':datapath,
    'train_split':TRAIN_SPLIT,
    'bands':bands,
    'encoder':MODEL_ENCODER,
    'model':MODEL_TYPE,
}

In [None]:
datamodule = CustomDataModule(**datamodule_params)
task = DeepDEMRegressionTask(**model_kwargs)

In [None]:
checkpoint_directory = Path(f'./checkpoints/deep_dem_experiment')
checkpoint_directory.mkdir(exist_ok=True, parents=True)
model_count = len([x for x in list(checkpoint_directory.glob('*')) if x.is_dir()]) + 1

checkpoint_directory = checkpoint_directory / f"version_{str(model_count).zfill(3)}"
checkpoint_directory.mkdir(exist_ok=False)

# Callbacks get passed to the trainer
callbacks = [
    LearningRateMonitor(logging_interval='step'), 
    ModelCheckpoint(dirpath=checkpoint_directory, monitor='val_loss', mode='min')
]

# if early stopping is set in model kwargs, stop training after conditions are met
if model_kwargs['early_stopping']:
    callbacks.append(EarlyStopping(monitor="val_loss", 
    min_delta=0.05, 
    patience=model_kwargs['earlystopping_patience'], 
    verbose=True, mode="min")) # type: ignore

# setup logger for tensorboard
logger = TensorBoardLogger(save_dir="logs/", name=f"deep_dem_experiment")

# define trainer
trainer = L.Trainer(accelerator = "gpu" if torch.cuda.is_available() else "cpu", 
                    default_root_dir=checkpoint_directory, 
                    max_epochs=model_kwargs['max_epochs'], logger=logger, check_val_every_n_epoch=1, # type: ignore
                    log_every_n_steps=1, fast_dev_run=FAST_DEV_RUN, # set fast_dev_run to True for sanity check (dummy run) before training 
                    callbacks=callbacks) # type: ignore

In [None]:
# run training
trainer.fit(model=task, datamodule=datamodule)

In [None]:
# save model weights
torch.save(task.model.state_dict(), checkpoint_directory/f"deepdem_model_weights.pth")

Training progress can be monitored by running tensorboard using the following commands:

`
tensorboard --logdir='./tensorboard_dirs' --port=<specify port number of choice>
`

and then navigating to the following URL in a browser:

`
http://localhost:<port number>/
`

If the training is happening on a remote machine, we can open an ssh connection to the remote machine:

`
ssh -N -f -L <local port>:127.0.0.1:<remote machine port> <username>@<server>
`

Followed by accessing `http://localhost:<local port>/`
