In [None]:
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch.nn import BCELoss
from torch.nn import MSELoss
from torch.optim import Adam
from torch.optim import SGD

import numpy as np
from tqdm import tqdm
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as clr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import ListedColormap
import time
import datetime

from src.model import UNet
from src.dataloader import LandCoverData

%load_ext autoreload
%autoreload 2

In [None]:
unet=torch.load('../unet_model_2022-12-22_11:12.pt', map_location=torch.device('cpu'))

In [None]:
MASK_DATASET_PATH = "../"
test_dataset = LandCoverData(path=MASK_DATASET_PATH,transforms=None, split="test")

In [None]:
#cmap = clr.LinearSegmentedColormap.from_list('custom_datacolor', test_dataset.colormap_names, N=256)
cmap='viridis'
#cmap=dataset.colormap

In [None]:
# plot individual samples
from ipywidgets import widgets
from ipywidgets import interact
%matplotlib inline

style = {'description_width': 'initial'}
widget=widgets.BoundedIntText(
    value=0,
    min=0,
    max=len(test_dataset),
    step=1,
    description=f"Index data: (max={len(test_dataset)})",
    disabled=False,
    style=style
)
widget_cb=widgets.Checkbox(
    value=False,
    description='Colorbar',
    disabled=False,
    indent=False
)

@interact(idx=widget, flag_colorbar=widget_cb)
def plot_sample(idx=0, flag_colorbar=False):
    roffset=1
    if flag_colorbar:
        roffset=1.07
    fig, ax = plt.subplots(1,3, figsize=(12, 10), gridspec_kw={'width_ratios': [1, 1, roffset]})    
    
    class_mapping = {v: k for k, v in test_dataset.LABEL_CLASSES.items()}  
    
    x, y = test_dataset[idx]
    x=x.unsqueeze(dim=0)
    
    softmax = torch.nn.Softmax(dim=1)
    preds = torch.argmax(softmax(unet(x)),axis=1)
    img = np.transpose(np.array(x[0,:,:]),(1,2,0))
    preds = np.array(preds[0,:,:])
    mask = np.array(y[0,:,:])
    
    ax[0].imshow(img)
    ax[1].imshow(preds, cmap=cmap, vmin=0, vmax=len(test_dataset.LABEL_CLASSES))
    pim=ax[2].imshow(mask, cmap=cmap, vmin=0, vmax=len(test_dataset.LABEL_CLASSES))
    
    if flag_colorbar:
        class_list, _ = torch.unique(y[0,:,:], return_counts=True)
        divider = make_axes_locatable(ax[2])
        cax = divider.append_axes("right", size="5%", pad=0.1)
        cbar = fig.colorbar(pim, cax=cax, ax=ax.ravel().tolist())
        if cmap==dataset.colormap:
            cbar.set_ticks([i.item()+0.5 for i in class_list])
        else:
            cbar.set_ticks([i.item() for i in class_list])
        cbar.set_ticklabels([class_mapping[i.item()] for i in class_list])

    ax[0].axis("off")
    ax[1].axis("off")
    ax[2].axis("off")
    ax[0].set_title(f"Image")
    ax[1].set_title(f"Prediction")
    ax[2].set_title(f"Ground Truth")
    
    plt.show()