In [None]:
#|default_exp vision.load

In [None]:
#| export
from __future__ import annotations

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from typing import Callable
from fastai.vision.all import *
from fastgs.vision.core import *

# Tensor loading helpers

## `MSTensorGetter`

We create an abstraction that loads multispectral tensors given a list of channels and a tile id.

In [None]:
#| export
class MSTensorGetter:
    pass

@patch
def load_tensor(self: MSTensorGetter, band_ids: list[str], img_id: Any) -> TensorImageMS:
    pass

The common case is specified with 2 functions, one which returns a list of files for the specified channels, and another that loads a `TensorImageMS` given that list of files.

In [None]:
#| export
class _MSFileTensorGetter(MSTensorGetter):
    pass

@patch
def __init__(
    self: _MSFileTensorGetter,
    files_getter: Callable[[list[str], Any], list[str]],
    chan_io_fn: Callable[list[str], TensorImageMS]
):
    store_attr()

@patch
def load_tensor(self: _MSFileTensorGetter, band_ids: list[str], img_id: Any) -> TensorImageMS:
    files = self.files_getter(band_ids, img_id)
    return self.chan_io_fn(files)

For unusual cases, we supply a function that does the complete tensor loading

In [None]:
#| export
class _MSDelegatingTensorGetter(MSTensorGetter):
    pass

@patch
def __init__(
    self: _MSDelegatingTensorGetter,
    tg_fn: Callable[[list[str], Any], TensorImageMS]
):
    store_attr()

@patch
def load_tensor(self: _MSDelegatingTensorGetter, band_ids: list[str], img_id: Any) -> TensorImageMS:
    return self.tg_fn(band_ids, img_id)

Finally we provide factories

In [None]:
#| export
@patch(cls_method=True)
def from_files(
    cls: MSTensorGetter,
    files_getter: Callable[[list[str], Any], list[str]],
    chan_io_fn: Callable[list[str], TensorImageMS]
):
    return _MSFileTensorGetter(files_getter, chan_io_fn)

@patch(cls_method=True)
def from_delegate(
    cls: MSTensorGetter,
    tg_fn: Callable[[list[str], Any], TensorImageMS]
):
    return _MSDelegatingTensorGetter(tg_fn)

## Mask loading helpers

In a similar fashion we create helpers for Mask loading

In [None]:
#| export
class MSMaskGetter:
    pass

@patch
def load_mask(self: MSMaskGetter, band_ids: list[str], img_id: Any) -> TensorMask:
    pass

In [None]:
#| export
class _MSFileMaskGetter(MSMaskGetter):
    pass

@patch
def __init__(
    self: _MSFileMaskGetter,
    files_getter: Callable[[list[str], Any], list[str]],
    chan_io_fn: Callable[[list[str]], TensorMask]
):
    store_attr()

@patch
def load_mask(self: _MSFileMaskGetter, mask_id: str, img_id: Any) -> TensorMask:
    file = self.files_getter([mask_id], img_id)[0]
    return self.chan_io_fn(file)

In [None]:
#| export
class _MSDelegatingMaskGetter(MSMaskGetter):
    pass

@patch
def __init__(
    self: _MSDelegatingMaskGetter,
    tg_fn: Callable[[str, Any], TensorMask]
):
    store_attr()

@patch
def load_mask(self: _MSDelegatingMaskGetter, mask_id: str, img_id: Any) -> TensorMask:
    return self.tg_fn(mask_id, img_id)

In [None]:
#| export
@patch(cls_method=True)
def from_files(
    cls: MSMaskGetter,
    files_getter: Callable[[list[str], Any], list[str]],
    chan_io_fn: Callable[list[str], TensorMask]
):
    return _MSFileMaskGetter(files_getter, chan_io_fn)

@patch(cls_method=True)
def from_delegate(
    cls: MSMaskGetter,
    tg_fn: Callable[[str, Any], TensorMask]
):
    return _MSDelegatingMaskGetter(tg_fn)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()