# PreProcessing Demonstration

In this notebook, we will explore the PyTorch Dataset class and show how it can be used to handle input images. It should do some light preprocessing to make sure everything is available. 

Some things to note:
* checks to see that there is a 1-to-1 correspondence between the datasets
* does the tiling
* exports tiles that are extra

In [1]:
# Helpful trick for loading the directories correction
import sys, os
from pyprojroot import here
# spyder up to find the root
root = here(project_files=[".here"])
# append to path
sys.path.append(str(here()))

from pathlib import Path
from src.data.worldfloods.dataset import WorldFloodsDataset
from src.data.utils import get_files_in_directory, get_filenames_in_directory

# Imports for the transformations
import src.preprocess.transformations as transformations
# from torchvision import transforms
import numpy as np

output_image_dir = str(Path(root).joinpath("datasets/trials/image/image_tiles/"))
image_files = get_files_in_directory(output_image_dir, ".tif")

output_gt_dir = str(Path(root).joinpath("datasets/trials/image/gt_tiles/"))
gt_files = get_files_in_directory(output_gt_dir, ".tif")

image_prefix = "image_tiles"
gt_prefix = "gt_tiles"

pt_ds = WorldFloodsDataset(image_files, image_prefix, gt_prefix)

In [2]:
# from albumentations import Compose

# Stacked Transforms
tranform_permute = transformations.PermuteChannels()
tranform_toTensor = transformations.ToTensor()
tranform_oneHotEncoding = transformations.OneHotEncoding(num_classes=3)

mega_transform = transformations.Compose([
    tranform_permute, 
    tranform_toTensor, 
    tranform_oneHotEncoding,
    ])

pt_ds = WorldFloodsDataset(image_files, image_prefix, gt_prefix, transforms=mega_transform)

# pt_ds[1]['image'].shape
pt_ds[3]['mask'].shape


torch.Size([1, 128, 128, 3])

--------

### Transformations

* [ ] Flip
* [ ] GaussNoise
* [ ] MotionBlur
* [ ] Normalize
* [ ] PadIfNeeded
* [ ] RandomRotate90
* [ ] ShiftScaleRotate
---
* [ ] PerChannel Transformations
* [X] ResizeFactor Transformation

In [4]:
from src.preprocess.worldfloods import normalize as wf_normalize

# Stacked Transforms
transform_permute = transformations.PermuteChannels()
transform_toTensor = transformations.ToTensor()
# TODO: Check number of classes
transform_oneHotEncoding = transformations.OneHotEncoding(num_classes=4) 
transform_resizeFactor = transformations.ResizeFactor(downsampling_factor=4, always_apply=True, p=1)

use_channels = "all"
channel_mean, channel_std = wf_normalize.get_normalisation(use_channels)
transform_normalize = transformations.Normalize(mean=channel_mean, std=channel_std, max_pixel_value=1.0)

# DO NOT CHANGE THE ORDER
mega_transform = transforms.Compose([
    transform_resizeFactor, 
    transform_permute, 
    transform_toTensor, 
    transform_oneHotEncoding,
    transform_normalize,
    ])

pt_ds = WorldFloodsDataset(
    image_files, 
    image_prefix, 
    gt_prefix, 
    transforms=mega_transform,
    )

print("image:", pt_ds[2]['image'].shape)
print("mask:", pt_ds[2]['mask'].shape)





IndexError: index 13 is out of bounds for axis 0 with size 13

---
* [ ] use a numpy array for every trasnformation other than "ToTensor"
* [ ] check to make sure that the normalize thing works for the special sensor
* [ ] check channeljitter PerRotation
* [ ] do a notebook showing what the transformation pictures look like
* [ ] discuss augmentation, adversarial training, etc content for ppt


In [5]:
from src.preprocess.worldfloods import normalize as wf_normalize
import matplotlib.pyplot as plt

# Normalization -------------

use_channels = "all"
channel_mean, channel_std = wf_normalize.get_normalisation(use_channels)
transform_normalize = transformations.NormalizeCustom(mean=channel_mean, std=channel_std, max_pixel_value=1.0)

# pt_ds_norm = WorldFloodsDataset(image_files, image_prefix, gt_prefix, transforms=transform_gauss)
pt_ds_norm = transform_normalize(input_data=pt_ds[1])
pt_ds_norm['image'].shape

# Gaussian Noise -------------

# transform_gauss = transformations.GaussNoise(var_limit=(1e-6, 1e-3), p=1)
# pt_ds_gauss = transform_gauss(image=pt_ds[1]['image'])

# print(pt_ds_gauss['image'][:,:,6].shape)
# print(np.min(pt_ds_gauss['image'][:,:,6]))
# print(np.max(pt_ds_gauss['image'][:,:,6]))
# fig, ax = plt.subplots(nrows=1, ncols=2)
# ax[0].imshow(pt_ds[1]['image'][:,:,6])
# ax[1].imshow(pt_ds_gauss['image'][:,:,6])
# plt.show()

IndexError: index 13 is out of bounds for axis 0 with size 13