In [None]:
#| default_exp vision.data

In [None]:
#| export
from __future__ import annotations

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastai.vision.all import *
from fastgs.vision.core import *

# Multi-spectral Vision Data

> Data support functions for working with `TensorImageMS`

## Show methods -

In [None]:
#| export
def _show_one_sample(img: TensorImageMS, msk: TensorMask, row, mskovl, **kwargs):
    if mskovl:
        return [msk.show(ctx=c, **kwargs) for c in img.show(ctxs=row, **kwargs)]
    else:
        nimgs = img.num_images()
        return img.show(ctxs=row[:nimgs]) + [msk.show(row[nimgs])]

In [None]:
#| export
@typedispatch
def show_batch(
    x: TensorImageMS,  # Input(s) in the batch
    y: TensorMask,  # Target(s) in the batch
    samples: list,  # List of (`x`, `y`) pairs of length `max_n`
    ctxs=None,  # List of `ctx` objects to show data. Could be a matplotlib axis, DataFrame, etc.
    max_n: int=9,  # Maximum number of `samples` to show
    nrows:int=None,
    ncols:int=None,
    figsize=None,
    mskovl:bool=True, # mask is overlaid on the image
    **kwargs
):
    assert len(samples[0]) == 2 and not hasattr(samples[0], "show")
    assert nrows is None and ncols is None and ctxs is None

    nimgs = x.num_images()
    nrows = min(len(samples),max_n)
    ncols = nimgs if mskovl else nimgs + 1

    ctxs = get_grid(nrows * ncols, nrows, ncols, figsize=figsize)
    rwcx = [ctxs[pos : pos + ncols] for pos in range(0, len(ctxs), ncols)]
    imgs,msks = samples.itemgot(0),samples.itemgot(1)

    return [_show_one_sample(img, msk, row, mskovl) for img, msk, row, _ in zip(imgs, msks, rwcx, range(nrows))]

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()