# DeepDEM demo
- Seth Vanderwilt
- 11/1/23

## Overview

**Today's demo**
- Overview of the workflow:
    - Process stereo imagery into DEM + orthoimages + intersection error which are self-consistent and aligned with the lidar target DEM.
    - Using torchgeo and lightning, load these raster stacks (onto GPU, if availalble) and train neural networks to refine the initial stereo DEM using lidar DEM as supervision, with the goal of restoring terrain features, eliminating artifacts, etc.
    - Run inference on training datasets or other raster stacks to evaluate refinement quality
- Show training & inference with either 3km x 3km Easton Glacier demo area subset, or larger 128km^2 Mount Baker stack
- With identical layer & normalization settings (the 5 input layers, `meanstd` normalization) reuse existing checkpoints such as `version_145/checkpoints/epoch=12366-step=3091750.ckpt` as starting point for training or use directly for inference
- Tips for processing, experiments, and evaluation
- Code suggestions to increase flexibility for future experiments:
    - Continue modularization and use of LightningCLI / torchgeo tools to ingest and preprocess datasets in a more flexible way than currently implemented.
    - Fix a naming scheme for input layers to avoid having to specify each individual dataset in configuration files and `torchgeo_dataset.py`

**Notes:**
- The latest [torchgeo](https://github.com/microsoft/torchgeo/) v0.5.0 may cause issues with its GridGeoSampler output for rasters that are not a multiple of the patch size (i.e. grid sampler failing along the edges where available pixels don't match patch size) - unknown changes since v0.4.1 for loading files and producing an output raster
- Also unknown LightningCLI changes may affect the functionality of this code

**Next steps for any users:**
- Configure environment appropriately
- Create/download desired dataset stack to use, input normalization strategy, etc. and revise the dataset implementation and filepaths as needed
    - Suggestion: 1 training stack + 1 validation stack, OR split a single stack of GeoTIFFs into pre-defined tiles (e.g. using USGS 3DEP tile boundaries, 116km^2 of training tiles interspersed with 12km^2 of validation tiles)
- Choose a model checkpoint from Google Drive experiment checkpoints (for the case of 5 input layers `DSM-Ortho1-Ortho2-TriError-NodataMask` and `meanstd` normalization), or start from scratch.
- Revise configuration YAML file to match desired training hyperparameters, model, ...
- Look through torchgeo and lightning docs, especially with LightningCLI integration updates described in [0.5.0](https://github.com/microsoft/torchgeo/releases/tag/v0.5.0)
    - this points to a potentially better way to define and load the dataset compared to current `torchgeo_dataset.py` implementation
    - LightningCLI fit / predict / validate / test setup may be useful to avoid current redundant code in training and inference scripts (basically, separate configuration from code more)
- Pick a descriptive file naming scheme for experiments and checkpoints (not just version numbers) and include metadata (descriptions) to track each run

**Steps for larger changes:**
- Look into refactoring the LightningModule `resdepth_lightning_module.py` and LightningDataModule `tgdsm_lightning_data_module.py` implementations and their supporting files to allow specification input layers, parameters, normalization, passing in different models/losses in configuration files. Main thing to handle is getting all layers in expected order with correct transformations, which were hardcoded per-layer for initial experiments
- Follow [Lightning documentation](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html) for modifying configuration files.

## Setting up the code (& get the repo from GitHub)

In [None]:
main_dir = "/Users/seth/Desktop/deepDEM_test"
script_dir = os.path.join(main_dir, "DeepDEM/ResDepth/torchgeo_experiments")

%cd $main_dir
# git clone DeepDEM # into this directory
repo_dir = os.path.join(main_dir, "DeepDEM")

### Environment

* Note: using another conda environment without the CUDA requirements, and with torchgeo pinned to 0.4.1 - some torchgeo changes over the last months seem to break the inference.py script when copying output patches to the resulting raster.

In [None]:
!mamba env create --file DeepDEM/environment_torchgeo041_no_gpu.yml

### Files to modify:
* `DeepDEM/ResDepth/torchgeo_experiments/torchgeo_dataset.py`
    * change the `files` attribute of the class to some other name
* `DeepDEM/ResDepth/torchgeo_experiments/inference.py`
    * revise path to Stucker & Schindler ResDepth repo
    * reduce tile (patch) size
    * specify validation directory and dataset type for `torchgeo_dataset.py` loading
* `DeepDEM/ResDepth/torchgeo_experiments/resdepth_lightning_module.py`
    * revise path to Stucker & Schindler ResDepth repo
* `DeepDEM/configs/` suggested way: copy the example, create a new configuration YAML file and set the necessary values
    * To create a new YAML from scratch, use `python train_cli.py fit --print_config
` based on [Lightning docs](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced.html#prepare-a-config-file-for-the-cli)
    * `max_epochs` perhaps set to 1234567890 in order to resume from previous checkpoints with > 1000 epochs
    * `ckpt_path` to resume from a `.ckpt` checkpoint
    * `normalization` to `meanstd`
    * `batch_size` reduce to e.g. 2 depending on hardware, model
    * `train_directory` and `val_directory` paths to desired datasets
* `DeepDEM/environment.yml`
    * remove CUDA if needed
    * pin version(s): only `torchgeo` 0.4.1

## Imports

In [None]:
import os
import glob
import sys

import numpy as np
import matplotlib.pyplot as plt
import rioxarray as rxr
from osgeo import gdal

from matplotlib_scalebar.scalebar import ScaleBar

## Make a smaller subset of the Mount Baker dataset (~3 km x 3 km)

In [None]:
# Due to potential changes in latest torchgeo for edge of raster behavior, use even multiples of the patch size for this test dataset
# Easton glacier demo area
xmin = 584000
xmax = 587072
ymin = 5397000
ymax = 5400072

In [None]:
for full_size_fn in glob.glob("baker_csm_stack/*.tif"):
    small_fn = os.path.join("baker_csm_stack_small", os.path.basename(full_size_fn))
    print(f"Crop {full_size_fn} to {small_fn}")
    !gdalwarp -overwrite -r cubic -te $xmin $ymin $xmax $ymax "$full_size_fn" "$small_fn"

## Helper functions copied from `presentation_figures.ipynb` in repo

In [None]:
def plot_ortho(ax,ortho,title, scale=1, fixed_value=50, show_scalebar=True):
    im = ax.imshow(ortho, cmap="gray", rasterized=True) # just want true range for colorbar
    ortho = stretch_image(ortho)
    im = ax.imshow(ortho, cmap="gray", rasterized=True)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect(1)
    ax.set_title(title)

    # Include scalebar
    # scale in meters
    if show_scalebar:
        ax.add_artist(ScaleBar(scale, fixed_value=fixed_value))
    plt.colorbar(im, ax=ax, fraction=0.04)

    return ax

In [None]:
def open_raster(fn, bbox=None):
    rxr_obj = rxr.open_rasterio(fn, masked=True).squeeze()
    if bbox:
        minx,miny,maxx,maxy = bbox
        # Sometimes have min & max ordering wrong when copying & pasting
        minx,miny,maxx,maxy = min(minx,maxx),min(miny,maxy),max(minx,maxx),max(miny,maxy)
        return rxr_obj.rio.clip_box(minx,miny,maxx,maxy)
    else:
        return rxr_obj

In [None]:
def stretch_image(img):
    # https://stackoverflow.com/questions/60449340/contrast-enhancement-using-a-percentage-cumulative-count-in-matplolib
    min_percent = 2   # Low percentile
    max_percent = 98  # High percentile
    lo, hi = np.percentile(img, (min_percent, max_percent))

    # Apply linear "stretch" - lo goes to 0, and hi goes to 1
    res_img = (img.astype(float) - lo) / (hi-lo)

    #Multiply by 255, clamp range to [0, 255] and convert to uint8
    res_img = np.maximum(np.minimum(res_img*255, 255), 0).astype(np.uint8)
    return res_img

In [None]:
def plot_shaded_relief(ax, dem_fn, bbox, title):
    hs_fn = "tmphillshade.tif"
    ds = gdal.Open(dem_fn)
    # Creates a new file in this directory, could delete or retain if desired
    gdal.DEMProcessing(hs_fn, ds, "hillshade", computeEdges=True)
    ds = None
    dem = open_raster(dem_fn, bbox=bbox)
    hs = open_raster(hs_fn, bbox=bbox)
    # dem = dem.squeeze()
    # hs = hillshade(dem)
    ax.imshow(hs, cmap="gray", rasterized=True)
    im = ax.imshow(dem, cmap="viridis", alpha=0.5, rasterized=True)
    plt.colorbar(im, ax=ax, fraction=0.04)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect(1)
    ax.set_title(title)

    return ax

## Commands to run training & inference

In [None]:
%pwd

In [None]:
%cd $script_dir

### Training with LightningCLI and a configuration file
* (encourage switch to use of the LightningCLI approach of `train_cli.py` instead of `train.py` which was used for most previous experiments and has all implemented features, but was limiting due to its mix of configuration and code)
* To train with a different dataset, modify the configuration file by providing lists of training & validation directories appropriately.
* To load a model from a checkpoint, define this in the YAML. Example:
    * `max_epochs: 1234567890`
    * `ckpt_path: "/path/to/epoch=12366-step=3091750.ckpt"`
    * Make sure normalization choice (i.e. `meanstd`) & list of input layers match what was used to train the model!
* When training monitor the losses and evaluate output carefully, find the high-loss examples (e.g. deep forest had offsets orders of magnitude larger than non-forested areas, and were inhibiting progress with 2022 experiments)

In [None]:
!python train_cli.py fit -c $repo_dir/configs/example_config.yml

### Run inference
* The inference script could also be replaced by LightningCLI if there is a way to sample the entire dataset and stitch outputs into a raster. This would avoid issues with implementing a separate CLI and redundant/brittle code dealing with normalization of layers, etc.

In [None]:
inference_output_fn = "inference_tg041_1024x1024.tif"
!python inference.py $main_dir/epoch=12366-step=3091750.ckpt \
    meanstd  \
    $inference_output_fn \
    stucker_unet \
    cpu

In [None]:
# Reuse Baker CSM stack filenames from torchgeo_dataset.py

# initial_dem_unfilled_root = "try_pc_align_to_lidar_15m_maxdisp_rotationallowed-1.0m-DEM.tif"
baker_stack_dir = f"{main_dir}/baker_csm_stack_small"

# Bring in the inference result
inference_output_path = os.path.join(baker_stack_dir, os.path.basename(inference_output_fn))
!cp $inference_output_fn $inference_output_path

initial_dem_root = os.path.join(baker_stack_dir, "try_pc_align_to_lidar_15m_maxdisp_rotationallowed-1.0m-DEM_holes_filled.tif")
ortho_left_root = os.path.join(baker_stack_dir, "final_ortho_left_1.0m.tif")
ortho_right_root = os.path.join(baker_stack_dir, "final_ortho_right_1.0m.tif")
triangulation_error_root = os.path.join(baker_stack_dir, "try_pc_align_to_lidar_15m_maxdisp_rotationallowed-1.0m-IntersectionErr.tif")#_holes_filled.tif")
target_root = os.path.join(baker_stack_dir, "mosaic_full128_USGS_LPC_WA_MtBaker_2015___LAS_2017_32610_first_filt_v1.3_1.0m-DEM_holes_filled.tif")
# Note: target_root had a `_*_` in it before, filename changed at some point

In [None]:
%pwd

In [None]:
%cd $main_dir

## Plot output vs input datasets

**It is also a good idea to use hillshades and difference maps to compare the refined DEM to input and target DEMs, with evaluation scripts, interactive notebooks, and/or QGIS.**
More plotting code can be found within the notebook directory or the `eval.py` script

In [None]:
train_channel_and_tree_spot = (585825,5398225, 586900,5397475)

In [None]:
dpi = 500
output_choice = inference_output_path

output_choice = output_v129_11758 # output_v140_11632#
run_name = os.path.splitext(os.path.basename(output_choice))[0]
plot_title = f"Plot Baker training area {train_channel_and_tree_spot}\n{run_name}"
nrows=2
ncols=2
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, dpi=dpi,figsize=(10,7.5))# figsize=(ncols * 4, nrows * 4 + 1), )
axes = axes.reshape(-1)

fig.suptitle(plot_title)

    
plot_ortho(axes[0], open_raster(ortho_left_root, bbox=train_channel_and_tree_spot), "Orthoimage (stretched)")
plot_shaded_relief(axes[1], initial_dem_root, train_channel_and_tree_spot, "Initial DEM (meters)")
plot_shaded_relief(axes[2], output_choice, train_channel_and_tree_spot, "Refined DEM (meters)")
plot_shaded_relief(axes[3], target_root, train_channel_and_tree_spot, "Lidar DEM (meters)")


plt.tight_layout()