In [None]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path

import lightning as L

In [None]:
sys.path.append("../")  # include parent directory
from src.pond_data import PondDataset, PondDataModule
from src.config_utils import build_kwargs_from_config
from src import pixel_stats

## Rollout Data Inspection

This notebook checks the geoTIFF files using the prebuilt Pytorch Lightning `Dataset` and `DataModule`

## Input
- Satellite Images (geoTIFF files)
- Raster Masks (geoTIFF files)


## Output
- None

## Set up parameters from config yaml

Feel free to edit the parameters here as well

In [None]:
DATA_PATH = Path("../data/")
CONFIG_PATH = Path("../config")

CONFIG_FPATH = CONFIG_PATH / "rollout_pond_config.yaml"

In [None]:
kwargs_dict = build_kwargs_from_config(DATA_PATH, CONFIG_FPATH)

In [None]:
DATASET_KWARGS = kwargs_dict["dataset_kwargs"]
DATASET_KWARGS

## Setting up the Pytorch Dataset

In [None]:
%%time
pond_dataset = PondDataset(**DATASET_KWARGS)
pond_dataset

In [None]:
# check labels
pond_dataset.label_mapping

In [None]:
# check a sample from the dataset
i = 2
pond_dataset[i]

## Data Assessment

We have a few methods in the Pytorch dataset to assess the data quality

1. Getting the min max pixel values: this is to check if there are values are within the preset `MAX_IMG_VAL` within the `PondDataset` class.
2. Getting the IDs of images that are all null: to check if there are erroneous images.
3. Check if the IDs of the images match up to the IDs of the masks.
4. Check the pixel dimensions of the images. Check how similar they are and which index the channels are.
5. Plotting images and seeing if there are images that are fully black or have other anomalies.

In [None]:
%%time
stream = pond_dataset.img_mask_stream()
channel_means, channel_stds = pixel_stats.get_mean_and_std_pixel_vals(stream)
channel_means, channel_stds

In [None]:
%%time
stream = pond_dataset.img_mask_stream()
pixel_stats.get_null_data_ids(stream, null_pixel_val=pond_dataset.NODATA_VAL)

In [None]:
%%time
# pond_dataset.check_imgs_masks_same_ids()

In [None]:
%%time
stream = pond_dataset.img_mask_stream()
pixel_stats.get_img_unique_shapes(stream)

In [None]:
%%time
stream = pond_dataset.img_mask_stream()
pixel_stats.validate_image_dims_for_segmentation(stream)

In [None]:
len(pond_dataset)

In [None]:
i = 2
pond_dataset.plot_img(i)

In [None]:
pond_dataset.plot_img_histogram(i)

In [None]:
# for i in range(len(pond_dataset)):
#     pond_dataset.plot_img_and_mask(i)
#     pond_dataset.plot_img_histogram(i)