Load a pre-trained model and generate some GeoTiffs with it
import os
import random
import toml
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
from osgeo import gdal
import matplotlib as mpl
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from networks import Generator
from tools import random_bbox, mask_image
from blend import blend_images
with open('config.toml', 'r') as file:
config = toml.load(file)
# Test specific settings
config['cuda'] = False
config['dataset_name'] = "../datac/NAC_DTM_WIENERF.TIF"
config['seed'] = 1433
def normalise(arr):
normalized = (arr - arr.min()) / (arr.max() - arr.min())
normalized = 2*normalized - 1
return normalized
def partial_norm(arr):
normalized = (arr - arr.min()) / (arr.max() - arr.min())
normalized = normalized - 1
return normalized
def tile(dataset, kernel_size):
dem = gdal.Open(dataset)
crs = dem.GetProjection()
geo_transform = dem.GetGeoTransform()
image = dem.ReadAsArray()
img_height, img_width = image.shape
tile_height, tile_width = kernel_size
# If cant divide perfectly
if (img_height % tile_height != 0 or img_width % tile_width != 0):
new_height = img_height - (img_height % tile_height)
new_width = img_width - (img_width % tile_width)
image = image[:new_height, :new_width]
tiles_high = img_height // tile_height
tiles_wide = img_width // tile_width
tiled_array = image.reshape(tiles_high,
tile_height,
tiles_wide,
tile_width )
tiled_array = tiled_array.swapaxes(1, 2)
tiled_array = tiled_array.reshape(tiles_high * tiles_wide, tile_height, tile_width)
# GC should get this, but just to be safe
dem = None
min_max = []
for arr in tiled_array:
min_max.append((arr.min(), arr.max()))
vectorized_normalise = np.vectorize(normalise, signature='(n,m)->(n,m)')
# vectorized_partial_norm = np.vectorize(partial_norm, signature='(n,m)->(n,m)')
tiled_array = vectorized_normalise(tiled_array)
# Slope
cellsize = geo_transform[1]
px, py = np.gradient(tiled_array, cellsize, axis=(1,2))
slope = np.arctan(np.sqrt(px ** 2 + py ** 2))
slope = vectorized_normalise(slope)
# # RDLS
# windowed = sliding_window_view(tiled_array, (3,3), axis=(1,2))
# rdls = np.ptp(windowed, axis=(3,4))
# rdls = np.pad(rdls, ((0,0), (1,1), (1,1)), mode='constant', constant_values=0)
# rdls = vectorized_normalise(rdls)
all = np.stack((tiled_array, slope), axis=3)
# tiled_array = np.expand_dims(tiled_array, axis=3)
all = np.transpose(all, (0, 3, 1, 2))
# H,W,C to C,H,W
return torch.from_numpy(all), min_max, crs, geo_transform
# Return data from -1<->1 normalisation to original state
def denormalize(tensor, max, min):
arr = tensor.cpu().detach().numpy()
arr = np.squeeze(arr)
arr = np.transpose(arr, (1,2,0))
arr = (arr * 0.5) + 0.5
arr = (arr * (max - min)) + min
return arr
# Return data from -1<->1 normalisation to 0<->1 normalisation
# Used as poisson blending requires the DEM data to be 0<->1 but the inpainted DEM is returened -1<->1
def partial_dn(tensor):
arr = tensor.cpu().detach().numpy()
arr = np.squeeze(arr)
arr = np.transpose(arr, (1,2,0))
arr = (arr * 0.5) + 0.5
return arr
Can probably get rid of this, no training is happening
if config["seed"]:
seed = config["seed"]
random.seed(seed)
torch.manual_seed(seed)
tiled, min_max, crs, geo_transform = tile(config["dataset_name"], (256, 256))
tiled.shape
Not the most efficient way of doing things but since individual DEM files are (probably) much larger than the tiles the network is trained on. Also a lot of data (annoyingly) seems to be basic slopes that arent very interesting.
- There maybe is something to be said for trying to find high res (5m) DEMs with consistently complex terrain.
- Manually iterate through tiles until an interesting tile is found
- Generate infilled DEM
- If it is either really good or really shit save to file as it will be good for the report.
def display(image):
if isinstance(image, torch.Tensor):
image = image.cpu().detach().numpy()
image = image.squeeze()
image = np.transpose(image, (1,2,0))
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.imshow(image[:,:,0], cmap='terrain')
plt.title("DEM")
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(image[:,:,1], cmap='viridis')
plt.title("Slope")
print(len(tiled))
Great tiles: NAC_DTM_RIMASHARP
- 1087
NAC_DTM_TYCHOPK05
- 250
NAC_DTM_WIENERF
- 240
- 400
- 500
biggest_dif = 0
big_dem = 0
for i, (min, max) in enumerate(min_max):
if max - min > biggest_dif:
biggest_dif = max - min
big_dem = i
big_dem
tile_n = 240
dem = tiled[tile_n]
dem_min, dem_max = min_max[tile_n]
tt = dem.cpu().detach().numpy()
tt = np.transpose(tt, (1,2,0))
display(denormalize(dem, dem_max, dem_min))
plt.colorbar()
#### Transforms
#### Infill void
# Different from normal bbox
# (y1, x1, y2, x2)
# bboxes = torch.tensor([(0, 80, 256, 160)], dtype=torch.int64)
bboxes = torch.tensor([(74, 74, 182, 182)], dtype=torch.int64)
x, mask = mask_image(dem, bboxes, config, train=False)
checkpoint_path = "../slope_out/saved_models/gen_00000500.pt"
# checkpoint_path = "../out_final/saved_models/gen_00000168.pt"
# checkpoint_path = "out/saved_models/gen_00000036.pt"
inpainted_result = None
x2 = None
with torch.no_grad():
netG = Generator(config, config["cuda"])
netG.load_state_dict(torch.load(checkpoint_path))
x1, x2 = netG(x, mask)
inpainted_result = x2 * mask + x * (1. - mask)
#### De-normalize
inpainted_result_dn = denormalize(inpainted_result, dem_max, dem_min)
# ground_truth_dn = denormalize(ground_truth, img_max, img_min)
# inpainted_result = np.squeeze(inpainted_result)
display(inpainted_result_dn)
# plt.imshow(np.squeeze(inpainted_result), cmap='terrain')
inpainted_result_dn.max()
m = np.squeeze(mask)
offset = 1000
void = np.ones((256,256)) * (inpainted_result_dn.min() - 200)
for i, row in enumerate(inpainted_result_dn[:,:,0]):
for z, pixel in enumerate(row):
if m[i][z] == 0.:
void[i][z] = pixel
plt.imshow(void, cmap='terrain')
infill = partial_dn(x2)
gt = partial_dn(dem)
padded_infill = np.pad(infill[:,:,0], ((1,1), (0,0)), mode='constant', constant_values=0)
padded_gt = np.pad(gt[:,:,0], ((1,1), (0,0)), mode='constant', constant_values=0)
padded_mask = np.pad(mask, ((1,1), (0,0)), mode='constant', constant_values=0)
blended = blend_images(padded_infill, padded_gt, padded_mask)
blended = blended[1:-1, :]
blended = (blended * (dem_max - dem_min)) + dem_min
# blended = blended[5:-5, 5:-5]
ground_truth = denormalize(dem, dem_max, dem_min)
inpainted_full = denormalize(x2, dem_max, dem_min)
combined = denormalize(inpainted_result, dem_max, dem_min)
plt.imshow(infill[:,:,0], cmap='terrain')
plt.colorbar()
plt.figure(figsize=(20,4))
plt.subplot(1,5,1)
plt.imshow(ground_truth[:,:,0], cmap='terrain')
plt.title("Ground Truth")
plt.subplot(1,5,2)
plt.imshow(mask, cmap='gray')
plt.title("Mask")
plt.subplot(1,5,3)
plt.imshow(inpainted_full[:,:,0], cmap='terrain')
plt.title("Inpainted Result")
plt.subplot(1,5,4)
plt.imshow(combined[:,:,0], cmap='terrain')
plt.title("Combined")
plt.subplot(1,5,5)
plt.imshow(blended, cmap='terrain')
plt.title("Poisson Blended")
plt.savefig(f'test_results/{tile_n}_fig.png', dpi=300, format='png')
plt.show()
if not os.path.exists('test_results'):
os.makedirs('test_results')
def write_geotiff(filename, arr):
driver = gdal.GetDriverByName("GTiff")
out_ds = driver.Create(filename, arr.shape[1], arr.shape[0], 1, gdal.GDT_Float32)
out_ds.SetProjection(crs)
# Get properties from input DEM
upper_left, pixel_width, rotation, upper_right, rotation, pixel_height = geo_transform
# Calculate tile coordinates
upper_left += (tile_n + 1) * 256
upper_right += (tile_n + 1) * 256
# Set Geo-transform
out_ds.SetGeoTransform((upper_left, pixel_width, rotation, upper_right, rotation, pixel_height))
band = out_ds.GetRasterBand(1)
band.WriteArray(arr)
band.FlushCache()
band.ComputeStatistics(False)
write_geotiff(f'test_results/{tile_n}_inpaint_poisson.tif', blended)
write_geotiff(f'test_results/{tile_n}_inpaint.tif', combined[:,:,0])
write_geotiff(f'test_results/{tile_n}_gt.tif', ground_truth[:,:,0])
write_geotiff(f'test_results/{tile_n}_void.tif', void)