In [None]:
#| default_exp vision.testio

In [None]:
#| export
from __future__ import annotations

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from PIL import Image
from fastai.vision.all import *

# Sample IO operations for Testing

We provide a number of io helper functions that are used in the library notebooks to list and read files in the repo image directory.

These functions will vary based on the naming conventions and file format of your own data files.

## File I/O from PNG files

We create a multi-spectral tensor by reading (in this case) sentinel 2 images from files. Each sentinel channel is saved in a separate `.png` file. We can use `PIL` to read the files into `numpy` arrays which are then cast into `Tensor`s.

Sentinel channels have so-called `DN` values less than `10000`. By convention `55537` is assigned to a pixel when the actual data is missing or unknown.

We have normalized the data with a min-max of `(0, 10000)` after replacing missing values with `9999`.

In [None]:
#| export
def _filter_masked(raw_arr, in_msk: int, out_msk: int):
    "Replace input mask pixel value with selected value"
    return np.select([raw_arr == in_msk], [out_msk], raw_arr)

def read_chn_file(path: str) -> Tensor:
    "Read single channel file into tensor"
    img_arr = np.array(Image.open(path))
    msk_arr = _filter_masked(img_arr, 55537, 9999)
    return Tensor(msk_arr / 10000)

def read_multichan_files(files: list(str)) -> Tensor:
    "Read individual channel tensor files into a tensor of channels"
    return torch.cat([read_chn_file(path)[None] for path in files])

To keep the io in one file, we use the following to read label masks for segmentation.

In [None]:
#| export
# TODO abstract this filter
def _to_bin_seg(img_arr):
    return np.select([img_arr == 255, img_arr < 6, img_arr == 6],[0, 0, 1],img_arr)

def read_mask_file(path: str) -> TensorMask:
    """Read ground truth segmentation label files with values from 0 to n."""
    img_arr = np.array(Image.open(path))
    prc_arr = _to_bin_seg(img_arr)
    return TensorMask(prc_arr)

## File and directory names

In [None]:
#| export
def _get_input(stem: str) -> str:
    "Get full input path for stem"
    return "./images/" + stem

def _tile_img_name(chn_id: str, tile_num: int) -> str:
    "File name from channel id and tile number"
    return f"Sentinel20m-{chn_id}-20200215-{tile_num:03d}.png"

def get_channel_filenames(chn_ids, tile_idx):
    "Get list of all channel filenames for one tile idx"
    return [_get_input(_tile_img_name(x, tile_idx)) for x in chn_ids]

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()