In [None]:
from src.dataloader import LandCoverData

import torch
import torchvision.transforms as T
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import ListedColormap
import numpy as np

%load_ext autoreload
%autoreload 2

In [None]:
path="../"

dataset = LandCoverData(path, transforms=None, split="train")

In [None]:
# plot individual samples
from ipywidgets import widgets
from ipywidgets import interact
%matplotlib inline

names = [k for k in dataset.LABEL_CLASSES.keys()]
toPIL = T.ToPILImage()

style = {'description_width': 'initial'}
widget=widgets.BoundedIntText(
    value=0,
    min=0,
    max=len(dataset),
    step=1,
    description=f"Index data: (max={len(dataset)})",
    disabled=False,
    style=style
)

@interact(idx=widget)
def plot_sample(idx=0):
    img, label = dataset[idx]
    fig, ax = plt.subplots(1,2, figsize=(12, 10), gridspec_kw={'width_ratios': [1, 1.07]})    
    class_mapping = {v: k for k, v in dataset.LABEL_CLASSES.items()}
    class_list, class_count = torch.unique(label, return_counts=True)  
    
    print(label[0][10])
    ax[0].imshow(toPIL(img))
    pim=ax[1].imshow(toPIL(label), cmap=dataset.colormap, vmin=0, vmax=len(dataset.LABEL_CLASSES))
    
    label_names = [names[x] for x in class_list]
    divider = make_axes_locatable(ax[1])
    cax = divider.append_axes("right", size="5%", pad=0.1)
    cbar = fig.colorbar(pim, cax=cax, ax=ax.ravel().tolist())
    cbar.set_ticks([i.item()+0.5 for i in class_list])
    cbar.set_ticklabels([class_mapping[i.item()] for i in class_list])

    ax[0].axis("off")
    ax[1].axis("off")
    #ax[2].axis("off")
    ax[0].set_title(f"Image")
    ax[1].set_title(f"Label")
    
    plt.show()
    
    print("Label contains:")
    for i, v in enumerate(class_list):
        print(f"   - {class_mapping[v.item()]}: {class_count[i]} times.")