This notebook demonstrates how to generate inference for a scene using a previously trained DeepDEM model

In [None]:
# rasterio imports
import rasterio
from rasterio.merge import merge

from rasterio import transform

# torchgeo imports
from torchgeo.samplers import GridGeoSampler
# from torchgeo.datasets import RasterDataset

# pytorch imports 
import torch

# misc imports
from functools import reduce, partial
# import operator
from pathlib import Path, PurePath
import numpy as np

import yaml
from pathlib import Path
import subprocess

import kornia.augmentation as K


# local imports
import sys
sys.path.insert(0, str(Path('.').absolute().parent/'scripts'))

from task_module import DeepDEMRegressionTask
from dataset_modules import CustomInputDataset, CustomDataModule
from torchgeo.datasets import stack_samples
from torchgeo.samplers import BatchGeoSampler
from torch import nn 
from functools import partial


In [None]:
model_path= Path('/mnt/working/karthikv/DeepDEM/scripts/checkpoints/experiment_group_1/version_001/')
model_checkpoint = list(model_path.glob('*.ckpt'))[0]
model = DeepDEMRegressionTask.load_from_checkpoint(model_checkpoint).cuda().eval();
# model.model_kwargs['channel_swap'] = False

bands, datapath, chip_size = model.model_kwargs['bands'], model.model_kwargs['datapath'], model.model_kwargs['chip_size']
bands.remove('lidar_data')
inference_dataset = CustomInputDataset(datapath, bands=bands)
data_sampler = GridGeoSampler(inference_dataset, size=chip_size, stride=chip_size//2)

In [None]:
with rasterio.open(inference_dataset.files[0]) as ds:
    template_profile = ds.profile

In [None]:
def generate_write_inference(index, sample, dataset, model, output_path, template_profile):
    
    if not isinstance(output_path, PurePath):
        output_path = Path(output_path)
    if not output_path.exists():
        output_path.mkdir()

    item = dataset.__getitem__(sample)
    inference = model.forward(item['image'].reshape(1, *item['image'].shape).cuda(), stage='inference').cpu().detach().numpy().squeeze()
    
    bounds = (sample.minx, sample.miny, sample.maxx, sample.maxy)

    dst_transform = transform.from_bounds(*bounds, *inference.shape[::-1])

    template_profile.update({
        'width':inference.shape[-1],
        'height':inference.shape[-2],
        'transform':dst_transform
    })

    with rasterio.open(output_path / f"inference_{str(index).zfill(6)}.tif", 'w', **template_profile) as ds:
        ds.write(inference.reshape(1, *inference.shape))


model_name = '_'.join(str(model_path).split('/')[-2:])
output_path = Path(f'../outputs/{model_name}')

generate_write_inference = partial(generate_write_inference, output_path=output_path, 
                                   dataset=inference_dataset, 
                                   model=model,
                                   template_profile=template_profile)

for i, sample in enumerate(data_sampler):
    generate_write_inference(i, sample)

In [None]:
inferences = sorted(list(output_path.glob('inference_*.tif')))[::-1]
merged_inference, merge_transform = merge(inferences)

In [None]:
inference_profile = template_profile.copy()
inference_profile['height'] = merged_inference.shape[-2]
inference_profile['width'] = merged_inference.shape[-1]
inference_profile['transform'] = merge_transform

with rasterio.open(output_path.parent / f'{output_path.name}.tif', 'w', **inference_profile) as ds:
    ds.write(merged_inference)

In [None]:
subprocess.run(["gdaldem", "hillshade", "-compute_edges", output_path.parent / f'{output_path.name}.tif', output_path.parent / f'{output_path.name}_hs.tif'])