# PreProcessing Demonstration

In this notebook, we will explore the PyTorch Dataset class and show how it can be used to handle input images. It should do some light preprocessing to make sure everything is available. 

Some things to note:
* checks to see that there is a 1-to-1 correspondence between the datasets
* does the tiling
* exports tiles that are extra

In [1]:
# Helpful trick for loading the directories correction
import sys, os
from pyprojroot import here
# spyder up to find the root
root = here(project_files=[".here"])
# append to path
sys.path.append(str(here()))

from pathlib import Path
from src.data.worldfloods.dataset import WorldFloodsDataset
from src.data.utils import get_files_in_directory, get_filenames_in_directory


output_image_dir = str(Path(root).joinpath("datasets/trials/image/image_tiles/"))
image_files = get_files_in_directory(output_image_dir, ".tif")

output_gt_dir = str(Path(root).joinpath("datasets/trials/image/gt_tiles/"))
gt_files = get_files_in_directory(output_gt_dir, ".tif")

image_prefix = "image_tiles"
gt_prefix = "gt_tiles"
pt_ds = WorldFloodsDataset(image_files, image_prefix, gt_prefix)

In [2]:
# Imports for the transformations

import src.preprocess.transformations as transformations
from torchvision import transforms
import numpy as np

-----------

In [17]:

# Stacked Transforms
tranform_permute = transformations.PermuteChannels()
tranform_toTensor = transformations.ToTensor()
tranform_oneHotEncoding = transformations.OneHotEncoding(num_classes=3)

mega_transform = transforms.Compose([tranform_permute, tranform_toTensor, tranform_oneHotEncoding])

pt_ds = WorldFloodsDataset(image_files, image_prefix, gt_prefix, transforms=mega_transform)

pt_ds[1]['mask'].shape

torch.Size([1, 128, 128, 3])

--------

### Transformations

* [ ] Flip
* [ ] GaussNoise
* [ ] MotionBlur
* [ ] Normalize
* [ ] PadIfNeeded
* [ ] RandomRotate90
* [ ] ShiftScaleRotate
---
* [ ] PerChannel Transformations
* [ ] ResizeFactor Transformation

In [4]:
# Stacked Transforms
transform_permute = transformations.PermuteChannels()
transform_toTensor = transformations.ToTensor()
transform_oneHotEncoding = transformations.OneHotEncoding(num_classes=3)
transform_resizeFactor = transformations.ResizeFactor(downsampling_factor=4, always_apply=True, p=1)

# DO NOT CHANGE THE ORDER
mega_transform = transforms.Compose([transform_resizeFactor, transform_permute, transform_toTensor, transform_oneHotEncoding])

pt_ds = WorldFloodsDataset(image_files, image_prefix, gt_prefix, transforms=transform_resizeFactor)

pt_ds[2]['image'].shape





(32, 32, 13)

---
* [ ] use a numpy array for every trasnformation other than "ToTensor"
* [ ] check to make sure that the normalize thing works for the special sensor
* [ ] check channeljitter PerRotation
* [ ] do a notebook showing what the transformation pictures look like
* [ ] discuss augmentation, adversarial training, etc content for ppt
