In [None]:
# Pixel-based classification on aerial/satellite images of shallow waters

# Initial Pytorch Implementation: Panagiotis Agrafiotis (https://github.com/pagraf)
# Email: agrafiotis.panagiotis@gmail.com

# Description: For the pixel-based classification problem we used U-Net, a well known FCN, and 
# SegFormer (B5-sized),which is an hierarchical Transformer encoder with an all-MLP decode head. 
# Both models proved already their high performance on similar tasks. 

# If you use this code please cite our paper: "Agrafiotis, P., Janowski, L., Skarlatos, D. & Demir, B. (2024) 
# MagicBathyNet: A Multimodal Remote Sensing Dataset for Bathymetry Prediction and Pixel-based Classification 
# in Shallow Waters, arXiv preprint arXiv:2405.15477, 2024."

# This .ipynb was structured inspired by the [DeepNetsForEO] (https://github.com/nshaud/DeepNetsForEO) repository.





# Attribution-NonCommercial-ShareAlike 4.0 International License

# Copyright (c) 2024 Panagiotis Agrafiotis

# This license requires that reusers give credit to the creator. It allows reusers 
# to distribute, remix, adapt, and build upon the material in any medium or format,
# for noncommercial purposes only. If others modify or adapt the material, they 
# must license the modified material under identical terms.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


# This work is part of MagicBathy project funded by the European Union’s HORIZON Europe research and innovation 
# programme under the Marie Skłodowska-Curie GA 101063294. Work has been carried out at the Remote Sensing Image 
# Analysis group. For more information about the project visit https://www.magicbathy.eu/.

In [None]:
## GPU

#Enable GPU with `Runtime->Change runtime type->Hardware Accelerator->GPU` in the top menu

In [None]:
# Core imports
import sys
import time
import numpy as np
from glob import glob
from tqdm.notebook import tqdm  # Updated for modern `tqdm`
import random
import itertools

# Image and plotting
from skimage import io
import matplotlib.pyplot as plt
%matplotlib inline

# Torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader  # Use specific imports if needed
from torch.optim import Adam, lr_scheduler  # Import specific optimizers/schedulers
from torch.autograd import Variable

# Geospatial and other libraries
import rasterio
import gdal
import scipy  # Consider importing only required submodules

# Semantic Segmentation Models
from semanticsegmentation.models import *
from semanticsegmentation.models.backbones import MiT
from semanticsegmentation.models.segformer import SegFormer

In [None]:

# data are assumed to be in Vaihingen/ directory, otherwise modify accordingly
MAIN_FOLDER = '.../MagicBathyNet/'
sys.path.append(MAIN_FOLDER)

#RGB images and gts folders
DATA_FOLDER = MAIN_FOLDER + 'agia_napa/img/aerial/img_{}.tif'
LABEL_FOLDER = MAIN_FOLDER + 'agia_napa/gts/aerial/gts_{}.tif'
ERODED_FOLDER = MAIN_FOLDER + 'agia_napa/gts/aerial/gts_{}.tif'

# Parameters

In [None]:
# Parameters

#dataset = "S2"
#dataset = "SPOT6"
dataset = "UAV"

#area = "puck_lagoon"
area = "agia_napa"

if dataset == "UAV":

    WINDOW_SIZE = (384, 384)
    STRIDE = 32
    BATCH_SIZE = 15
    norm_param = np.load('norm_param_aerial.npy')
    train_images = ['409', '418', '350', '399', '361', '430', '380', '359', '371', '377', '379', '360', '368', '419', '389', '420', '401', '408', '352', '388', '362', '421', '412', '351', '349', '390', '400', '378']
    test_images = ['411', '387', '410', '398', '370', '369', '397']
    if WINDOW_SIZE[0] < 128:
        scale = 128/WINDOW_SIZE[0]
        test_window = 128
    else:
        scale = 1
        test_window = WINDOW_SIZE[0]
    inv_scale = 1/scale
elif dataset == "SPOT6":
    norm_param = np.load('norm_param_spot6_an2.npy')
    WINDOW_SIZE = (16, 16)
    STRIDE = 1
    BATCH_SIZE = 20
    MAIN_FOLDER = FOLDER 
    train_images = ['409', '418', '350', '399', '361', '430', '380', '359', '371', '377', '379', '360', '368', '419', '389', '420', '401', '408', '352', '388', '362', '421', '412', '351', '349', '390', '400', '378']
    test_images = ['411', '387', '410', '398', '370', '369', '397']
    
    if WINDOW_SIZE[0] < 256:
        scale = 256/WINDOW_SIZE[0]
        test_window = 256
    else:
        scale = 1
        test_window = WINDOW_SIZE[0]
    inv_scale = 1/scale
elif dataset == "S2":
    norm_param = np.load('norm_param_s2_an.npy')
    WINDOW_SIZE = (16, 16)
    STRIDE = 2
    BATCH_SIZE = 20
    MAIN_FOLDER = FOLDER 
    train_images = ['409', '418', '350', '399', '361', '430', '380', '359', '371', '377', '379', '360', '368', '419', '389', '420', '401', '408', '352', '388', '362', '421', '412', '351', '349', '390', '400', '378']
    test_images = ['411', '387', '410', '398', '370', '369', '397']
    if WINDOW_SIZE[0] < 256:
        scale = 256/WINDOW_SIZE[0]
        test_window = 256
    else:
        scale = 1
        test_window = WINDOW_SIZE[0]
    inv_scale = 1/scale

print(norm_param)

if area == "agia_napa":
    #net = SegFormer('MiT-B5', 5)
    net = UNet(3, 5)
    LABELS = ["poseidonia", "rock", "macroalgae", "sand", "ignored"] # Label names
    N_CLASSES = len(LABELS) # Number of classes
    print(N_CLASSES)
    WEIGHTS = torch.ones(N_CLASSES) # Weights for class balancing
    l =1.
    WEIGHTS[4] = 0.
    WEIGHTS[0] = l
    WEIGHTS[1] = l
    WEIGHTS[2] = l
    WEIGHTS[3] = l
    print(WEIGHTS)

if area == "puck_lagoon":
    net = SegFormer('MiT-B5', 3)
    #net = UNet(3, 3)
    LABELS = ["sand", "poseidonia", "ignored"] # Label names
    N_CLASSES = len(LABELS) # Number of classes
    print(N_CLASSES)
    WEIGHTS = torch.ones(N_CLASSES) # Weights for class balancing
    l =1.
    WEIGHTS[0] = l
    WEIGHTS[1] = l
    WEIGHTS[2] = 0.

print(WEIGHTS)


base_lr = 0.000001


CACHE = True # Store the dataset in-memory
epsilon = 1e-20  # Small positive value



# Visualizing the dataset

In [None]:
# color palette

if area == "agia_napa":
    palette = {0 : (0, 128, 0),     #poseidonia
               1 : (0, 0, 255),     #rock
               2 : (255, 0, 0),     #macroalgae
               3 : (255, 128, 0),   #sand
               4 : (0, 0, 0)}       #Undefined (black)

if area == "puck_lagoon":
    palette = {0 : (255, 128, 0),   #sand
               1 : (0, 128, 0),     #poseidonia
               2 : (0, 0, 0)}       #Undefined (black)


invert_palette = {v: k for k, v in palette.items()}

def convert_to_color(arr_2d, palette=palette):
    """ Numeric labels to RGB-color encoding """
    arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8)

    for c, i in palette.items():
        m = arr_2d == c
        arr_3d[m] = i

    return arr_3d

def convert_from_color(arr_3d, palette=invert_palette):
    """ RGB-color encoding to grayscale labels """
    arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)

    for c, i in palette.items():
        m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
        arr_2d[m] = i

    return arr_2d

# We load one tile from the dataset and we display it
img = io.imread(MAIN_FOLDER+'agia_napa/img/aerial/img_410.tif')
fig = plt.figure()
fig.add_subplot(121)
norm_img = (img - norm_param[0]) / (norm_param[1] - norm_param[0]) 
plt.imshow(norm_img)

# We load the ground truth
gt = io.imread(MAIN_FOLDER+'agia_napa/gts/aerial/gts_410.tif')
fig.add_subplot(122)
plt.imshow(gt)
plt.show()

# We also check that we can convert the ground truth into an array format
array_gt = convert_from_color(gt)
print("Ground truth in numerical format has shape ({},{}) : \n".format(*array_gt.shape[:2]), array_gt)

We need to define a bunch of utils functions.

In [5]:
# Utils

def get_random_pos(img, window_shape):
    """ Extract of 2D random patch of shape window_shape in the image """
    w, h = window_shape
    W, H = img.shape[-2:]
    x1 = random.randint(0, W - w - 0)
    x2 = x1 + w
    y1 = random.randint(0, H - h - 0)
    y2 = y1 + h
    return x1, x2, y1, y2

def CrossEntropy2d(input, target, weight=None, reduction=True):
    """ 2D version of the cross entropy loss """
    dim = input.dim()
    if dim == 2:
        return F.cross_entropy(input, target, weight, reduction='mean')
##        return nn.CrossEntropyLoss(output, target,weight)
    elif dim == 4:
        output = input.view(input.size(0),input.size(1), -1)
        output = torch.transpose(output,1,2).contiguous()
        output = output.view(-1,output.size(2))
        target = target.view(-1)
        return F.cross_entropy(output, target,weight, reduction='mean')
##        return nn.CrossEntropyLoss(output, target,weight)
    else:
        raise ValueError('Expected 2 or 4 dimensions (got {})'.format(dim))

#def accuracy(input, target):
#    return 100 * float(np.count_nonzero(input == target)) / target.size

def sliding_window(top, step=10, window_size=(20,20)):
    """ Slide a window_shape window across the image with a stride of step """
    for x in range(0, top.shape[0], step):
        if x + window_size[0] > top.shape[0]:
            x = top.shape[0] - window_size[0]
        for y in range(0, top.shape[1], step):
            if y + window_size[1] > top.shape[1]:
                y = top.shape[1] - window_size[1]
            yield x, y, window_size[0], window_size[1]
            
def count_sliding_window(top, step=10, window_size=(20,20)):
    """ Count the number of windows in an image """
    c = 0
    for x in range(0, top.shape[0], step):
        if x + window_size[0] > top.shape[0]:
            x = top.shape[0] - window_size[0]
        for y in range(0, top.shape[1], step):
            if y + window_size[1] > top.shape[1]:
                y = top.shape[1] - window_size[1]
            c += 1
    return c

def grouper(n, iterable):
    """ Browse an iterator by chunk of n elements """
    it = iter(iterable)
    while True:
        chunk = tuple(itertools.islice(it, n))
        if not chunk:
            return
        yield chunk

def metrics(predictions, gts, label_values=LABELS[:-1]):
    cm = confusion_matrix(
            gts,
            predictions)
    
    cols_rows_to_use = N_CLASSES-1
    cm = cm[:cols_rows_to_use,:cols_rows_to_use]
    
    print("Confusion matrix :")
    print(cm)
    
    print("---")
    
    # Compute global accuracy
    total = sum(sum(cm))
    accuracy = sum([cm[x][x] for x in range(len(cm))])

    accuracy *= 100 / (float(total)+ epsilon)


    #accuracy *= 100 / float(total)
    print("{} pixels processed".format(total))
    print("Total accuracy : {}%".format(accuracy))
    
    print("---")
    
    # Compute F1 score
    F1Score = np.zeros(len(label_values))
    for i in range(len(label_values)):
        try:
            F1Score[i] = 2. * cm[i,i] / (np.sum(cm[i,:]) + np.sum(cm[:,i]) + epsilon) 
        except:
            # Ignore exception if there is no element in class i for test set
            pass
    print("F1Score :")
    for l_id, score in enumerate(F1Score):
        print("{}: {}".format(label_values[l_id], score))

    print("---")
        
    # Compute kappa coefficient
    total = np.sum(cm)
    pa = np.trace(cm) / float(total)
    pe = np.sum(np.sum(cm, axis=0) * np.sum(cm, axis=1)) / float(total*total)
    kappa = (pa - pe) / (1 - pe);
    print("Kappa: " + str(kappa))
    return accuracy


def metrics_2(predictions, gts, label_values=LABELS[:-1]):
    cm = confusion_matrix(
            gts,
            predictions)
    
    cols_rows_to_use = N_CLASSES-1
    cm = cm[:cols_rows_to_use,:cols_rows_to_use]
       
    # Compute global accuracy
    total = sum(sum(cm))
    accuracy = sum([cm[x][x] for x in range(len(cm))])
    
    accuracy *= 100 / (float(total)+ epsilon)

    
    # Compute F1 score
    F1Score = np.zeros(len(label_values))
    for i in range(len(label_values)):
        try:
            F1Score[i] = 2. * cm[i,i] / (np.sum(cm[i,:]) + np.sum(cm[:,i]) + epsilon)
        except:
            # Ignore exception if there is no element in class i for test set
            pass
    
    return accuracy

def read_geotiff(filename, b):
    ds = gdal.Open(filename)
    band = ds.GetRasterBand(b)
    arr = band.ReadAsArray()
    return arr, ds

def write_geotiff(filename, arr, in_ds):
    if arr.dtype == np.float32:
        arr_type = gdal.GDT_Float32
    else:
        arr_type = gdal.GDT_Int32

    driver = gdal.GetDriverByName("GTiff")
    out_ds = driver.Create(filename, arr.shape[1], arr.shape[0], 1, arr_type)
    out_ds.SetProjection(in_ds.GetProjection())
    out_ds.SetGeoTransform(in_ds.GetGeoTransform())
    band = out_ds.GetRasterBand(1)
    band.WriteArray(arr)
    band.FlushCache()
    band.ComputeStatistics(False)


# Loading the dataset

We define a PyTorch dataset (`torch.utils.data.Dataset)` that loads all the tiles in memory and performs random sampling. Tiles are stored in memory on the fly.

The dataset also performs random data augmentation (horizontal and vertical flips) and normalizes the data in [0, 1].

In [None]:
# Dataset class
from torchvision.transforms import Resize

random.seed(1)
        
class dataset(torch.utils.data.Dataset):
    def __init__(self, ids, data_files=DATA_FOLDER, label_files=LABEL_FOLDER,
                            cache=False, augmentation=True):
        super(dataset, self).__init__()
        
        self.augmentation = augmentation
        self.cache = cache
        
        # List of files
        self.data_files = [DATA_FOLDER.format(id) for id in ids]
        self.label_files = [LABEL_FOLDER.format(id) for id in ids]
        
        

        # Sanity check : raise an error if some files do not exist
        for f in self.data_files + self.label_files:
            if not os.path.isfile(f):
                raise KeyError('{} is not a file !'.format(f))
        
        # Initialize cache dicts
        self.data_cache_ = {}
        self.label_cache_ = {}
            
    
    def __len__(self):
        # Default epoch size is 10 000 samples
        return 10000
    
    @classmethod
    def data_augmentation(cls, *arrays, flip=True, mirror=True):
        will_flip, will_mirror = False, False
        if flip and random.random() < 0.5:
            will_flip = True
        if mirror and random.random() < 0.5:
            will_mirror = True
        
        results = []
        for array in arrays:
            if will_flip:
                if len(array.shape) == 2:
                    array = array[::-1, :]
                else:
                    array = array[:, ::-1, :]
            if will_mirror:
                if len(array.shape) == 2:
                    array = array[:, ::-1]
                else:
                    array = array[:, :, ::-1]
            results.append(np.copy(array))
            
        return tuple(results)
    
    def __getitem__(self, i):
        
        # Pick a random image
        random_idx = random.randint(0, len(self.data_files) - 1)
        
        
        # If the tile hasn't been loaded yet, put in cache
        if random_idx in self.data_cache_.keys():
            data = self.data_cache_[random_idx]
        else:
            # Data is normalized in [0, 1]
            #if dataset == "S2":
            data = np.asarray(io.imread(self.data_files[random_idx]).transpose((2,0,1)), dtype='float32')
            data = (data - norm_param[0][:, np.newaxis, np.newaxis]) / (norm_param[1][:, np.newaxis, np.newaxis] - norm_param[0][:, np.newaxis, np.newaxis]) 
            #else:
             #   data = np.asarray(io.imread(self.data_files[random_idx]).transpose((2,0,1)), dtype='float32') / norm_param          
                       
  
            if self.cache:
                self.data_cache_[random_idx] = data
            
        if random_idx in self.label_cache_.keys():
            label = self.label_cache_[random_idx]
        else: 
            # Labels are converted from RGB to their numeric values
            label = np.asarray(convert_from_color(io.imread(self.label_files[random_idx])), dtype='int64')
            if self.cache:
                self.label_cache_[random_idx] = label
        
        if dataset == "UAV":
            data[data == 0] = 0.5   
        if dataset == "SPOT6":
            data[data == 0] = 0.5
        if dataset == "S2":
            data[data == 0] = 0.5


        x1, x2, y1, y2 = get_random_pos(data, WINDOW_SIZE)
        data_p = data[:, x1:x2,y1:y2]
        label_p = label[x1:x2,y1:y2]

        data_p, label_p = self.data_augmentation(data_p, label_p)
   
        data_p = scipy.ndimage.zoom(data_p, (1,scale,scale), order=0)
        label_p = scipy.ndimage.zoom(label_p, scale, order=0)
   
        
        return (torch.from_numpy(data_p),
                torch.from_numpy(label_p))

# Load the Network on GPU

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.init()
net = net.to(device)
torch.cuda.is_available()


# Loading the data

We now create a train/test split. In our case, we specify a fixed train/test split for benchmarking MagicBathyNet.


In [None]:
import os

# Load the datasets
all_files = sorted(glob(LABEL_FOLDER.replace('{}', '*')))
all_ids = [f.split('area')[-1].split('.')[0] for f in all_files]

train_ids = train_images
test_ids = test_images

print("Tiles for training : ", train_ids)
print("Tiles for testing : ", test_ids)


train_set = dataset(train_ids, cache=CACHE)
train_loader = torch.utils.data.DataLoader(train_set,batch_size=BATCH_SIZE)


# Designing the optimizer
We use the `torch.optim.lr_scheduler` to reduce the learning rate by 10 after 10 or so epochs.

In [9]:

params_dict = dict(net.named_parameters())
params = []
for key, value in params_dict.items():
    if '_D' in key:
        params += [{'params':[value],'lr': base_lr}]
    else:
        params += [{'params':[value],'lr': base_lr}]  

optimizer = optim.SGD(net.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0005)
# We define the scheduler
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [200], gamma=0.1)

In [10]:
def test(net, test_ids, all=False, stride=STRIDE, batch_size=BATCH_SIZE, window_size=(test_window,test_window)):
    # Use the network on the test set
    test_images = ((np.asarray(io.imread(DATA_FOLDER.format(id)), dtype='float32') - norm_param[0]) / (norm_param[1] - norm_param[0]) for id in test_ids)
    test_labels = (np.asarray(io.imread(LABEL_FOLDER.format(id)), dtype='uint8') for id in test_ids)
    eroded_labels = (convert_from_color(io.imread(ERODED_FOLDER.format(id))) for id in test_ids)
    all_preds = []
    all_gts = []
    
    # Switch the network to inference mode
    net.eval()

    for img, gt, gt_e in tqdm(zip(test_images, test_labels, eroded_labels), total=len(test_ids), leave=False):
        pred = np.zeros(img.shape[:2] + (N_CLASSES,))
        
        print(gt.shape)

        gt = scipy.ndimage.zoom(gt, (scale,scale,1), order=0)
        #gt_e = scipy.ndimage.zoom(gt_e, (scale,scale,1), order=0)
        img = scipy.ndimage.zoom(img, (scale,scale,1), order=0)
        img = np.clip(img, 0, 1)
        pred = scipy.ndimage.zoom(pred, (scale,scale,1), order=0)

        total = count_sliding_window(img, step=stride, window_size=window_size) // batch_size
        for i, coords in enumerate(tqdm(grouper(batch_size, sliding_window(img, step=stride, window_size=window_size)), total=total, leave=False)):
            # Display in progress results
            if i > 0 and total > 10 and i % int(10 * total / 100) == 0:
                    _pred = np.argmax(pred, axis=-1)
                    fig = plt.figure()
                    fig.add_subplot(1,3,1)
                    plt.imshow(np.asarray(255 * img, dtype='uint8'))
                    fig.add_subplot(1,3,2)
                    plt.imshow(convert_to_color(_pred))
                    fig.add_subplot(1,3,3)
                    plt.imshow(gt)
                    clear_output()
                    plt.show()
                    
            # Build the tensor
            image_patches = [np.copy(img[x:x+w, y:y+h]).transpose((2,0,1)) for x,y,w,h in coords]
            image_patches = np.asarray(image_patches)
            image_patches = Variable(torch.from_numpy(image_patches).cuda(), volatile=True)
            
            # Do the inference
            outs = net(image_patches.float())
            #outs = net(image_patches)['out'] # Use for ResNet and pytorch ready models
            outs = outs.data.cpu().numpy()
                      
            # Fill in the results array
            for out, (x, y, w, h) in zip(outs, coords):
                out = out.transpose((1,2,0))
                pred[x:x+w, y:y+h] += out
            del(outs)
        
        
        pred = scipy.ndimage.zoom(pred, (inv_scale,inv_scale,1), order=0)
        pred = np.argmax(pred, axis=-1)

        # Display the result
        clear_output()
        fig = plt.figure(figsize=(14.0, 8.0))
        fig.add_subplot(1,3,1)
        plt.imshow(np.asarray(255* img, dtype='uint8'))
        plt.title('RGB Image')
        fig.add_subplot(1,3,2)
        plt.imshow(convert_to_color(pred))
        plt.title('Prediction')
        fig.add_subplot(1,3,3)
        plt.imshow(gt)
        plt.title('Ground Truth')
        plt.show()
        fig.savefig(DATA_FOLDER + "/test_image_result")

        all_preds.append(pred)
        all_gts.append(gt_e)

        clear_output()
        # Compute some metrics
        
        metrics(pred.ravel(), gt_e.ravel())
        accuracy = metrics(np.concatenate([p.ravel() for p in all_preds]), np.concatenate([p.ravel() for p in all_gts]).ravel())
    if all:
        return accuracy, all_preds, all_gts
    else:
        return accuracy

# Training the network


In [11]:
from IPython.display import clear_output
import numpy.ma as ma


def train(net, optimizer, epochs, scheduler=None, weights=WEIGHTS, save_epoch = 200):
    global epoch_folder
    global data_folder
    losses = np.zeros(10000000)
    mean_losses = np.zeros(100000000)
    weights = weights.to(device)
    accuracies_plot = np.zeros(100000000)
    mean_accuracies_plot = np.zeros(1000000)
    epoch_folder = '/home/pagraf/Desktop/magicbathy/'

    #criterion = nn.CrossEntropyLoss(weight=weights)
    iter_ = 0
    
    for e in range(1, epochs + 1):
        if scheduler is not None:
            scheduler.step()
        net.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = Variable(data.to(device)), Variable(target.to(device))
            optimizer.zero_grad()
            
            
            data = torch.clamp(data, min=0, max=1)
            output = net(data.float())

            loss = CrossEntropy2d(output, target, weight=weights)
            
            loss.backward()
            
            optimizer.step()
            
            losses[iter_] = loss.item() ##loss.data[0]
            mean_losses[iter_] = np.mean(losses[max(0,iter_-100):iter_])
            
            pred = np.argmax(output.data.cpu().numpy()[0], axis=0)
            gt = target.data.cpu().numpy()[0]

            accuracies_plot[iter_] = metrics_2(np.concatenate([p.ravel() for p in pred]), np.concatenate([p.ravel() for p in gt]).ravel())/100
            mean_accuracies_plot[iter_] = np.mean(accuracies_plot[max(0,iter_-100):iter_])
        
            if iter_ % 100 == 0:
                if iter_ % 1000 == 0 and iter_ != 0:
                    try:
                        os.mkdir(DATA_FOLDER)
                    except FileExistsError:
                        pass
                clear_output()
                rgb = np.asarray(255 * np.transpose(data.data.cpu().numpy()[0],(1,2,0)), dtype='uint8')
    
                pred = np.argmax(output.data.cpu().numpy()[0], axis=0)
                gt = target.data.cpu().numpy()[0]
                c1 = N_CLASSES - 1
                print(c1)
                gt_sparse = np.delete(gt, np.where(gt == c1))      ##delete background
                pred_sparse = np.delete(pred, np.where(gt == c1))  ##delete background
                print('Train (epoch {}/{}) [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {}'.format(
                    e, epochs, batch_idx, len(train_loader),
                    100. * batch_idx / len(train_loader), loss.item(), metrics_2(np.concatenate([p.ravel() for p in pred]), np.concatenate([p.ravel() for p in gt]).ravel()))) ##loss.data[0]

                # Plot loss
                fig1, ax1 = plt.subplots(figsize=(14.0, 8.0))
                ax1.plot(mean_losses[:iter_], 'blue')
                ax1.set_title('Training Loss')
                ax1.set_xlabel('Iteration')
                ax1.set_ylabel('Loss')
                ax1.grid(color='black', linestyle='-', linewidth=0.5)

                # Plot accuracy
                fig2, ax2 = plt.subplots(figsize=(14.0, 8.0))
                ax2.plot(mean_accuracies_plot[:iter_], 'orange')
                ax2.set_title('Validation Accuracy')
                ax2.set_xlabel('Iteration')
                ax2.set_ylabel('Accuracy')
                ax2.grid(color='black', linestyle='-', linewidth=0.5)
    
                plt.show()
    
                if iter_ % 1000 == 0 and iter_ != 0:
                    fig1.savefig(DATA_FOLDER + "/train_loss_{}_out_of_{}".format(e, epochs))
                    fig2.savefig(DATA_FOLDER + "/validation_accuracy_{}_out_of_{}".format(e, epochs))

                # Plot images
                fig = plt.figure(figsize=(14.0, 8.0))
                fig.add_subplot(131)
                if dataset == "S2":
                    rgb = rgb[:, :, ::-1]
                plt.imshow(rgb)
                plt.title('RGB')
                fig.add_subplot(132)
                plt.imshow(convert_to_color(gt))
                plt.title('Ground truth')
                fig.add_subplot(133)
                plt.title('Prediction')
                plt.imshow(convert_to_color(pred))
                plt.suptitle('Train (epoch {}/{}) [{}/{} ({:.0f}%)]\nLoss: {:.6f}\nAccuracy: {}'.format(
                    e, epochs, batch_idx, len(train_loader),
                    100. * batch_idx / len(train_loader), loss.item(), metrics_2(np.concatenate([p.ravel() for p in pred]), np.concatenate([p.ravel() for p in gt]).ravel())))
                plt.show()

                if iter_ % 1000 == 0 and iter_ != 0:
                    fig.savefig(DATA_FOLDER + "/train_images_{}_out_of_{}".format(e, epochs))

            iter_ += 1

            
            del(data, target, loss)      
            
        if e % save_epoch == 0:
            try:
                os.mkdir(epoch_folder)
            except FileExistsError:
                pass
            
            # We validate with the largest possible stride for faster computing
            acc = test(net, test_ids, all=False, stride=min(WINDOW_SIZE))
            torch.save(net.state_dict(),epoch_folder + 'model_epoch{}'.format(e))
    torch.save(net.state_dict(), epoch_folder + 'model_final')

In [None]:
train(net, optimizer, 100, scheduler)


# Testing the network


In [None]:
net.load_state_dict(torch.load('./model_final'))


In [None]:
_, all_preds, all_gts = test(net, test_ids, all=True, stride=32)


# Saving the results

We can visualize and save the resulting tiles for qualitative assessment.


In [None]:
for p, id_ in zip(all_preds, test_ids):
    img = convert_to_color(p)
    plt.imshow(img) and plt.show()
    io.imsave('/home/pagraf/Desktop/magicbathy/inference_tile_{}.png'.format(id_), img)
    #nlcd02_arr_1, nlcd02_ds_1 = read_geotiff(MAIN_FOLDER + 'puck_laggon/img/s2/img_3051.tif', 3)
    #write_geotiff('./inference_tile_{}.tif'.format(id_), img[1], nlcd02_ds_1)