In [None]:
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.nn import BCELoss
from torch.nn import MSELoss
from torch.optim import Adam
from torch.optim import SGD

from torchmetrics.classification import MulticlassF1Score
from torchmetrics.classification import MulticlassConfusionMatrix

import numpy as np
from tqdm import tqdm
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as clr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import ListedColormap
import time
import datetime

from src.model import UNet
from src.dataloader import LandCoverData

import seaborn as sn
import pandas as pd

%load_ext autoreload
%autoreload 2

In [None]:
#unet = UNet(nbClasses=8)
#unet=torch.load('unet_model_epoch_100_focal_loss_gamma_2_loss.pt', map_location=torch.device('cpu'))
# initialize our UNet model
unet = UNet(nbClasses=8)
unet.load_state_dict(torch.load('unet_model_2023-01-13_0:6_focal_loss_gamma_2_loss.pth', map_location=torch.device('cpu')))
#unet.load_state_dict(torch.load('best_model_fl_4_loss.pth', map_location=torch.device('cpu')))
#unet.load_state_dict(torch.load('unet_model_epoch_200_fl_2_loss.pth', map_location=torch.device('cpu')))

In [None]:
#MASK_DATASET_PATH = "../"
MASK_DATASET_PATH="/scratch/izar/nkaltenr/"
test_dataset = LandCoverData(path=MASK_DATASET_PATH, transforms=None, split="test")

In [None]:
cmap = clr.LinearSegmentedColormap.from_list('custom_datacolor', test_dataset.colormap_names, N=256)
#cmap='viridis'
#cmap=dataset.colormap

In [None]:
# plot individual samples
from ipywidgets import widgets
from ipywidgets import interact
%matplotlib inline

style = {'description_width': 'initial'}
widget=widgets.BoundedIntText(
    value=0,
    min=0,
    max=len(test_dataset),
    step=1,
    description=f"Index data: (max={len(test_dataset)})",
    disabled=False,
    style=style
)
widget_cb=widgets.Checkbox(
    value=False,
    description='Colorbar',
    disabled=False,
    indent=False
)

@interact(idx=widget, flag_colorbar=widget_cb)
def plot_sample(idx=0, flag_colorbar=False):
    roffset=1
    if flag_colorbar:
        roffset=1.07
    fig, ax = plt.subplots(1,3, figsize=(12, 10), gridspec_kw={'width_ratios': [1, 1, roffset]})    
    
    class_mapping = {v: k for k, v in test_dataset.LABEL_CLASSES.items()}  
    
    x, y = test_dataset[idx]
    x=x.unsqueeze(dim=0)
    
    softmax = torch.nn.Softmax(dim=1)
    preds = torch.argmax(softmax(unet(x)),axis=1)
    img = np.transpose(np.array(x[0,:,:]),(1,2,0))
    preds = np.array(preds[0,:,:])
    mask = np.array(y[0,:,:])
    
    ax[0].imshow(img)
    ax[1].imshow(preds, cmap=cmap, vmin=0, vmax=len(test_dataset.LABEL_CLASSES))
    pim=ax[2].imshow(mask, cmap=cmap, vmin=0, vmax=len(test_dataset.LABEL_CLASSES))
    
    if flag_colorbar:
        class_list, _ = torch.unique(y[0,:,:], return_counts=True)
        divider = make_axes_locatable(ax[2])
        cax = divider.append_axes("right", size="5%", pad=0.1)
        cbar = fig.colorbar(pim, cax=cax, ax=ax.ravel().tolist())
        if cmap==test_dataset.colormap:
            cbar.set_ticks([i.item()+0.5 for i in class_list])
        else:
            cbar.set_ticks([i.item() for i in class_list])
        cbar.set_ticklabels([class_mapping[i.item()] for i in class_list])

    ax[0].axis("off")
    ax[1].axis("off")
    ax[2].axis("off")
    ax[0].set_title(f"Image")
    ax[1].set_title(f"Prediction")
    ax[2].set_title(f"Ground Truth")
    
    plt.show()

## Compute Metrics on the test dataset

In [None]:
name_for_save = "focal_loss_gamma_2"

### Confusion Matrix

In [None]:
DEVICE = 'cuda'

unet = UNet(nbClasses=8).to(DEVICE)
unet.load_state_dict(torch.load('unet_model_2023-01-13_0:6_focal_loss_gamma_2_loss.pth', map_location=torch.device(DEVICE)))

y_pred = torch.zeros(0).to(DEVICE)
y_true = torch.zeros(0).to(DEVICE)

softmax = torch.nn.Softmax(dim=1)

# iterate over test data
for x, labels in tqdm(test_dataset):
    (x, labels) = (x.to(DEVICE), labels.to(DEVICE))
    x=x.unsqueeze(dim=0)
    
    preds = torch.argmax(softmax(unet(x)),axis=1)
    y_pred = torch.cat([y_pred,preds]) # Save Prediction
        
    #labels = labels.data.cpu().numpy()
    y_true = torch.cat([y_true,labels]) # Save Truth

In [None]:
metric = MulticlassConfusionMatrix(num_classes=8).to(DEVICE)
cf_matrix = metric(y_pred, y_true)
cf_matrix_np = cf_matrix.cpu().numpy()

# constant for classes
classes = ('Grass and other', 'Wald',
           'Bushes and sparse forest', 'Water and wetlands',
           'Glaciers and permanent snow', 'Sparse rocks (rocks mixed with grass)',
           'Loose rocks, scree', 'Bed rocks')

df_cm = pd.DataFrame(cf_matrix_np, index = [i for i in classes], columns = [i for i in classes])
plt.figure(figsize = (12,7))
sn.heatmap(df_cm, annot=True)
plt.savefig(f'confusion_matrix_{name_for_save}.png')

### Per-class Accuracy

In [None]:
per_class_acc = (cf_matrix.diag()/cf_matrix.sum(1)).cpu().numpy()
plt.figure(figsize = (12,1))
df_per_class_acc = pd.DataFrame([per_class_acc], columns = [i for i in classes])
sn.heatmap(df_per_class_acc, annot=True)
plt.savefig(f'per_class_accuracy_{name_for_save}.png')