In [None]:
# Deep learning-based bathymetry retrieval without in-situ depths using remote sensing imagery and SfM-MVS DSMs with data gaps
# Initial Pytorch Implementation: Panagiotis Agrafiotis (https://github.com/pagraf)
# Email: agrafiotis.panagiotis@gmail.com

# Description:  Swin-BathyUNet, a deep learning model that combines U-Net with Swin Transformer self-attention 
# layers and a cross-attention mechanism, tailored specifically for SDB. Swin-BathyUNet is designed to improve 
# bathymetric accuracy by capturing long-range spatial relationships and can also function as a standalone solution 
# for standard bathymetric mapping with various training depth data, independent of SfM-MVS output.
# It outputs continuous values.

# If you use this code please cite our paper: P. Agrafiotis and B. Demir, "Deep learning-based bathymetry retrieval without in-situ depths using remote sensing imagery and SfM-MVS DSMs with data gaps" arXiv preprint arXiv:2504.11416 (2025).

# This .ipynb was structured inspired by the [DeepNetsForEO] (https://github.com/nshaud/DeepNetsForEO) repository.

# Attribution-NonCommercial-ShareAlike 4.0 International License

# Copyright (c) 2025 Panagiotis Agrafiotis

# This license requires that reusers give credit to the creator. It allows reusers 
# to distribute, remix, adapt, and build upon the material in any medium or format,
# for noncommercial purposes only. If others modify or adapt the material, they 
# must license the modified material under identical terms.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


# This work is part of MagicBathy project funded by the European Union’s HORIZON Europe research and innovation 
# programme under the Marie Skłodowska-Curie GA 101063294. Work has been carried out at the Remote Sensing Image 
# Analysis group. For more information about the project visit https://www.magicbathy.eu/.

In [None]:
# General imports
import os
import sys
import time
import random
import itertools
import numpy as np
import numpy.ma as ma
import scipy
import scipy.ndimage

# Image processing
from skimage import io, transform
from skimage.transform import resize, rotate

# PyTorch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torch.optim.lr_scheduler
import torch.nn.init
from torch.autograd import Variable
from torchvision.transforms import RandomCrop, Resize

# Metrics
from sklearn.metrics import confusion_matrix, precision_score, recall_score, mean_squared_error

# Additional libraries
from glob import glob
from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output
import rasterio
import gdal
import cupy as cp

# Project-specific imports
from bathymetry.swin-bathyunet import *

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.init()

In [3]:
# data are assumed to be in MagicBathyNet/ directory (https://www.magicbathy.eu/magicbathynet.html)
FOLDER = '/.../magicbathynet/'
sys.path.append(FOLDER)

# Parameters

In [None]:
# Parameters
norm_param = np.load('/.../magicbathynet/agia_napa/norm_param_aerial.npy')
norm_param_depth = -14.556  #-14.556 FOR AGIA NAPA, -5 FOR PUCK LAGOON
WINDOW_SIZE = (720, 720)
WINDOW_SIZE_sfm = (720, 720)
STRIDE = 1
BATCH_SIZE = 1
MAIN_FOLDER = FOLDER
 
train_images = ['1','2',...] #The train-test splits used are the same with MagicBathyNet: https://github.com/pagraf/MagicBathyNet



test_images = train_images
num_train_images = len(train_images)*2

net = UNetWithAttention(in_channels=3, out_channels=1)

base_lr = 0.00025 # 0.000125 for PUCK LAGOON AERIAL

CACHE = False # Store the dataset in-memory
IMG_FOLDER = MAIN_FOLDER + 'agia_napa/img/aerial/img_{}.tif'
SFM_FOLDER = MAIN_FOLDER + 'agia_napa/depth/aerial_sfm/sfm_{}.tif'
LIDAR_FOLDER = MAIN_FOLDER + 'agia_napa/depth/aerial_interpolated/interpolated_depth_{}.tif'


# Visualizing the dataset

In [None]:
# We load one tile from the dataset and we display it
img = io.imread(MAIN_FOLDER+'agia_napa/img/aerial/img_410.tif')
#print(img.shape)
fig = plt.figure()
fig.add_subplot(131)
norm_img = (img - norm_param[0]) / (norm_param[1] - norm_param[0]) 
plt.imshow(norm_img)

# We load the ground truth
gt = io.imread(MAIN_FOLDER+'agia_napa/depth/aerial_sfm/sfm_410.tif')
fig.add_subplot(132)
plt.imshow(gt/norm_param_depth)
#print(gt/norm_param_depth)

# We load the LiDAR ground truth
gt_lidar = io.imread(MAIN_FOLDER+'agia_napa/depth/aerial/depth_410.tif')
fig.add_subplot(133)
plt.imshow(gt_lidar/norm_param_depth)
#print(gt_lidar/norm_param_depth)

plt.show()

In [None]:
# Utils

def get_random_pos(img, window_shape):
    """ Extract of 2D random patch of shape window_shape in the image """
    w, h = window_shape
    W, H = img.shape[-2:]
    x1 = random.randint(0, W - w)
    x2 = x1 + w
    y1 = random.randint(0, H - h)
    y2 = y1 + h
    return x1, x2, y1, y2

def sliding_window(top, step=10, window_size=(20,20)):
    """ Slide a window_shape window across the image with a stride of step """
    for x in range(0, top.shape[0], step):
        if x + window_size[0] > top.shape[0]:
            x = top.shape[0] - window_size[0]
        for y in range(0, top.shape[1], step):
            if y + window_size[1] > top.shape[1]:
                y = top.shape[1] - window_size[1]
            yield x, y, window_size[0], window_size[1]
            
def count_sliding_window(top, step=10, window_size=(20,20)):
    """ Count the number of windows in an image """
    c = 0
    for x in range(0, top.shape[0], step):
        if x + window_size[0] > top.shape[0]:
            x = top.shape[0] - window_size[0]
        for y in range(0, top.shape[1], step):
            if y + window_size[1] > top.shape[1]:
                y = top.shape[1] - window_size[1]
            c += 1
    return c

def grouper(n, iterable):
    """ Browse an iterator by chunk of n elements """
    it = iter(iterable)
    while True:
        chunk = tuple(itertools.islice(it, n))
        if not chunk:
            return
        yield chunk

def metrics_test(predictions, gts, mask):

    # Ensure inputs are on the GPU
    predictions = cp.asarray(predictions)
    gts = cp.asarray(gts)
    mask = cp.asarray(mask)
    # Exclude 0 values from calculation
    non_zero_mask = mask != 0
    
    # Calculate RMSE, MAE, and collect predictions and targets
    rmse = cp.sqrt(cp.mean(((predictions - gts) ** 2)[non_zero_mask]))*-norm_param_depth
    mae = cp.mean(cp.abs((predictions - gts)[non_zero_mask]))*-norm_param_depth
    std_dev = cp.std((predictions - gts)[non_zero_mask])*-norm_param_depth
    
    print("RMSE : {:.3f}m".format(rmse))
    print("MAE : {:.3f}m".format(mae))
    print("Std_Dev : {:.3f}m".format(std_dev))
    print("---")
    
    return rmse

def metrics(predictions, gts):

    # Ensure inputs are on the GPU
    predictions = cp.asarray(predictions)
    gts = cp.asarray(gts)

    # Exclude 0 values from calculation
    non_zero_mask = gts != 0
    
    # Calculate RMSE, MAE, and collect predictions and targets
    rmse = cp.sqrt(cp.mean(((predictions - gts) ** 2)[non_zero_mask]))*-norm_param_depth
    mae = cp.mean(cp.abs((predictions - gts)[non_zero_mask]))*-norm_param_depth
    std_dev = cp.std((predictions - gts)[non_zero_mask])*-norm_param_depth
    
    print("RMSE : {:.3f}m".format(rmse))
    print("MAE : {:.3f}m".format(mae))
    print("Std_Dev : {:.3f}m".format(std_dev))
    print("---")
    
    return rmse

def read_geotiff(filename, b):
    ds = gdal.Open(filename)
    band = ds.GetRasterBand(b)
    arr = band.ReadAsArray()
    return arr, ds

def write_geotiff(filename, arr, in_ds):
    if arr.dtype == np.float32:
        arr_type = gdal.GDT_Float32
    else:
        arr_type = gdal.GDT_Int32

    driver = gdal.GetDriverByName("GTiff")
    out_ds = driver.Create(filename, arr.shape[1], arr.shape[0], 1, arr_type)
    out_ds.SetProjection(in_ds.GetProjection())
    out_ds.SetGeoTransform(in_ds.GetGeoTransform())
    band = out_ds.GetRasterBand(1)
    band.WriteArray(arr)
    band.FlushCache()
    band.ComputeStatistics(False)


# Loading the dataset

We define a PyTorch dataset (`torch.utils.data.Dataset)` that loads all the tiles in memory and performs random sampling. Tiles are stored in memory on the fly.

The dataset also performs random data augmentation (horizontal and vertical flips) and normalizes the data in [0, 1].

In [None]:
# Dataset class
from scipy.ndimage import gaussian_filter
random.seed(1)
        
class MAGICBATHYNET_dataset(torch.utils.data.Dataset):
    def __init__(self, ids, img_files=IMG_FOLDER, sfm_files=SFM_FOLDER, lidar_files=LIDAR_FOLDER,
                            cache=False, augmentation=True):
        super(MAGICBATHYNET_dataset, self).__init__()
        
        self.augmentation = augmentation
        self.cache = cache
        
        # List of files
        self.img_files = [IMG_FOLDER.format(id) for id in ids]
        self.sfm_files = [SFM_FOLDER.format(id) for id in ids]
        self.lidar_files = [LIDAR_FOLDER.format(id) for id in ids]

        # Sanity check : raise an error if some files do not exist
        for f in self.img_files + self.sfm_files + self.lidar_files:
            if not os.path.isfile(f):
                raise KeyError('{} is not a file !'.format(f))
        
        # Initialize cache dicts
        self.img_cache_ = {}
        self.sfm_cache_ = {}
        self.lidar_cache_ = {}
            
    
    def __len__(self):
        # Default epoch size is 10 000 samples
        return num_train_images
    
    @classmethod
    def data_augmentation(cls, *arrays, flip=True, mirror=True):
        will_flip, will_mirror = False, False
        if flip and random.random() < 0.5:
            will_flip = True
        if mirror and random.random() < 0.5:
            will_mirror = True
        
        results = []
        for array in arrays:
            if will_flip:
                if len(array.shape) == 2:
                    array = array[::-1, :]
                else:
                    array = array[:, ::-1, :]
            if will_mirror:
                if len(array.shape) == 2:
                    array = array[:, ::-1]
                else:
                    array = array[:, :, ::-1]
            results.append(np.copy(array))
            
        return tuple(results)
    
    def __getitem__(self, i):
        while True:
            # Pick a random image
            random_idx = random.randint(0, len(self.img_files) - 1)

            # If the tile hasn't been loaded yet, put in cache
            if random_idx in self.img_cache_.keys():
                data = self.img_cache_[random_idx]
            else:
                # Data is normalized in [0, 1]
                data = np.asarray(io.imread(self.img_files[random_idx]).transpose((2,0,1)), dtype='float32')
                data = (data - norm_param[0][:, np.newaxis, np.newaxis]) / (norm_param[1][:, np.newaxis, np.newaxis] - norm_param[0][:, np.newaxis, np.newaxis]) 

                
                if self.cache:
                    self.img_cache_[random_idx] = data

            if random_idx in self.sfm_cache_.keys():
                label_sfm = self.sfm_cache_[random_idx]
            else: 
                # Labels are converted from RGB to their numeric values
                label_sfm = 1/norm_param_depth * np.asarray(io.imread(self.sfm_files[random_idx]), dtype='float32')
                
 
                
                if self.cache:
                    self.sfm_cache_[random_idx] = label_sfm

            if random_idx in self.lidar_cache_.keys():
                label_lidar = self.lidar_cache_[random_idx]
            else: 
                # Labels are converted from RGB to their numeric values
                label_lidar = 1/norm_param_depth * np.asarray(io.imread(self.lidar_files[random_idx]), dtype='float32')

                
                if self.cache:
                    self.lidar_cache_[random_idx] = label_lidar

            #x1, x2, y1, y2 = get_random_pos(data, WINDOW_SIZE)
            data_p = data#[:, x1:x2, y1:y2]
            label_sfm_p = label_sfm#[x1:x2, y1:y2]
            label_lidar_p = label_lidar#[x1:x2, y1:y2]

            data_p, label_sfm_p, label_lidar_p = self.data_augmentation(data_p, label_sfm_p, label_lidar_p)
            
            data_p = torch.from_numpy(data_p)
            label_sfm_p = torch.from_numpy(label_sfm_p)
            label_lidar_p =   torch.from_numpy(label_lidar_p)
                      
            return data_p, label_sfm_p, label_lidar_p

      

# Network definition 

In [8]:
net = net.to(device)


# Loading the data

We now create a train/test split. If you want to use another dataset, you have to adjust the method to collect all filenames. In our case, we specify a fixed train/test split for the demo.


In [None]:
# Load the datasets
all_files = sorted(glob(SFM_FOLDER.replace('{}', '*')))
all_ids = [f.split('area')[-1].split('.')[0] for f in all_files]

train_ids = train_images
test_ids = test_images

print("Tiles for training : ", train_ids)
print("Tiles for testing : ", test_ids)

train_set = MAGICBATHYNET_dataset(train_ids, cache=CACHE)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE)


# Designing the optimizer

In [None]:
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import CosineAnnealingLR

params_dict = dict(net.named_parameters())
params = []
for key, value in params_dict.items():
    if '_D' in key:
        # Decoder weights are trained at the nominal learning rate
        params += [{'params':[value],'lr': base_lr}]
    else:
        # Encoder weights are trained at lr / 2 (we have VGG-16 weights as initialization)
        params += [{'params':[value],'lr': base_lr}]  

optimizer = optim.Adam(net.parameters(), lr=base_lr)
scheduler = CosineAnnealingLR(optimizer, T_max=10)  # T_max is the number of epochs

In [None]:
crop_size_t = 720
pad_size = 16

def test(net, test_ids):
    
    # Use the network on the test set
    test_images = ((np.asarray(io.imread(IMG_FOLDER.format(id)), dtype='float32') - norm_param[0]) / (norm_param[1] - norm_param[0]) for id in test_ids)

    test_labels_sfm = [1 / norm_param_depth * np.asarray(io.imread(SFM_FOLDER.format(id)), dtype='float32') for id in test_ids]
    eroded_labels_sfm = [1 / norm_param_depth * np.asarray(io.imread(SFM_FOLDER.format(id)), dtype='float32') for id in test_ids]
        
    test_labels_lidar = [1 / norm_param_depth * np.asarray(io.imread(LIDAR_FOLDER.format(id)), dtype='float32') for id in test_ids]
    eroded_labels_lidar = [1 / norm_param_depth * np.asarray(io.imread(LIDAR_FOLDER.format(id)), dtype='float32') for id in test_ids]
    
    all_preds = []

    # Switch the network to inference mode
    net.eval()
    
    ratio = crop_size_t / WINDOW_SIZE[0]
    
    for img, gt_sfm, gt_e_sfm, gt_lidar, gt_e_lidar in tqdm(zip(test_images, test_labels_sfm, eroded_labels_sfm, test_labels_lidar, eroded_labels_lidar), total=len(test_ids), leave=False):
        img = scipy.ndimage.zoom(img, (ratio, ratio, 1), order=1)
        
        gt_lidar = scipy.ndimage.zoom(gt_lidar, (ratio, ratio), order=1)
        gt_e_lidar = scipy.ndimage.zoom(gt_e_lidar, (ratio, ratio), order=1)
        
        # Pad the image, ground truth, and eroded ground truth with reflection
        img = np.pad(img, ((pad_size, pad_size), (pad_size, pad_size), (0, 0)), mode='reflect')
        
        gt_lidar = np.pad(gt_lidar, ((pad_size, pad_size), (pad_size, pad_size)), mode='reflect')
        gt_e_lidar = np.pad(gt_e_lidar, ((pad_size, pad_size), (pad_size, pad_size)), mode='reflect')
        

        img_tensor = np.copy(img).transpose((2, 0, 1))
        img_tensor = np.expand_dims(img_tensor, axis=0)
        img_tensor = torch.from_numpy(img_tensor).cuda()

        with torch.no_grad():
            outs = net(img_tensor.float())
            pred = outs.data.cpu().numpy().squeeze()

        # Remove padding from prediction
        pred = pred[pad_size:-pad_size, pad_size:-pad_size]

        # Append predictions
        all_preds.append(pred)

    # Returning only the predictions
    return all_preds

# Training the network

Let's train the network for 1 epoch to see how it works (back at work, better train for 50 epochs). The `matplotlib` graph is periodically udpated with the loss plot and a sample inference. It might takes a few minutes on GPUs in the cloud.

If using the notebook on your own machine with the full 50 epochs, depending on your GPU, this might take from a few hours (Titan Pascal) to a full day (old K20).


In [None]:
crop_size = 720
size=(720, 720)
loss_weight = 1.0
epoch_folder = '/.../magicbathynet/epoch_folder/'  

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as F2
from torch.autograd import Variable
import math

from scipy.ndimage import distance_transform_edt

import torch
import torch.nn as nn
import numpy as np
from scipy.ndimage import distance_transform_edt

class CustomLossW(nn.Module):
    def __init__(self, max_distance=5, min_distance=0, decay='linear'):
        super(CustomLossW, self).__init__()
        self.max_distance = max_distance
        self.min_distance = min_distance
        self.decay = decay

    def forward(self, output, depth, mask):
        
        def distance_transform_edt_cuda(mask):
            mask_np = mask.cpu().numpy()
            distances = np.zeros_like(mask_np, dtype=np.float32)

            for i in range(mask_np.shape[0]):
                distance_edt = distance_transform_edt(mask_np[i] == 1)
                distances[i] = distance_edt

            return torch.tensor(distances, device=mask.device) 
        distances = distance_transform_edt_cuda(mask)

        # Clip distances to min_distance and max_distance
        distances = torch.clamp(distances, min=self.min_distance, max=self.max_distance)

        # Calculate weights based on distances: smaller distance larger weight
        if self.decay == 'linear':
            weights = 1 - (distances - self.min_distance) / (self.max_distance - self.min_distance)
        elif self.decay == 'exponential':
            weights = torch.exp(-(distances - self.min_distance) / (self.max_distance - self.min_distance))
        
        # Apply a linear transformation to the weights to set the range
        min_weight = 1
        max_weight = 2
        weights = min_weight + (max_weight - min_weight) * (weights - torch.min(weights)) / (torch.max(weights) - torch.min(weights))
        
        # Set weights to zero where mask is zero
        weights = (weights) * mask.float()

        # Calculate the MSE loss
        mse_loss = nn.MSELoss(reduction='none').to(device)
        loss = mse_loss(output, depth)

        # Apply the weights
        weighted_loss = loss * weights
        weighted_loss = (weighted_loss * mask).sum()  # Sum over the unmasked elements
        non_zero_elements = mask.sum()
        rmse_loss_val = torch.sqrt(weighted_loss / non_zero_elements)

        return rmse_loss_val #, weights, distance_edt
    

#MagicBathyNet's Loss
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, output, depth, mask):
        # Mask out areas with no annotations
        mse_loss = nn.MSELoss(reduction='none').to(device)
        loss = mse_loss(output, depth)
        loss = (loss * mask.float()).sum() # gives \sigma_euclidean over unmasked elements
        non_zero_elements = mask.sum()
        rmse_loss_val = torch.sqrt(loss / non_zero_elements)

        return rmse_loss_val


def train(net, optimizer, epochs, scheduler=None, save_epoch = 1000):
    global epoch_folder
    global data_folder
    losses = np.zeros(10000000)
    mean_losses = np.zeros(100000000)
    mean_rmse = np.zeros(1000000)
    mean_lidar_rmse_plot = np.zeros(1000000)
    mean_mse_plot = np.zeros(1000000)
    rmse = np.zeros(1000000)
    lidar_rmse_plot = np.zeros(1000000)
    mse_plot = np.zeros(1000000)
    epoch_folder = epoch_folder
    
    criterion = CustomLossW().to(device)
    
    iter_ = 0

    for e in range(1, epochs + 1):
        loss = 0
        if scheduler is not None:
            scheduler.step(loss)
        net.train()
        
        for batch_idx, (data, target, lidar) in enumerate(train_loader):
            data, target, lidar = Variable(data.to(device)), Variable(target.to(device)), Variable(lidar.to(device))
            optimizer.zero_grad()
            
            target = lidar            
            data_size = data.size()[2:]  # Get the original data size

            # Generate mask for non-annotated pixels in depth data (example: identify pixels with value 0)
            target_mask = (target.cpu().numpy() != 0).astype(np.float32)  # Modify this condition based on your annotation criteria
            target_mask = torch.from_numpy(target_mask)  # Convert the modified NumPy array back to PyTorch tensor if needed
            target_mask = target_mask.to(device)  # Move tensor to device (GPU, if available)
            
            lidar_mask = (lidar.cpu().numpy() != 0).astype(np.float32)  # Modify this condition based on your annotation criteria
            lidar_mask = torch.from_numpy(lidar_mask)  # Convert the modified NumPy array back to PyTorch tensor if needed
            lidar_mask = lidar_mask.to(device)  # Move tensor to device (GPU, if available)
            
            data_mask = (data.cpu().numpy() != 0).astype(np.float32)  # Modify this condition based on your annotation criteria
            data_mask = np.mean(data_mask, axis=1)
            data_mask = torch.from_numpy(data_mask)  # Convert the modified NumPy array back to PyTorch tensor if needed
            data_mask = data_mask.to(device)  # Move tensor to device (GPU, if available)
            
            # Combine the masks
            combined_mask = target_mask * data_mask
            combined_mask = (combined_mask >= 0.5).float().to(device)
            mask_np = combined_mask.cpu().numpy()

            
            lidar_combined_mask = lidar_mask * data_mask
            lidar_combined_mask = (lidar_combined_mask >= 0.5).float().to(device)


            # Check if combined_mask is 0
            if torch.sum(combined_mask) == 0:
            # Use another pair of data and target
                continue
                
            data = torch.clamp(data, min=0, max=1)
            target = torch.clamp(target, min=0, max=1)
            output = net(data.float()).to(device)

            
            # Check for NaN in output
            if torch.isnan(output).any():
                print("NaN values found in output!")
                # Print more information to debug
                print("Output:", output)
                continue  # Skip this batch
                
            loss = criterion(output, target.unsqueeze(1), combined_mask) * loss_weight
            loss.backward()         
            optimizer.step()

        # Check if loss is NaN
            if torch.isnan(loss):
            # Add more debugging info here if needed
                print("NaN loss detected. Investigate further.")
            
            losses[iter_] = loss.item() ##loss.data[0]
            mean_losses[iter_] = np.mean(losses[max(0,iter_-100):iter_])
            
            mean_rmse[iter_] = np.mean(losses[max(0,iter_-100):iter_])*-norm_param_depth/loss_weight
            
            pred = output.data.cpu().numpy()[0]
            gt = target.data.cpu().numpy()[0]
            gt_lidar = lidar.data.cpu().numpy()[0]
            
            # Apply the mask to the predictions and ground truth
            masked_pred = pred * combined_mask.cpu().numpy()
            masked_gt = gt * combined_mask.cpu().numpy()
            masked_gt_lidar = gt_lidar * lidar_combined_mask.cpu().numpy()
            masked_pred_lidar = pred * lidar_combined_mask.cpu().numpy()
            
            if iter_ % 100 == 0:
                if iter_ % 1000 == 0 and iter_ != 0:
                    try:
                        os.mkdir(IMG_FOLDER)
                    except FileExistsError:
                        pass
                clear_output()
                rgb = np.asarray(np.transpose(data.data.cpu().numpy()[0],(1,2,0)), dtype='float32')
                pred = output.data.cpu().numpy()[0]
                gt = target.data.cpu().numpy()[0]
                gt_lidar = lidar.data.cpu().numpy()[0]
                masked_pred = pred * combined_mask.cpu().numpy()
                masked_pred_lidar = pred * lidar_combined_mask.cpu().numpy()
                masked_gt = gt * combined_mask.cpu().numpy()
                masked_gt_lidar = gt_lidar * lidar_combined_mask.cpu().numpy()
                
                print('Train (epoch {}/{}) [{}/{} ({:.0f}%)]\tLoss: {:.6f}\Mean RMSE in m: {}'.format(
                    e, epochs, batch_idx, len(train_loader),
                    100. * batch_idx / len(train_loader), loss.item(), -norm_param_depth * metrics(np.concatenate([p.ravel() for p in masked_pred]), np.concatenate([p.ravel() for p in masked_gt]).ravel()))) ##loss.data[0]
               
                 # Plot loss
                fig1, ax1 = plt.subplots(figsize=(8, 5))
                ax1.plot(mean_losses[:iter_], 'blue')
                ax1.set_title('Training Loss')
                ax1.set_xlabel('Iteration')
                ax1.set_ylabel('Loss')
                ax1.grid(color='black', linestyle='-', linewidth=0.5)

                # Plot accuracy
                fig2, ax2 = plt.subplots(figsize=(8, 5))
                ax2.plot(mean_rmse[:iter_], 'red', label='vs SfM')
                #ax2.plot(-norm_param_depth * mean_lidar_rmse_plot[:iter_], 'green', label='vs_LiDAR')
                ax2.set_title('Mean RMSE in m')
                ax2.set_xlabel('Iteration')
                ax2.set_ylabel('Accuracy')
                ax2.grid(color='black', linestyle='-', linewidth=0.5)
                ax2.legend()
    
                plt.show()
    
                if iter_ % 1000 == 0 and iter_ != 0:
                    #fig1.savefig(IMG_FOLDER + "/train_loss_{}_out_of_{}".format(e, epochs))
                    fig2.savefig(IMG_FOLDER + "/validation_accuracy_{}_out_of_{}".format(e, epochs))

                # Display the figure
                plt.show()
                
                diff = (masked_gt - masked_pred) * norm_param_depth
                diff_lidar = (masked_gt_lidar - masked_pred_lidar) * norm_param_depth
        
                fig = plt.figure(figsize=(16.0, 10.0))
                
                fig.add_subplot(151)
                plt.imshow(rgb)
                plt.title('RGB')
                
                fig.add_subplot(152)
                plt.imshow(gt, cmap='viridis_r', vmin=0, vmax=1)
                plt.title('Ground truth')
                
                fig.add_subplot(153)
                plt.title('Prediction')
                plt.imshow(pred[0,:,:],  cmap='viridis_r', vmin=0, vmax=1)
                
                plt.suptitle('Train (epoch {}/{}) [{}/{} ({:.0f}%)]\nLoss: {:.6f}'.format(
                    e, epochs, batch_idx, len(train_loader),
                    100. * batch_idx / len(train_loader), loss.item()))
                plt.show()

                if iter_ % 1000 == 0 and iter_ != 0:
                    # plt.savefig(MAIN_FOLDER + model_folder +"output_data_filled_irfanview_no_shades_10/diagram_{}_out_of_{}".format(e,epochs))
                    fig.savefig(IMG_FOLDER + "/train_images_{}_out_of_{}".format(e, epochs))
                    # plt.savefig("Train_epoch_{}/{}_{}/{}_({:.0f}%).png".format(e, epochs, batch_idx, len(train_loader), 100. * batch_idx / len(train_loader)))
            iter_ += 1
            
            del(data, target, loss)      
            
        if e % save_epoch == 0:
            try:
                os.mkdir(epoch_folder)
            except FileExistsError:
                pass

            # We validate with the largest possible stride for faster computing
            #acc = test(net, test_ids, all=False)
            torch.save(net.state_dict(),epoch_folder + 'model_epoch{}'.format(e))
    torch.save(net.state_dict(), epoch_folder + 'model_final')

In [None]:
train(net, optimizer, 10)


# Testing the network

Now that the training has ended, we can load the final weights and test the network using a reasonable stride, e.g. half or a quarter of the window size. Inference time depends on the chosen stride, e.g. a step size of 32 (75% overlap) will take 10 secondes / image.


In [None]:
net.load_state_dict(torch.load('/.../magicbathynet/epoch_folder/model_final'))


# Saving the results

We can visualize and save the resulting tiles for qualitative assessment.


In [None]:
all_preds = test(net, test_ids)

In [None]:
ratio = crop_size_t / WINDOW_SIZE[0]

for p, id_ in zip(all_preds, test_ids):
    img = p*norm_param_depth
    
    img = scipy.ndimage.zoom(img, (1/ratio, 1/ratio), order=1)
    
    #print(img)
    plt.imshow(img, cmap='viridis') and plt.show()

    #io.imsave('/.../inference_tile_{}.png'.format(id_), img)
    nlcd02_arr_1, nlcd02_ds_1 = read_geotiff(MAIN_FOLDER + '.../img_###.tif', 3)
    write_geotiff('./inference_tile_{}.tif'.format(id_), img, nlcd02_ds_1)

In [18]:
with torch.no_grad():
    torch.cuda.empty_cache()