Skip to content

Latest commit

 

History

History
451 lines (314 loc) · 10.3 KB

infill.org

File metadata and controls

451 lines (314 loc) · 10.3 KB

Infill

Load a pre-trained model and generate some GeoTiffs with it

Imports

import os
import random
import toml

import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
from osgeo import gdal

import matplotlib as mpl
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.transforms as transforms

from networks import Generator
from tools import random_bbox, mask_image
from blend import blend_images

Settings

with open('config.toml', 'r') as file:
    config = toml.load(file)


# Test specific settings
config['cuda'] = False
config['dataset_name'] = "../datac/NAC_DTM_WIENERF.TIF"
config['seed'] = 1433

Tile Dataset

def normalise(arr):
    normalized = (arr - arr.min()) / (arr.max() - arr.min())


    normalized = 2*normalized - 1

    return normalized

def partial_norm(arr):
    normalized = (arr - arr.min()) / (arr.max() - arr.min())

    normalized = normalized - 1

    return normalized

def tile(dataset, kernel_size):

    dem = gdal.Open(dataset)

    crs = dem.GetProjection()
    geo_transform = dem.GetGeoTransform()

    image = dem.ReadAsArray()

    img_height, img_width = image.shape
    tile_height, tile_width = kernel_size

    # If cant divide perfectly
    if (img_height % tile_height != 0 or img_width % tile_width != 0):
        new_height = img_height - (img_height % tile_height)
        new_width = img_width - (img_width % tile_width)

        image = image[:new_height, :new_width]

    tiles_high = img_height // tile_height
    tiles_wide = img_width // tile_width

    tiled_array = image.reshape(tiles_high,
                                tile_height,
                                tiles_wide,
                                tile_width )

    tiled_array = tiled_array.swapaxes(1, 2)

    tiled_array = tiled_array.reshape(tiles_high * tiles_wide, tile_height, tile_width)

    # GC should get this, but just to be safe
    dem = None

    min_max = []
    for arr in tiled_array:
        min_max.append((arr.min(), arr.max()))

    vectorized_normalise = np.vectorize(normalise, signature='(n,m)->(n,m)')
    # vectorized_partial_norm = np.vectorize(partial_norm, signature='(n,m)->(n,m)')

    tiled_array = vectorized_normalise(tiled_array)

    # Slope
    cellsize = geo_transform[1]
    px, py = np.gradient(tiled_array, cellsize, axis=(1,2))
    slope = np.arctan(np.sqrt(px ** 2 + py ** 2))
    slope = vectorized_normalise(slope)

    # # RDLS
    # windowed = sliding_window_view(tiled_array, (3,3), axis=(1,2))
    # rdls = np.ptp(windowed, axis=(3,4))
    # rdls = np.pad(rdls, ((0,0), (1,1), (1,1)), mode='constant', constant_values=0)
    # rdls = vectorized_normalise(rdls)

    all = np.stack((tiled_array, slope), axis=3)

    # tiled_array = np.expand_dims(tiled_array, axis=3)
    all = np.transpose(all, (0, 3, 1, 2))

    # H,W,C to C,H,W
    return torch.from_numpy(all), min_max, crs, geo_transform

Image Transformations

# Return data from -1<->1 normalisation to original state

def denormalize(tensor, max, min):
    arr = tensor.cpu().detach().numpy()

    arr = np.squeeze(arr)

    arr = np.transpose(arr, (1,2,0))

    arr = (arr * 0.5) + 0.5
    arr = (arr * (max - min)) + min

    return arr

# Return data from -1<->1 normalisation to 0<->1 normalisation
# Used as poisson blending requires the DEM data to be 0<->1 but the inpainted DEM is returened -1<->1
def partial_dn(tensor):

    arr = tensor.cpu().detach().numpy()

    arr = np.squeeze(arr)
    arr = np.transpose(arr, (1,2,0))

    arr = (arr * 0.5) + 0.5

    return arr

Setup

Seed

Can probably get rid of this, no training is happening

if config["seed"]:
    seed = config["seed"]
    random.seed(seed)
    torch.manual_seed(seed)

Get Tile

tiled, min_max, crs, geo_transform = tile(config["dataset_name"], (256, 256))

tiled.shape

Not the most efficient way of doing things but since individual DEM files are (probably) much larger than the tiles the network is trained on. Also a lot of data (annoyingly) seems to be basic slopes that arent very interesting.

  • There maybe is something to be said for trying to find high res (5m) DEMs with consistently complex terrain.

Workflow

  • Manually iterate through tiles until an interesting tile is found
  • Generate infilled DEM
  • If it is either really good or really shit save to file as it will be good for the report.

Select Tile

def display(image):

    if isinstance(image, torch.Tensor):
        image = image.cpu().detach().numpy()
        image = image.squeeze()
        image = np.transpose(image, (1,2,0))

    plt.figure(figsize=(12,4))
    plt.subplot(1,2,1)

    plt.imshow(image[:,:,0], cmap='terrain')
    plt.title("DEM")
    plt.colorbar()
    plt.subplot(1,2,2)
    plt.imshow(image[:,:,1], cmap='viridis')
    plt.title("Slope")
print(len(tiled))

Great tiles: NAC_DTM_RIMASHARP

  • 1087

NAC_DTM_TYCHOPK05

  • 250

NAC_DTM_WIENERF

  • 240
  • 400
  • 500
biggest_dif = 0
big_dem = 0

for i, (min, max) in enumerate(min_max):
    if max - min > biggest_dif:
        biggest_dif = max - min
        big_dem = i


big_dem
tile_n = 240

dem = tiled[tile_n]
dem_min, dem_max = min_max[tile_n]

tt = dem.cpu().detach().numpy()
tt = np.transpose(tt, (1,2,0))

display(denormalize(dem, dem_max, dem_min))

plt.colorbar()

Infill

#### Transforms

#### Infill void

# Different from normal bbox
# (y1, x1, y2, x2)
# bboxes = torch.tensor([(0, 80, 256, 160)], dtype=torch.int64)
bboxes = torch.tensor([(74, 74, 182, 182)], dtype=torch.int64)
x, mask = mask_image(dem, bboxes, config, train=False)

checkpoint_path = "../slope_out/saved_models/gen_00000500.pt"
# checkpoint_path = "../out_final/saved_models/gen_00000168.pt"
# checkpoint_path = "out/saved_models/gen_00000036.pt"

inpainted_result = None
x2 = None

with torch.no_grad():

    netG = Generator(config, config["cuda"])
    netG.load_state_dict(torch.load(checkpoint_path))
    x1, x2 = netG(x, mask)
    inpainted_result = x2 * mask + x * (1. - mask)

#### De-normalize
inpainted_result_dn = denormalize(inpainted_result, dem_max, dem_min)
# ground_truth_dn = denormalize(ground_truth, img_max, img_min)
# inpainted_result = np.squeeze(inpainted_result)


display(inpainted_result_dn)
# plt.imshow(np.squeeze(inpainted_result), cmap='terrain')

inpainted_result_dn.max()
m = np.squeeze(mask)

offset = 1000

void = np.ones((256,256)) * (inpainted_result_dn.min() - 200)

for i, row in enumerate(inpainted_result_dn[:,:,0]):
    for z, pixel in enumerate(row):
        if m[i][z] == 0.:
            void[i][z] = pixel

plt.imshow(void, cmap='terrain')

infill = partial_dn(x2)
gt = partial_dn(dem)

padded_infill = np.pad(infill[:,:,0], ((1,1), (0,0)), mode='constant', constant_values=0)
padded_gt = np.pad(gt[:,:,0], ((1,1), (0,0)), mode='constant', constant_values=0)
padded_mask = np.pad(mask, ((1,1), (0,0)), mode='constant', constant_values=0)
blended = blend_images(padded_infill, padded_gt, padded_mask)
blended = blended[1:-1, :]

blended = (blended * (dem_max - dem_min)) + dem_min
# blended = blended[5:-5, 5:-5]

ground_truth = denormalize(dem, dem_max, dem_min)
inpainted_full = denormalize(x2, dem_max, dem_min)
combined = denormalize(inpainted_result, dem_max, dem_min)

plt.imshow(infill[:,:,0], cmap='terrain')
plt.colorbar()
plt.figure(figsize=(20,4))
plt.subplot(1,5,1)
plt.imshow(ground_truth[:,:,0], cmap='terrain')
plt.title("Ground Truth")
plt.subplot(1,5,2)
plt.imshow(mask, cmap='gray')
plt.title("Mask")
plt.subplot(1,5,3)
plt.imshow(inpainted_full[:,:,0], cmap='terrain')
plt.title("Inpainted Result")
plt.subplot(1,5,4)
plt.imshow(combined[:,:,0], cmap='terrain')
plt.title("Combined")
plt.subplot(1,5,5)
plt.imshow(blended, cmap='terrain')
plt.title("Poisson Blended")
plt.savefig(f'test_results/{tile_n}_fig.png', dpi=300, format='png')
plt.show()

Save

if not os.path.exists('test_results'):
    os.makedirs('test_results')

def write_geotiff(filename, arr):

    driver = gdal.GetDriverByName("GTiff")
    out_ds = driver.Create(filename, arr.shape[1], arr.shape[0], 1, gdal.GDT_Float32)
    out_ds.SetProjection(crs)

    # Get properties from input DEM
    upper_left, pixel_width, rotation, upper_right, rotation, pixel_height = geo_transform

    # Calculate tile coordinates
    upper_left += (tile_n + 1) * 256
    upper_right += (tile_n + 1) * 256

    # Set Geo-transform
    out_ds.SetGeoTransform((upper_left, pixel_width, rotation, upper_right, rotation, pixel_height))

    band = out_ds.GetRasterBand(1)
    band.WriteArray(arr)
    band.FlushCache()
    band.ComputeStatistics(False)

write_geotiff(f'test_results/{tile_n}_inpaint_poisson.tif', blended)
write_geotiff(f'test_results/{tile_n}_inpaint.tif', combined[:,:,0])
write_geotiff(f'test_results/{tile_n}_gt.tif', ground_truth[:,:,0])
write_geotiff(f'test_results/{tile_n}_void.tif', void)