In [None]:
from src.dataloader import LandCoverData

import torch
import torchvision.transforms as T
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as clr
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import ListedColormap
import numpy as np
from tqdm import tqdm
import os

%load_ext autoreload
%autoreload 2

# 1. DATA ANALYSIS

In [None]:
#path="../"
#path="/scratch/izar/nkaltenr/"
path="/scratch/izar/damiani/"

dataset = LandCoverData(path, transforms=None, split="train")
#dataset = LandCoverData(path, split="train")

In [None]:
cmap = clr.LinearSegmentedColormap.from_list('custom_datacolor', dataset.colormap_names, N=256)
#cmap='viridis'
#cmap=dataset.colormap

In [None]:
tt_class=np.zeros(len(dataset.LABEL_CLASSES))
for _, label in tqdm(dataset):
    class_list, class_count = torch.unique(label, return_counts=True)  
    for i, c in enumerate(class_list):
        tt_class[c]+=class_count[i]

In [None]:
plt.style.use('ggplot')

labels = [k for k in dataset.LABEL_CLASSES.keys()]
indexes = np.arange(len(labels))

rescale = lambda indexes: (indexes - np.min(indexes)) / (np.max(indexes) - np.min(indexes))

width = 0.8
plt.barh(indexes, tt_class, color=cmap(rescale(indexes)))
plt.yticks(indexes, labels)
plt.xlabel('Pixels count', fontsize=16)
plt.ylabel('Class', fontsize=16)
plt.title('Barchart - Frequency of each class',fontsize=20)
plt.show()

In [None]:
# Compute fraction of each class of the dataset
total_pixels = tt_class.sum()
frac_class = [c/total_pixels for c in tt_class]
print(frac_class)

In [None]:
# plot individual samples
from ipywidgets import widgets
from ipywidgets import interact
%matplotlib inline

names = [k for k in dataset.LABEL_CLASSES.keys()]
toPIL = T.ToPILImage()

style = {'description_width': 'initial'}
widget=widgets.BoundedIntText(
    value=0,
    min=0,
    max=len(dataset),
    step=1,
    description=f"Index data: (max={len(dataset)})",
    disabled=False,
    style=style
)
widget_cb=widgets.Checkbox(
    value=True,
    description='Colorbar',
    disabled=False,
    indent=False
)

@interact(idx=widget, flag_colorbar=widget_cb)
def plot_sample(idx=0, flag_colorbar=False):
    img, label = dataset[idx]
    class_mapping = {v: k for k, v in dataset.LABEL_CLASSES.items()}
    class_list, class_count = torch.unique(label, return_counts=True) 
    
    roffset=1
    if flag_colorbar:
        roffset=1.07
    fig, ax = plt.subplots(1,2, figsize=(12, 10), gridspec_kw={'width_ratios': [1, roffset]})      

    ax[0].imshow(toPIL(img))
    pim=ax[1].imshow(toPIL(label), cmap=cmap, vmin=0, vmax=len(dataset.LABEL_CLASSES))
    
    if flag_colorbar:
        class_list, _ = torch.unique(label, return_counts=True)
        divider = make_axes_locatable(ax[1])
        cax = divider.append_axes("right", size="5%", pad=0.1)
        cbar = fig.colorbar(pim, cax=cax, ax=ax.ravel().tolist())
        if cmap==dataset.colormap:
            cbar.set_ticks([i.item()+0.5 for i in class_list])
        else:
            cbar.set_ticks([i.item() for i in class_list])
        cbar.set_ticklabels([class_mapping[i.item()] for i in class_list])

    ax[0].axis("off")
    ax[1].axis("off")
    #ax[2].axis("off")
    ax[0].set_title(f"Image")
    ax[1].set_title(f"Label")
    
    plt.show()
    
    print("Label contains:")
    for i, v in enumerate(class_list):
        print(f"   - {class_mapping[v.item()]}: {class_count[i]} times.")

# 1.2 Sanity check in the training dataset (remove all black images)

In [None]:
tt_class=np.zeros(len(dataset.LABEL_CLASSES))
index_img = 0
for img, _ in tqdm(dataset):
    # Image is all black if the sum of the value of the pixels is equal to zero
    if img.sum() == 0:
        print(index_img)
    index_img += 1

### It's only the last 11 training images that are all black 

In [None]:
dataset = LandCoverData(path, transforms=None, split="train", ignore_last_number=11, use_augmented=False)
print(len(dataset))

# 2. PREPROCESSING

## 2.1 Compute mean per channel (R,G,B)

In [None]:
def compute_mean(dataset):
    mean_rgb = np.array([0.,0.,0.])

    for img, _ in tqdm(dataset):
        mean_rgb += img.mean((1,2)).numpy()

    mean_rgb = mean_rgb / len(dataset)
    print(f"mean : {mean_rgb}")
    return mean_rgb

In [None]:
dataset = LandCoverData(path, transforms=None, split="train", ignore_last_number=11, use_augmented=False)

meanRGB = compute_mean(dataset)

## 2.1 Compute std per channel (R,G,B)

In [None]:
def compute_std(dataset, mean_rgb):
    std = np.array([0.,0.,0.])
    stdTemp = np.array([0.,0.,0.])

    for img, _ in tqdm(dataset):
        for j in range(3):
            stdTemp[j] += ((img[j,:,:] - mean_rgb[j])**2).sum()/(img.shape[1]*img.shape[2])

    std = np.sqrt(stdTemp/len(dataset))
    print(f"std : {std}")
    return std

In [None]:
stdRGB = compute_std(dataset, meanRGB)

## 2.2 After normalization the mean should be close to 0 and std close to 1

In [None]:
datasetNormed = LandCoverData(path, split="train", ignore_last_number=11, use_augmented=False)

meanNormed=compute_mean(datasetNormed)

In [None]:
compute_std(datasetNormed, meanNormed)

# 3. DATA AUGMENTATION

#### You need to create 2 folder before running below cells:
#### cd data/
#### mkdir augmented_data_rgb
#### mkdir augmented_data_label

In [None]:
#path_augmented_rgb ="/scratch/izar/nkaltenr/ipeo_data/augmented_data_rgb/"
#path_augmented_label ="/scratch/izar/nkaltenr/ipeo_data/augmented_data_label/"

path_augmented_rgb ="/scratch/izar/damiani/ipeo_data/augmented_data_rgb/"
path_augmented_label ="/scratch/izar/damiani/ipeo_data/augmented_data_label/"

In [None]:
dataset2 = LandCoverData(path, transforms=None, split="train", ignore_last_number=11, use_augmented=False)

In [None]:
# Water and wetlands is less represented, -> index=4
for index_augmentation in range(1, 6):
    toPIL = T.ToPILImage()

    tt_class=np.zeros(len(dataset2.LABEL_CLASSES))
    index_label=0
    for img, label in tqdm(dataset2):
        class_list, class_count = torch.unique(label, return_counts=True) 
        if index_augmentation in class_list:
            #print(index_label)
            #fig, ax = plt.subplots(1,8, figsize=(16, 20))
            #ax[0].imshow(toPIL(img))
            #ax[1].imshow(toPIL(label), cmap=cmap, vmin=0, vmax=len(dataset2.LABEL_CLASSES))
            for i in range(1,4):
                rotation_deg = i * 90 #Only performs flip
                path = "/"
                name_image = f"{index_label}_{i}_rgb"
                name_label = f"{index_label}_{i}_label"

                rotated = toPIL(img).rotate(rotation_deg,expand=0)
                #ax[i*2].imshow(rotated)
                rotated.save(path_augmented_rgb+name_image+".tif")
                rotated = toPIL(label).rotate(rotation_deg,expand=0)
                rotated.save(path_augmented_label+name_label+".tif")
                #ax[i*2+1].imshow(rotated, cmap=cmap, vmin=0, vmax=len(dataset2.LABEL_CLASSES))
        index_label+=1

##### You should see images in the cells below if the data augmentation worked

In [None]:
img = Image.open(path_augmented_rgb+"420_3_rgb.tif")
label = Image.open(path_augmented_label+"420_3_label.tif")

fig, ax = plt.subplots(1,2, figsize=(8, 10))
ax[0].imshow(img)
ax[1].imshow(label, cmap=cmap, vmin=0, vmax=len(dataset2.LABEL_CLASSES))

In [None]:
name_list = os.listdir(path_augmented_rgb)
for name in name_list:
    name_without_extension = name[:-4]
    img = Image.open(path_augmented_rgb+name_without_extension+".tif")
    label = Image.open(path_augmented_label+name_without_extension[:-3]+"label.tif")

    fig, ax = plt.subplots(1,2, figsize=(8, 10))
    ax[0].imshow(img)
    ax[1].imshow(label, cmap=cmap, vmin=0, vmax=len(dataset2.LABEL_CLASSES))
    break

## 3.2 Compute Barchart Frequency with data augmented

In [None]:
dataset_final = LandCoverData(path, transforms=None, split="train", ignore_last_number=11, use_augmented=True)

In [None]:
tt_class=np.zeros(len(dataset_final.LABEL_CLASSES))
for _, label in tqdm(dataset_final):
    class_list, class_count = torch.unique(label, return_counts=True)  
    for i, c in enumerate(class_list):
        tt_class[c]+=class_count[i]

In [None]:
plt.style.use('ggplot')

labels = [k for k in dataset_final.LABEL_CLASSES.keys()]
indexes = np.arange(len(labels))

rescale = lambda indexes: (indexes - np.min(indexes)) / (np.max(indexes) - np.min(indexes))

width = 0.8
plt.barh(indexes, tt_class, color=cmap(rescale(indexes)))
plt.yticks(indexes, labels)
plt.xlabel('Pixels count', fontsize=16)
plt.ylabel('Class', fontsize=16)
plt.title('Barchart - Frequency of each class',fontsize=20)
plt.show()

In [None]:
# Compute fraction of each class of the dataset
total_pixels = tt_class.sum()
frac_class = [c/total_pixels for c in tt_class]
print(frac_class)

In [None]:
[0.24643106446206095, 0.06570622957842312, 0.062330682962183345, 0.020174477207666023, 0.03669090817922338, 0.05728250507725563, 0.2680321487788832, 0.24335198375430434]

# 3.3 Compute Mean and Std of Augmented Dataset

In [None]:
datasetAugmented = LandCoverData(path, transforms=None, split="train",
                                 ignore_last_number=11, use_augmented=True)

meanAugmented=compute_mean(datasetAugmented)
stdAugmented=compute_std(datasetAugmented,meanAugmented)

#### Check mean and std to 0

In [None]:
datasetAugmented = LandCoverData(path, split="train", ignore_last_number=11, use_augmented=True)

meanAugmented=compute_mean(datasetAugmented)
stdAugmented=compute_std(datasetAugmented,meanAugmented)

# 4. Compute median class frequency

In [None]:
dataset = LandCoverData(path, split="train", ignore_last_number=11, use_augmented=True)

In [None]:
tt_class=np.zeros(len(dataset.LABEL_CLASSES))
for _, label in tqdm(dataset):
    class_list, class_count = torch.unique(label, return_counts=True)  
    for i, c in enumerate(class_list):
        tt_class[c]+=class_count[i]

In [None]:
freq_c=tt_class/np.sum(tt_class)
print(freq_c)

In [None]:
med_freq=np.median(freq_c)
print(med_freq)

In [None]:
# WEIGHT FREQ:
print(med_freq/freq_c)

#### 5 Keep only data with balanced classes labels

In [None]:
datasetAugmentedRestricted = LandCoverData(path, transforms=None, split="train",
                                 ignore_last_number=11,
                                 use_augmented=True,
                                 restrict_classes=True)

In [None]:
meanRestricted=compute_mean(datasetAugmentedRestricted)
stdRestricted=compute_std(datasetAugmentedRestricted,meanRestricted)

In [None]:
tt_class=np.zeros(len(datasetAugmentedRestricted.LABEL_CLASSES))
for _, label in tqdm(datasetAugmentedRestricted):
    class_list, class_count = torch.unique(label, return_counts=True)  
    for i, c in enumerate(class_list):
        tt_class[c]+=class_count[i]

In [None]:
plt.style.use('ggplot')

labels = [k for k in datasetAugmentedRestricted.LABEL_CLASSES.keys()]
indexes = np.arange(len(labels))

rescale = lambda indexes: (indexes - np.min(indexes)) / (np.max(indexes) - np.min(indexes))

width = 0.8
plt.barh(indexes, tt_class, color=cmap(rescale(indexes)))
plt.yticks(indexes, labels)
plt.xlabel('Pixels count', fontsize=16)
plt.ylabel('Class', fontsize=16)
plt.title('Barchart - Frequency of each class',fontsize=20)
plt.show()