In [None]:
# Bathymetry prediction on aerial/satellite images using bathy-U-Net
Initial Pytorch Implementation: Panagiotis Agrafiotis (https://github.com/pagraf)
Email: agrafiotis.panagiotis@gmail.com

Description: magicbathy_unet.py is a simplified U-Net model modified for estimating water depth from RGB images. 
The model retains the encoder-decoder structure with reduced layers and channels, using skip connections to 
maintain spatial information during depth prediction. It outputs continuous values, suitable for depth estimation,
even with limited annotated data.

If you use this code please cite our paper: "Agrafiotis, P., Janowski, L., Skarlatos, D. & Demir, B. (2024) 
MagicBathyNet: A Multimodal Remote Sensing Dataset for Bathymetry Prediction and Pixel-based Classification 
in Shallow Waters, arXiv preprint arXiv:2405.15477, 2024."

This .ipynb was structured inspired by the [DeepNetsForEO] (https://github.com/nshaud/DeepNetsForEO) repository.





Attribution-NonCommercial-ShareAlike 4.0 International License

Copyright (c) 2024 The MagicBathyNet Authors

This license requires that reusers give credit to the creator. It allows reusers 
to distribute, remix, adapt, and build upon the material in any medium or format,
for noncommercial purposes only. If others modify or adapt the material, they 
must license the modified material under identical terms.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


This work is part of MagicBathy project funded by the European Union’s HORIZON Europe research and innovation 
programme under the Marie Skłodowska-Curie GA 101063294. Work has been carried out at the Remote Sensing Image 
Analysis group. For more information about the project visit https://www.magicbathy.eu/.



In [None]:
## GPU
Enable GPU with `Runtime->Change runtime type->Hardware Accelerator->GPU` in the top menu

In [None]:
# imports and stuff
import time
import numpy as np
from skimage import io
from glob import glob
from tqdm import tqdm_notebook as tqdm
from sklearn.metrics import confusion_matrix, precision_score, recall_score
import random
import itertools

# Matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

# Torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torch.optim.lr_scheduler
import torch.nn.init
from torch.autograd import Variable

# Other imports
import rasterio
import gdal
import scipy
import cv2
from unet_bathy import *
import sys
from IPython.display import clear_output
import numpy.ma as ma
from torchvision.transforms import RandomCrop, Resize
from skimage.transform import resize, rotate, resize_local_mean
from sklearn.metrics import mean_squared_error
import scipy.ndimage 
import os
from skimage.transform import resize as skimage_resize

In [7]:
# insert your data folder path here
FOLDER = '/.../'
sys.path.append(FOLDER)

MAIN_FOLDER = FOLDER

#RGB images and depth folders
DATA_FOLDER = MAIN_FOLDER + '.../img_{}.tif'
LABEL_FOLDER = MAIN_FOLDER + '.../depth_{}.tif'
ERODED_FOLDER = MAIN_FOLDER + '.../depth_{}.tif'

# Parameters

In [None]:
# Parameters

#set the dataset modality here

#dataset = "S2"
dataset = "SPOT6"
#dataset = "UAV"

if dataset == "UAV":
    norm_param = np.load('norm_param_aerial.npy')
    norm_param_depth = -30.443   #-30.443 FOR AGIA NAPA, -11 FOR PUCK LAGOON
    WINDOW_SIZE = (720, 720)
    STRIDE = 16
    BATCH_SIZE = 1
    MAIN_FOLDER = FOLDER
    train_images = ['409', '418', '350', '399', '361', '430', '380', '359', '371', '377', '379', '360', '368', '419', '389', '420', '401', '408', '352', '388', '362', '421', '412', '351', '349', '390', '400', '378']
    test_images = ['411', '387', '410', '398', '370', '369', '397']
    
elif dataset == "SPOT6":
    norm_param = np.load('norm_param_spot6_an.npy')
    norm_param_depth = -30.443   #-30.443 FOR AGIA NAPA, -11 FOR PUCK LAGOON
    WINDOW_SIZE = (30, 30)
    STRIDE = 2
    BATCH_SIZE = 1
    MAIN_FOLDER = FOLDER 
    train_images = ['409', '418', '350', '399', '361', '430', '380', '359', '371', '377', '379', '360', '368', '419', '389', '420', '401', '408', '352', '388', '362', '421', '412', '351', '349', '390', '400', '378']
    test_images = ['411', '387', '410', '398', '370', '369', '397']
    
elif dataset == "S2":
    norm_param = np.load('norm_param_s2_an.npy')
    norm_param_depth = -30.443   #-30.443 FOR AGIA NAPA, -11 FOR PUCK LAGOON
    WINDOW_SIZE = (18, 18)
    STRIDE = 2
    BATCH_SIZE = 1
    MAIN_FOLDER = FOLDER 
    train_images = ['409', '418', '350', '399', '361', '430', '380', '359', '371', '377', '379', '360', '368', '419', '389', '420', '401', '408', '352', '388', '362', '421', '412', '351', '349', '390', '400', '378']
    test_images = ['411', '387', '410', '398', '370', '369', '397']
    
print(norm_param)
print(norm_param.shape)

net = UNet_bathy(3, 1)
base_lr = 0.0001

CACHE = True # Store the dataset in-memory


# Visualizing the dataset

In [None]:
# We load one tile from the dataset and we display it
img = io.imread(MAIN_FOLDER+'.../img_409.tif')
fig = plt.figure()
fig.add_subplot(121)
norm_img = (img - norm_param[0]) / (norm_param[1] - norm_param[0]) 
plt.imshow(norm_img)

# We load the ground truth
gt = io.imread(MAIN_FOLDER+'.../depth_409.tif')
fig.add_subplot(122)
plt.imshow(gt/norm_param_depth)
print(gt/norm_param_depth)

plt.show()

We need to define some utils functions.

In [10]:
# Utils

def get_random_pos(img, window_shape):
    """ Extract of 2D random patch of shape window_shape in the image """
    w, h = window_shape
    W, H = img.shape[-2:]
    x1 = random.randint(0, W - w)
    x2 = x1 + w
    y1 = random.randint(0, H - h)
    y2 = y1 + h
    return x1, x2, y1, y2

def sliding_window(top, step=10, window_size=(20,20)):
    """ Slide a window_shape window across the image with a stride of step """
    for x in range(0, top.shape[0], step):
        if x + window_size[0] > top.shape[0]:
            x = top.shape[0] - window_size[0]
        for y in range(0, top.shape[1], step):
            if y + window_size[1] > top.shape[1]:
                y = top.shape[1] - window_size[1]
            yield x, y, window_size[0], window_size[1]
            
def count_sliding_window(top, step=10, window_size=(20,20)):
    """ Count the number of windows in an image """
    c = 0
    for x in range(0, top.shape[0], step):
        if x + window_size[0] > top.shape[0]:
            x = top.shape[0] - window_size[0]
        for y in range(0, top.shape[1], step):
            if y + window_size[1] > top.shape[1]:
                y = top.shape[1] - window_size[1]
            c += 1
    return c

def grouper(n, iterable):
    """ Browse an iterator by chunk of n elements """
    it = iter(iterable)
    while True:
        chunk = tuple(itertools.islice(it, n))
        if not chunk:
            return
        yield chunk

def metrics(predictions, gts):
    # Calculate RMSE, MAE, and collect predictions and targets
    rmse = np.sqrt(np.mean((predictions - gts) ** 2))
    mae = np.mean(np.abs(predictions - gts))
    std_dev = np.std(predictions - gts)
    
    print("RMSE : {:.3f}m".format(rmse*-norm_param_depth))
    print("MAE : {:.3f}m".format(mae*-norm_param_depth))
    print("Std_Dev : {:.3f}m".format(std_dev*-norm_param_depth))
    print("---")
    
    return rmse

def read_geotiff(filename, b):
    ds = gdal.Open(filename)
    band = ds.GetRasterBand(b)
    arr = band.ReadAsArray()
    return arr, ds

def write_geotiff(filename, arr, in_ds):
    if arr.dtype == np.float32:
        arr_type = gdal.GDT_Float32
    else:
        arr_type = gdal.GDT_Int32

    driver = gdal.GetDriverByName("GTiff")
    out_ds = driver.Create(filename, arr.shape[1], arr.shape[0], 1, arr_type)
    out_ds.SetProjection(in_ds.GetProjection())
    out_ds.SetGeoTransform(in_ds.GetGeoTransform())
    band = out_ds.GetRasterBand(1)
    band.WriteArray(arr)
    band.FlushCache()
    band.ComputeStatistics(False)


# Loading the dataset

We define a PyTorch dataset (`torch.utils.data.Dataset)` that loads all the tiles in memory and performs random sampling. Tiles are stored in memory on the fly.

The dataset also performs random data augmentation (horizontal and vertical flips) and normalizes the data in [0, 1].

In [11]:
# Dataset class

random.seed(1)
        
class dataset(torch.utils.data.Dataset):
    def __init__(self, ids, data_files=DATA_FOLDER, label_files=LABEL_FOLDER,
                            cache=False, augmentation=True):
        super(dataset, self).__init__()
        
        self.augmentation = augmentation
        self.cache = cache
        
        # List of files
        self.data_files = [DATA_FOLDER.format(id) for id in ids]
        self.label_files = [LABEL_FOLDER.format(id) for id in ids]
        
        

        # Sanity check : raise an error if some files do not exist
        for f in self.data_files + self.label_files:
            if not os.path.isfile(f):
                raise KeyError('{} is not a file !'.format(f))
        
        # Initialize cache dicts
        self.data_cache_ = {}
        self.label_cache_ = {}
            
    
    def __len__(self):
        # Default epoch size is 10 000 samples
        return 10000
    
    @classmethod
    def data_augmentation(cls, *arrays, flip=True, mirror=True):
        will_flip, will_mirror = False, False
        if flip and random.random() < 0.5:
            will_flip = True
        if mirror and random.random() < 0.5:
            will_mirror = True
        
        results = []
        for array in arrays:
            if will_flip:
                if len(array.shape) == 2:
                    array = array[::-1, :]
                else:
                    array = array[:, ::-1, :]
            if will_mirror:
                if len(array.shape) == 2:
                    array = array[:, ::-1]
                else:
                    array = array[:, :, ::-1]
            results.append(np.copy(array))
            
        return tuple(results)
    
    def __getitem__(self, i):
        
        # Pick a random image
        
        random_idx = random.randint(0, len(self.data_files) - 1)
        
        
        # If the tile hasn't been loaded yet, put in cache
        if random_idx in self.data_cache_.keys():
            data = self.data_cache_[random_idx]
        else:
            # Data is normalized in [0, 1]
            data = np.asarray(io.imread(self.data_files[random_idx]).transpose((2,0,1)), dtype='float32')
            data = (data - norm_param[0][:, np.newaxis, np.newaxis]) / (norm_param[1][:, np.newaxis, np.newaxis] - norm_param[0][:, np.newaxis, np.newaxis]) 
            
            if self.cache:
                self.data_cache_[random_idx] = data
            
        if random_idx in self.label_cache_.keys():
            label = self.label_cache_[random_idx]
        else: 
            # Labels are converted from RGB to their numeric values
            label = 1/norm_param_depth * np.asarray(io.imread(self.label_files[random_idx]), dtype='float32')
            if self.cache:
                self.label_cache_[random_idx] = label
        

        
        x1, x2, y1, y2 = get_random_pos(data, WINDOW_SIZE)
        data_p = data[:, x1:x2,y1:y2]
        #to be fixed, only for spot6
        label_p = label[x1:x2,y1:y2]


        data_p, label_p = self.data_augmentation(data_p, label_p)

        return (torch.from_numpy(data_p),
                torch.from_numpy(label_p))
        
        

# Load the Network on GPU

We can now instantiate the network using the specified parameters.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.init()
net = net.to(device)


# Loading the data

We now create a train/test split. In our case, we specify a fixed train/test split for benchmarking MagicBathyNet.


In [None]:
# Load the datasets
all_files = sorted(glob(LABEL_FOLDER.replace('{}', '*')))
all_ids = [f.split('area')[-1].split('.')[0] for f in all_files]

# Random tile numbers for train/test split
# train_ids = random.sample(all_ids, 2 * len(all_ids) // 3 + 1)
# test_ids = list(set(all_ids) - set(train_ids))

train_ids = train_images
test_ids = test_images

print("Tiles for training : ", train_ids)
print("Tiles for testing : ", test_ids)

train_set = dataset(train_ids, cache=CACHE)
train_loader = torch.utils.data.DataLoader(train_set,batch_size=BATCH_SIZE)


# Designing the optimizer
We use the `torch.optim.lr_scheduler` to reduce the learning rate by 10 after 10 or so epochs.


In [14]:
params_dict = dict(net.named_parameters())
params = []
for key, value in params_dict.items():
    if '_D' in key:
        params += [{'params':[value],'lr': base_lr}]
    else:
        params += [{'params':[value],'lr': base_lr}] 
        
optimizer = optim.Adam(net.parameters(), lr=base_lr)
# We define the scheduler
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [10], gamma=0.1)

In [15]:
crop_size = 256
pad_size = 32  # Define pad size here

def test(net, test_ids, all=True):
    # Use the network on the test set
    test_images = ((np.asarray(io.imread(DATA_FOLDER.format(id)), dtype='float32') - norm_param[0]) / (norm_param[1] - norm_param[0]) for id in test_ids)

    test_labels = [1 / norm_param_depth * np.asarray(io.imread(LABEL_FOLDER.format(id)), dtype='float32') for id in test_ids]
    eroded_labels = [1 / norm_param_depth * np.asarray(io.imread(LABEL_FOLDER.format(id)), dtype='float32') for id in test_ids]
    all_preds = []
    all_gts = []
    all_masked_preds = []
    all_masked_gts = []


    # Switch the network to inference mode
    net.eval()
    
    mse = None
    
    ratio = crop_size / WINDOW_SIZE[0]
    
    for img, gt, gt_e in tqdm(zip(test_images, test_labels, eroded_labels), total=len(test_ids), leave=False):
        img = scipy.ndimage.zoom(img, (ratio, ratio, 1), order=1)
        gt = scipy.ndimage.zoom(gt, (ratio, ratio), order=1)
        gt_e = scipy.ndimage.zoom(gt_e, (ratio, ratio), order=1)
        
        # Pad the image, ground truth, and eroded ground truth with reflection
        img = np.pad(img, ((pad_size, pad_size), (pad_size, pad_size), (0, 0)), mode='reflect')
        gt = np.pad(gt, ((pad_size, pad_size), (pad_size, pad_size)), mode='reflect')
        gt_e = np.pad(gt_e, ((pad_size, pad_size), (pad_size, pad_size)), mode='reflect')

        # Convert image to tensor
        img_tensor = np.copy(img).transpose((2, 0, 1))
        img_tensor = np.expand_dims(img_tensor, axis=0)
        img_tensor = torch.from_numpy(img_tensor).cuda()

        # Do the inference on the whole image
        with torch.no_grad():
            outs = net(img_tensor.float())
            pred = outs.data.cpu().numpy().squeeze()
            
        # Remove padding from prediction
        pred = pred[pad_size:-pad_size, pad_size:-pad_size]
        img = img[pad_size:-pad_size, pad_size:-pad_size]
        gt = gt[pad_size:-pad_size, pad_size:-pad_size]
        gt_e = gt_e[pad_size:-pad_size, pad_size:-pad_size]

        # Display the result
        clear_output()
        fig = plt.figure()
        fig.add_subplot(1, 3, 1)
        plt.imshow(np.asarray(255 * img, dtype='uint8'))
        fig.add_subplot(1, 3, 2)
        plt.imshow(pred)  
        fig.add_subplot(1, 3, 3)
        plt.imshow(gt)
        plt.show()
        
        # Generate mask for non-annotated pixels in depth data 
        gt_mask = (gt_e > 0).astype(np.float32) 
        gt_mask = torch.from_numpy(gt_mask) 
        gt_mask = gt_mask.unsqueeze(0)
        gt_mask = gt_mask.reshape(crop_size, crop_size)
        gt_mask = gt_mask.to(device) 

        img_mask = (img != 0).astype(np.float32) 
        img_mask = np.mean(img_mask, axis=2)
        img_mask = torch.from_numpy(img_mask)  
        #img_mask = img_mask.reshape(crop_size, crop_size)
        img_mask = img_mask.to(device) 
        
        print(gt_mask.shape)
        print(img_mask.shape)
        
        combined_mask = img_mask*gt_mask
      
        masked_pred = pred * combined_mask.cpu().numpy()
        masked_gt_e = gt_e * combined_mask.cpu().numpy()
        all_preds.append(pred)
        all_gts.append(gt_e)
        
        all_masked_preds.append(masked_pred)
        all_masked_gts.append(masked_gt_e)
        

        clear_output()

        metrics(masked_pred.ravel(), masked_gt_e.ravel())
        # print(f"RSME: {rmse*-norm_param_depth}")
        rmse = metrics(np.concatenate([p.ravel() for p in all_masked_preds]), np.concatenate([p.ravel() for p in all_masked_gts]).ravel())

    # Returning all predictions and ground truths if 'all' is set to True
    if all:
        return rmse, all_preds, all_gts
    else:
        return rmse  # Returning the final MSE for the test set

# Training the network

In [16]:
crop_size = 256
epoch_folder = '/.../'

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, output, depth, mask):
        # Mask out areas with no annotations
        mse_loss = nn.MSELoss(reduction='none')

        loss = mse_loss(output, depth)
        loss = (loss * mask.float()).sum() # gives \sigma_euclidean over unmasked elements

        non_zero_elements = mask.sum()
        rmse_loss_val = torch.sqrt(loss / non_zero_elements)


        return rmse_loss_val



def train(net, optimizer, epochs, scheduler=None, save_epoch = 15):
    global epoch_folder
    global data_folder
    losses = np.zeros(10000000)
    mean_losses = np.zeros(100000000)
    mean_rmse_plot = np.zeros(1000000)
    mean_mse_plot = np.zeros(1000000)
    rmse_plot = np.zeros(1000000)
    mse_plot = np.zeros(1000000)
    epoch_folder = epoch_folder
    criterion = CustomLoss()
    iter_ = 0

    for e in range(1, epochs + 1):
        if scheduler is not None:
            scheduler.step()
        net.train()
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = Variable(data.to(device)), Variable(target.to(device))
            optimizer.zero_grad()

            size=(256, 256)
            
            # Resizing data_p and label_p
            data = F.interpolate(data, size=size, mode='nearest')
            target = F.interpolate(target.unsqueeze(0), size=size, mode='nearest')
            
            #target = target.unsqueeze(0) #needed for aerial
            
            data_size = data.size()[2:]  # Get the original data size

            if data_size[0] > crop_size and data_size[1] > crop_size:
                    # Use RandomCrop transformation for data and target
                data_transform = RandomCrop(size=crop_size)
                target_transform = RandomCrop(size=crop_size)
    
                    # Apply RandomCrop transformation to data and target
                data = data_transform(data)
                target = target_transform(target)
                
            # Generate mask for non-annotated pixels in depth data
            target_mask = (target.cpu().numpy() > 0).astype(np.float32)  
            target_mask = torch.from_numpy(target_mask)  
            target_mask = target_mask.reshape(crop_size, crop_size)
            target_mask = target_mask.to(device)  
            
            data_mask = (data.cpu().numpy() != 0).astype(np.float32)  
            data_mask = np.mean(data_mask, axis=1)
            data_mask = torch.from_numpy(data_mask) 
            #data_mask = data_mask.reshape(crop_size, crop_size)
            data_mask = data_mask.to(device) 
            
            # Combine the masks
            combined_mask = target_mask * data_mask
            # Check if combined_mask is 0
            if torch.sum(combined_mask) == 0:
            # Use another pair of data and target
                continue

            data = torch.clamp(data, min=0, max=1)
            output = net(data.float())

            loss = criterion(output, target, combined_mask)
            loss.backward()
            optimizer.step()
            losses[iter_] = loss.item() ##loss.data[0]
            mean_losses[iter_] = np.mean(losses[max(0,iter_-100):iter_])
            
            pred = output.data.cpu().numpy()[0]
            gt = target.data.cpu().numpy()[0]
            
            # Apply the mask to the predictions and ground truth
            masked_pred = pred * combined_mask.cpu().numpy()
            masked_gt = gt * combined_mask.cpu().numpy()

            rmse_plot[iter_] = metrics(np.concatenate([p.ravel() for p in masked_pred]), np.concatenate([p.ravel() for p in masked_gt]).ravel())
            mean_rmse_plot[iter_] = np.mean(rmse_plot[max(0,iter_-100):iter_])
            
            if iter_ % 100 == 0:
                if iter_ % 1000 == 0 and iter_ != 0:
                    try:
                        os.mkdir(DATA_FOLDER)
                    except FileExistsError:
                        pass
                clear_output()
                rgb = np.asarray(np.transpose(data.data.cpu().numpy()[0],(1,2,0)), dtype='float32')
                pred = output.data.cpu().numpy()[0]
                gt = target.data.cpu().numpy()[0]
                masked_pred = pred * combined_mask.cpu().numpy()
                masked_gt = gt * combined_mask.cpu().numpy()
                
                print('Train (epoch {}/{}) [{}/{} ({:.0f}%)]\tLoss: {:.6f}\Mean RMSE in m: {}'.format(
                    e, epochs, batch_idx, len(train_loader),
                    100. * batch_idx / len(train_loader), loss.item(), -norm_param_depth * metrics(np.concatenate([p.ravel() for p in masked_pred]), np.concatenate([p.ravel() for p in masked_gt]).ravel()))) ##loss.data[0]
               
                # Plot loss
                fig1, ax1 = plt.subplots(figsize=(14.0, 8.0))
                ax1.plot(mean_losses[:iter_], 'blue')
                ax1.set_title('Training Loss')
                ax1.set_xlabel('Iteration')
                ax1.set_ylabel('Loss')
                ax1.grid(color='black', linestyle='-', linewidth=0.5)

                # Plot accuracy
                fig2, ax2 = plt.subplots(figsize=(14.0, 8.0))
                ax2.plot(-norm_param_depth * mean_rmse_plot[:iter_], 'red')
                ax2.set_title('Mean RMSE in m')
                ax2.set_xlabel('Iteration')
                ax2.set_ylabel('Accuracy')
                ax2.grid(color='black', linestyle='-', linewidth=0.5)
    
                plt.show()
    
                if iter_ % 1000 == 0 and iter_ != 0:
                    fig1.savefig(DATA_FOLDER + "/train_loss_{}_out_of_{}".format(e, epochs))
                    fig2.savefig(DATA_FOLDER + "/validation_accuracy_{}_out_of_{}".format(e, epochs))
            
           
        
                
                
                fig = plt.figure(figsize=(14.0, 8.0))
                fig.add_subplot(131)
                plt.imshow(rgb)
                plt.title('RGB')
                fig.add_subplot(132)
                plt.imshow(gt[0,:,:], cmap='viridis_r', vmin=0, vmax=1)
                plt.title('Ground truth')
                fig.add_subplot(133)
                plt.title('Prediction')
                plt.imshow(masked_pred[0,:,:],  cmap='viridis_r', vmin=0, vmax=1)
                plt.suptitle('Train (epoch {}/{}) [{}/{} ({:.0f}%)]\nLoss: {:.6f}'.format(
                    e, epochs, batch_idx, len(train_loader),
                    100. * batch_idx / len(train_loader), loss.item()))
                plt.show()

                if iter_ % 1000 == 0 and iter_ != 0:
                    # plt.savefig(MAIN_FOLDER + model_folder +"output_data_filled_irfanview_no_shades_10/diagram_{}_out_of_{}".format(e,epochs))
                    fig.savefig(DATA_FOLDER + "/train_images_{}_out_of_{}".format(e, epochs))
                    # plt.savefig("Train_epoch_{}/{}_{}/{}_({:.0f}%).png".format(e, epochs, batch_idx, len(train_loader), 100. * batch_idx / len(train_loader)))
            iter_ += 1
            
            del(data, target, loss)      
            
        if e % save_epoch == 0:
            try:
                os.mkdir(epoch_folder)
            except FileExistsError:
                pass

            # We validate with the largest possible stride for faster computing
            acc = test(net, test_ids, all=False)
            torch.save(net.state_dict(),epoch_folder + 'model_epoch{}'.format(e))
    torch.save(net.state_dict(), epoch_folder + 'model_final')

In [None]:
train(net, optimizer, 10, scheduler)


# Testing the network


In [None]:
net.load_state_dict(torch.load('./model_final'))

In [None]:
_, all_preds, all_gts= test(net, test_ids, all=True)
#print(all_preds)
#print(all_gts)


# Saving the results

We can visualize and save the resulting tiles for qualitative assessment.


In [None]:
ratio = crop_size / WINDOW_SIZE[0]

for p, id_ in zip(all_preds, test_ids):
    img = p*norm_param_depth
    
    img = scipy.ndimage.zoom(img, (1/ratio, 1/ratio), order=1)
    
    #print(img)
    plt.imshow(img, cmap='viridis') and plt.show()

    #io.imsave('/home/pagraf/Desktop/magicbathy/inference_tile_{}.png'.format(id_), img)
    nlcd02_arr_1, nlcd02_ds_1 = read_geotiff(MAIN_FOLDER + 'agia_napa/img/spot6/img_410.tif', 3)
    write_geotiff('./inference_tile_{}.tif'.format(id_), img, nlcd02_ds_1)