This notebook demonstrates how to use train a convolutional neural network for DEM refinement. We leverage `pytorch`, `pytorch-lightning`, `torchgeo` as part of our ML stack. Our model follows the `UNet` architecture and incorporates a few variations in the form of `ResNet` backbones and additional skip connections.

In [None]:
# misc imports
from pathlib import Path
from datetime import datetime

# pytorch imports
from torch import nn 
import torch

# pytorch-lightning imports
import lightning as L
from pytorch_lightning.loggers import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint

# torchgeo imports
from torchgeo.datasets import stack_samples

# Kornia imports for data augmentation
import kornia.augmentation as K

import sys
sys.path.insert(0, str(Path('.').absolute().parent))

# local imports
from scripts.task_module import DeepDEMRegressionTask
from scripts.dataset_modules import CustomInputDataset, CustomDataModule

In [None]:
# Set file path to pre-processed data
data_path = Path('')

In [None]:
# Set up training parameters
CHIP_SIZE = 256 
BATCH_SIZE = 12
NUM_WORKERS = 8
LR = 5e-4

bands = [
    "asp_dsm",
    "ortho_left",
    "ortho_right",
    "ndvi",
    "nodata_mask",
    "triangulation_error",
    "lidar_data",
]

In [None]:
# path to data
datapath = '/mnt/1.0_TB_VOLUME/karthikv/DeepDEM/data/baker_csm_stack/processed_rasters/'

In [None]:
transforms = nn.Sequential(
    K.RandomHorizontalFlip(p=0.5),
    K.RandomVerticalFlip(p=0.5),
)

In [None]:
datamodule_params = {
    'paths': datapath,
    'dataset_class':CustomDataModule,
    'chip_size':CHIP_SIZE,
    'batch_size':BATCH_SIZE,
    'num_workers':NUM_WORKERS,
    'collate_fn':stack_samples,
    'cuda': torch.cuda.is_available(),
    'bands':bands,
    'train_aug':transforms
}
datamodule = CustomDataModule(**datamodule_params)

In [None]:
tempdataset = CustomInputDataset(paths=datapath, bands=bands)
left_ortho_mean, left_ortho_std = tempdataset.compute_mean_std("ortho_left")
right_ortho_mean, right_ortho_std = tempdataset.compute_mean_std("ortho_right")

model_kwargs = {
    'model':'smp-unet',
    'encoder':'resnet18',
    'encoder_weights':'imagenet',
    'bands':bands,
    'left_ortho_mean':left_ortho_mean,
    'left_ortho_std':left_ortho_std,
    'right_ortho_mean':right_ortho_mean,
    'right_ortho_std':right_ortho_std,
    'chip_size':CHIP_SIZE,
    'do_BN':False,
    'bias_conv_layer':False,
    'lr':LR,
    'patience':10,
    'num_workers':NUM_WORKERS,
    'max_epochs':100,
    'lr_scheduler':True,
    'lr_scheduler_scale_factor':0.5,
    'lr_scheduler_patience':150
}
task = DeepDEMRegressionTask(**model_kwargs)

In [None]:
# Setup folder for logging
checkpoint_directory = Path(f'./checkpoints/checkpoint_directory_{datetime.now().strftime("%Y%m%d")}')
checkpoint_directory.mkdir(exist_ok=True, parents=True)

model_count = len([x for x in list(checkpoint_directory.glob('*')) if x.is_dir()]) + 1 # incase multiple runs exist in the folder

checkpoint_directory = checkpoint_directory / f"version_{str(model_count).zfill(3)}"
checkpoint_directory.mkdir(exist_ok=False)

callbacks =[LearningRateMonitor(logging_interval='step'), ModelCheckpoint(dirpath=checkpoint_directory, monitor='val_loss', mode='min')]
logger = TensorBoardLogger(save_dir="logs/", name=f"my_experiment_{datetime.now().strftime("%Y%m%d")}")

trainer = L.Trainer(accelerator = "gpu" if torch.cuda.is_available() else "cpu", 
                    default_root_dir=checkpoint_directory, 
                    max_epochs=model_kwargs['max_epochs'], logger=logger, check_val_every_n_epoch=1, # type: ignore
                    log_every_n_steps=1, fast_dev_run=False, # set fast_dev_run to True to do quick sanity check run (no training)
                    callbacks=callbacks) # type: ignore

# Train model
trainer.fit(model=task, datamodule=datamodule)

# Save model weights
torch.save(task.model.state_dict(), checkpoint_directory/f"model_weights_version{model_count}.pth")

print(f"Model weights saved to {checkpoint_directory/f"model_weights_version{model_count}.pth"}")

Training progress can be monitored by running tensorboard using the following commands:

`
tensorboard --logdir='./tensorboard_dirs' --port=<specify port number of choice>
`

and then navigating to the following URL in a browser:

`
http://localhost:<port number>/
`

If the training is happening on a remote machine, we can open an ssh connection to the remote machine:

`
ssh -N -f -L <local port>:127.0.0.1:<remote machine port> <username>@<server>
`

Followed by accessing `http://localhost:<local port>/`
