In [None]:
from src.dataloader import LandCoverData

import torch
import torchvision.transforms as T
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as clr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import ListedColormap
import numpy as np
from tqdm import tqdm

%load_ext autoreload
%autoreload 2

In [None]:
path="../"

dataset = LandCoverData(path, transforms=None, split="train")

In [None]:
cmap = clr.LinearSegmentedColormap.from_list('custom_datacolor', dataset.colormap_names, N=256)
#cmap='viridis'
#cmap=dataset.colormap

In [None]:
tt_class=np.zeros(len(dataset.LABEL_CLASSES))
for _, label in tqdm(dataset):
    class_list, class_count = torch.unique(label, return_counts=True)  
    for i, c in enumerate(class_list):
        tt_class[c]+=class_count[i]

In [None]:
plt.style.use('ggplot')

labels = [k for k in dataset.LABEL_CLASSES.keys()]
indexes = np.arange(len(labels))

rescale = lambda indexes: (indexes - np.min(indexes)) / (np.max(indexes) - np.min(indexes))

width = 0.8
plt.barh(indexes, tt_class, color=cmap(rescale(indexes)))
plt.yticks(indexes, labels)
plt.xlabel('Pixels count', fontsize=16)
plt.ylabel('Class', fontsize=16)
plt.title('Barchart - Frequency of each class',fontsize=20)
plt.show()

In [None]:
# plot individual samples
from ipywidgets import widgets
from ipywidgets import interact
%matplotlib inline

names = [k for k in dataset.LABEL_CLASSES.keys()]
toPIL = T.ToPILImage()

style = {'description_width': 'initial'}
widget=widgets.BoundedIntText(
    value=0,
    min=0,
    max=len(dataset),
    step=1,
    description=f"Index data: (max={len(dataset)})",
    disabled=False,
    style=style
)
widget_cb=widgets.Checkbox(
    value=True,
    description='Colorbar',
    disabled=False,
    indent=False
)

@interact(idx=widget, flag_colorbar=widget_cb)
def plot_sample(idx=0, flag_colorbar=False):
    img, label = dataset[idx]
    class_mapping = {v: k for k, v in dataset.LABEL_CLASSES.items()}
    class_list, class_count = torch.unique(label, return_counts=True) 
    
    roffset=1
    if flag_colorbar:
        roffset=1.07
    fig, ax = plt.subplots(1,2, figsize=(12, 10), gridspec_kw={'width_ratios': [1, roffset]})      

    ax[0].imshow(toPIL(img))
    pim=ax[1].imshow(toPIL(label), cmap=cmap, vmin=0, vmax=len(dataset.LABEL_CLASSES))
    
    if flag_colorbar:
        class_list, _ = torch.unique(label, return_counts=True)
        divider = make_axes_locatable(ax[1])
        cax = divider.append_axes("right", size="5%", pad=0.1)
        cbar = fig.colorbar(pim, cax=cax, ax=ax.ravel().tolist())
        if cmap==dataset.colormap:
            cbar.set_ticks([i.item()+0.5 for i in class_list])
        else:
            cbar.set_ticks([i.item() for i in class_list])
        cbar.set_ticklabels([class_mapping[i.item()] for i in class_list])

    ax[0].axis("off")
    ax[1].axis("off")
    #ax[2].axis("off")
    ax[0].set_title(f"Image")
    ax[1].set_title(f"Label")
    
    plt.show()
    
    print("Label contains:")
    for i, v in enumerate(class_list):
        print(f"   - {class_mapping[v.item()]}: {class_count[i]} times.")

In [None]:
# DATA AUGMENTATION TEST

# Water and wetlands is less represented, -> index=4

index_augmentation=3
toPIL = T.ToPILImage()

tt_class=np.zeros(len(dataset.LABEL_CLASSES))
index_label=0
for img, label in tqdm(dataset):
    class_list, class_count = torch.unique(label, return_counts=True) 
    if index_augmentation in class_list:
        print(index_label)
        fig, ax = plt.subplots(1,4, figsize=(8, 10))    
        ax[0].imshow(toPIL(img))
        ax[1].imshow(toPIL(label), cmap=cmap, vmin=0, vmax=len(dataset.LABEL_CLASSES))
        rotated     = toPIL(img).rotate(156,expand=0)
        ax[2].imshow(rotated)
        rotated     = toPIL(label).rotate(156,expand=0)
        ax[3].imshow(rotated, cmap=cmap, vmin=0, vmax=len(dataset.LABEL_CLASSES))
        break
    index_label+=1