# Evaluation Notebook

## Setup

In [None]:
!pip install pandas seaborn torch torchvision torchsummary torchtext pytorch_lightning tensorboard matplotlib tqdm datetime time 

## Download Data

either provide a download link here

## Your Plots and Results

In [None]:
from src.dataloader import LandCoverData

import torchvision.transforms as T
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as clr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import ListedColormap
import numpy as np
from tqdm import tqdm
import os

import torch
from torch.utils.data import DataLoader

from torchmetrics.classification import MulticlassF1Score
from torchmetrics.classification import MulticlassConfusionMatrix
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision.models.segmentation import deeplabv3_resnet101

import time
import datetime

from src.model import UNet
from src.dataloader import LandCoverData, transformsNorm

import seaborn as sn
import pandas as pd

%load_ext autoreload
%autoreload 2

# 1. Data augmentation

## If you want to see the result of our data augmentation run the following cells,
## Otherwise jump to the part 2.Test Model

In [None]:
"""

YOU NEED TO DOWNLOAD THE DATA IN THE ipeo_data/ folder.
SEE README OF THE REPO GITHUB.

"""

In [None]:
path="data/ipeo_data/"
path_augmented_rgb ="data/ipeo_data/augmented_data_rgb/"
path_augmented_label ="data/ipeo_data/augmented_data_label/"

In [None]:
dataset2 = LandCoverData(path, transforms=None, split="train", ignore_last_number=11, use_augmented=False)

In [None]:
# Water and wetlands is less represented, -> index=4
for index_augmentation in range(1, 6):
    toPIL = T.ToPILImage()

    tt_class=np.zeros(len(dataset2.LABEL_CLASSES))
    index_label=0
    for img, label in tqdm(dataset2):
        class_list, class_count = torch.unique(label, return_counts=True) 
        if index_augmentation in class_list:
            #print(index_label)
            #fig, ax = plt.subplots(1,8, figsize=(16, 20))
            #ax[0].imshow(toPIL(img))
            #ax[1].imshow(toPIL(label), cmap=cmap, vmin=0, vmax=len(dataset2.LABEL_CLASSES))
            for i in range(1,4):
                rotation_deg = i * 90 #Only performs flip
                path = "/"
                name_image = f"{index_label}_{i}_rgb"
                name_label = f"{index_label}_{i}_label"

                rotated = toPIL(img).rotate(rotation_deg,expand=0)
                #ax[i*2].imshow(rotated)
                rotated.save(path_augmented_rgb+name_image+".tif")
                rotated = toPIL(label).rotate(rotation_deg,expand=0)
                rotated.save(path_augmented_label+name_label+".tif")
                #ax[i*2+1].imshow(rotated, cmap=cmap, vmin=0, vmax=len(dataset2.LABEL_CLASSES))
        index_label+=1

In [None]:
dataset_final = LandCoverData(path, transforms=None, split="train", ignore_last_number=11, use_augmented=True)

In [None]:
tt_class=np.zeros(len(dataset_final.LABEL_CLASSES))
for _, label in tqdm(dataset_final):
    class_list, class_count = torch.unique(label, return_counts=True)  
    for i, c in enumerate(class_list):
        tt_class[c]+=class_count[i]

In [None]:
plt.style.use('ggplot')

labels = [k for k in dataset_final.LABEL_CLASSES.keys()]
indexes = np.arange(len(labels))

rescale = lambda indexes: (indexes - np.min(indexes)) / (np.max(indexes) - np.min(indexes))

width = 0.8
plt.barh(indexes, tt_class, color=cmap(rescale(indexes)))
plt.yticks(indexes, labels)
plt.xlabel('Pixels count', fontsize=16)
plt.ylabel('Class', fontsize=16)
plt.title('Barchart - Frequency of each class',fontsize=20)
plt.show()

# 2. Test model

In [None]:
"""

YOU NEED TO DOWNLOAD THE MODEL IN THE checkpoints/ folder.
SEE README OF THE REPO GITHUB.

"""

unet = UNet(nbClasses=8)

# To use DeepLabV3 pretrained model:
#unet = deeplabv3_resnet101(pretrained=True, progress=True)
#unet.classifier = DeepLabHead(2048, 8)

PATH_MODEL = 'checkpoints/best_model_acc_cross_entropy_Batch_32_loss.pth'

unet.load_state_dict(torch.load(PATH_MODEL, map_location=torch.device('cpu')))

unet.eval()

In [None]:
"""

YOU NEED TO DOWNLOAD THE TEST DATA IN THE ipeo_data/ folder.
SEE README OF THE REPO GITHUB.

"""

path="data/ipeo_data/"

transformsData, unnormalize = transformsNorm(flag_plot=True)

test_dataset = LandCoverData(path=path, 
                             transforms=transformsData,
                             split="test")

In [None]:
cmap = clr.LinearSegmentedColormap.from_list('custom_datacolor', test_dataset.colormap_names, N=256)
#cmap='viridis'
#cmap=dataset.colormap

# plot individual samples
from ipywidgets import widgets
from ipywidgets import interact
%matplotlib inline

style = {'description_width': 'initial'}
widget=widgets.BoundedIntText(
    value=0,
    min=0,
    max=len(test_dataset),
    step=1,
    description=f"Index data: (max={len(test_dataset)})",
    disabled=False,
    style=style
)
widget_cb=widgets.Checkbox(
    value=False,
    description='Colorbar',
    disabled=False,
    indent=False
)

@interact(idx=widget, flag_colorbar=widget_cb)
def plot_sample(idx=0, flag_colorbar=False):
    roffset=1
    if flag_colorbar:
        roffset=1.07
    fig, ax = plt.subplots(1,3, figsize=(12, 10), gridspec_kw={'width_ratios': [1, 1, roffset]})    
    
    class_mapping = {v: k for k, v in test_dataset.LABEL_CLASSES.items()}  
    
    x, y = test_dataset[idx]
    x=x.unsqueeze(dim=0)
    
    softmax = torch.nn.Softmax(dim=1)
    # For DeepLabV3:
    #preds = torch.argmax(softmax(unet(x)['out']),axis=1)
    # Otherwise
    preds = torch.argmax(softmax(unet(x)),axis=1)
    x = unnormalize(x)
    img = np.transpose(np.array(x[0,:,:]),(1,2,0))
    preds = np.array(preds[0,:,:])
    mask = np.array(y[0,:,:])
    
    ax[0].imshow(img)
    ax[1].imshow(preds, cmap=cmap, vmin=0, vmax=len(test_dataset.LABEL_CLASSES))
    pim=ax[2].imshow(mask, cmap=cmap, vmin=0, vmax=len(test_dataset.LABEL_CLASSES))
    
    if flag_colorbar:
        class_list, _ = torch.unique(y[0,:,:], return_counts=True)
        divider = make_axes_locatable(ax[2])
        cax = divider.append_axes("right", size="5%", pad=0.1)
        cbar = fig.colorbar(pim, cax=cax, ax=ax.ravel().tolist())
        if cmap==test_dataset.colormap:
            cbar.set_ticks([i.item()+0.5 for i in class_list])
        else:
            cbar.set_ticks([i.item() for i in class_list])
        cbar.set_ticklabels([class_mapping[i.item()] for i in class_list])

    ax[0].axis("off")
    ax[1].axis("off")
    ax[2].axis("off")
    ax[0].set_title(f"Image")
    ax[1].set_title(f"Prediction")
    ax[2].set_title(f"Ground Truth")
    
    plt.show()

# Confusion Matrix

In [None]:
DEVICE = 'cuda'

#unet=unet.to(DEVICE)
unet = UNet(nbClasses=8).to(DEVICE)
unet.load_state_dict(torch.load(PATH_MODEL, map_location=torch.device(DEVICE)))

y_pred = torch.zeros(0).to(DEVICE)
y_true = torch.zeros(0).to(DEVICE)

softmax = torch.nn.Softmax(dim=1)

# iterate over test data
for x, labels in tqdm(test_dataset):
    (x, labels) = (x.to(DEVICE), labels.to(DEVICE))
    x=x.unsqueeze(dim=0)
    
    preds = torch.argmax(softmax(unet(x)),axis=1)
    y_pred = torch.cat([y_pred,preds]) # Save Prediction
        
    #labels = labels.data.cpu().numpy()
    y_true = torch.cat([y_true,labels]) # Save Truth

In [None]:
metric = MulticlassConfusionMatrix(num_classes=8, normalize='none').to(DEVICE)
cf_matrix = metric(y_pred, y_true)
cf_matrix_np = cf_matrix.cpu().numpy()

# constant for classes
classes = ('Grass and other', 'Wald',
           'Bushes and sparse forest', 'Water and wetlands',
           'Glaciers and permanent snow', 'Sparse rocks (rocks mixed with grass)',
           'Loose rocks, scree', 'Bed rocks')

plt.figure(figsize = (12,9))
df_cm = pd.DataFrame(cf_matrix_np, index = [i for i in classes], columns = [i for i in classes])
sn.heatmap(df_cm, annot=True)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.tight_layout()

In [None]:
per_class_acc = (cf_matrix.diag()/cf_matrix.sum(1)).cpu().numpy()
plt.figure(figsize = (12,1))
plt.tight_layout()
df_per_class_acc = pd.DataFrame([per_class_acc], columns = [i for i in classes])
sn.heatmap(df_per_class_acc, annot=True)
plt.tight_layout()