In [None]:
#| default_exp vision.data

In [None]:
#| export
from __future__ import annotations

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch
import fastai
from fastai.vision.all import *

from fastai_geospatial.vision.core import *

In [None]:
def get_grid_batch(nrows: int, ncols: int, figsize=None) -> list:
    assert nrows is not None and ncols is not None
    n_cells = nrows * ncols
    return get_grid(n_cells, nrows=nrows, ncols=ncols, figsize=figsize)

In [None]:
def chunk_grid(ctxs: list, ncols: int) -> list:
    return [ctxs[pos : pos + ncols] for pos in range(0, len(ctxs), ncols)]

In [None]:
@typedispatch
def show_batch(
    x: TensorImageMS,  # Input(s) in the batch
    y: TensorMask,  # Target(s) in the batch
    samples: list,  # List of (`x`, `y`) pairs of length `max_n`
    ctxs: list=None,  # List of `ctx` objects to show data. Could be a matplotlib axis, DataFrame, etc.
    max_n: int=9,  # Maximum number of `samples` to show
    nrows: int=None,
    ncols: int=None,
    figsize=None,
    **kwargs,
):
    assert nrows is None and ncols is None and ctxs is None
    assert len(samples[0]) == 2 and not hasattr(samples[0], "show")

    nrows, ncols = min(len(samples), max_n), x.num_images()
    ctxs = get_grid_batch(nrows, ncols, figsize=figsize)
    chks = chunk_grid(ctxs, ncols)
    imgs, msks = samples.itemgot(0), samples.itemgot(1)

    return [
        [msk.show(ctx=c, **kwargs) for c in img.show(ctxs=chk, **kwargs)]
        for img, msk, chk, _ in zip(imgs, msks, chks, range(nrows))
    ]

In [None]:
@typedispatch
def show_results(
    x: TensorImageMS,  # Input(s) in the batch
    y: TensorMask,  # Target(s) in the batch
    samples: list,  # List of (`x`, `y`) pairs of length `max_n`
    outs: list,  # List of predicted output(s) from the model
    ctxs=None,  # List of `ctx` objects to show data. Could be a matplotlib axis, DataFrame, etc.
    max_n: int=9,  # Maximum number of `samples` to show
    nrows:int=None,
    ncols:int=None,
    figsize=None,
    **kwargs,
):
    assert nrows is None and ncols is None and ctxs is None
    assert len(samples[0]) == 2 and not hasattr(samples[0], "show")

    ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)

    nrows, ncols = min(len(samples), max_n), x.num_images()
    ctxs = get_grid_batch(nrows, ncols, figsize=figsize)
    chks = [ctxs[pos : pos + ncols] for pos in range(0, len(ctxs), ncols)]

    ctxs = [
        [msk.show(ctx=c, **kwargs) for c in img.show(ctxs=chk, **kwargs)]
        for img, msk, chk, _ in zip(
            samples.itemgot(0), samples.itemgot(1), chks, range(nrows)
        )
    ]
    for i in range(len(outs[0])):
        ctxs = [
            b.show(ctx=c, **kwargs)
            for b, c, _ in zip(outs.itemgot(i), ctxs, range(max_n))
        ]
    return ctxs