In [None]:
# default_exp data.utils

In [None]:
# hide
%load_ext autoreload
%autoreload 2

# Data utils
> Utility functions to manipulate data.

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# export
from grade_classif.data.loaders import ImageLoader
from grade_classif.imports import *
from torch.utils.data import Sampler

In [None]:
# export
def np_to_tensor(x: NDArray[(Any, ...), Number], tensor_type: str) -> torch.Tensor:
    if tensor_type == "image" or tensor_type == "slide":
        x = x.transpose(2, 0, 1)
        if x.dtype == np.uint8:
            x = x.astype(np.float32) / 255
    x = torch.tensor(x)
    return x

Convert a numpy ndarray into a tensor. If `tensor_type` is `'image'`, put channel first.

In [None]:
# export
def show_img(
    x: NDArray[(Any, Any, 3), Number],
    ax: Optional[Axes] = None,
    figsize: Tuple[int, int] = (3, 3),
    title: Optional[str] = None,
    hide_axis: bool = True,
    cmap: str = "viridis",
    alpha: Optional[float]=None,
    **kwargs
) -> Axes:
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    xtr = dict(cmap=cmap, alpha=alpha, **kwargs)
    ax.imshow(x, **xtr)
    if hide_axis:
        ax.axis("off")
    if title:
        ax.set_title(title)
    return ax

Convienience function for plotting images

In [None]:
# export
def load_batches(
    folder: Path,
    bs: int = 16,
    device: str = "cpu",
    filt: Optional[Callable[[Path], bool]] = None,
) -> Iterable[torch.Tensor]:
    x = []
    image_loader = ImageLoader()
    for fn in folder.iterdir():
        if filt is not None and not filt(fn):
            continue
        if len(x) < bs:
            img = image_loader(fn)
            img = np_to_tensor(img, "image").to(device)
            x.append(img)
        else:
            yield torch.stack(x)
            x = []
    if x != []:
        yield torch.stack(x)

Generator function that loads images from `folder` in batches of size `bs` on `device`.

In [None]:
# export
class LabelSlideBalancedRandomSampler(Sampler[int]):
    def __init__(
        self,
        labels: Sequence[str],
        patch_slides: Sequence[str],
        num_samples,
        replacement: bool = True,
        generator=None,
    ):
        assert len(labels) == len(
            patch_slides
        ), "labels and slides must have same length"
        self.classes, self.labels = np.unique(labels, return_inverse=True)
        self.classes = np.arange(len(self.classes))
        self.slides, self.patch_slides = np.unique(patch_slides, return_inverse=True)
        self.slides = np.arange(len(self.slides))
        self.num_samples = num_samples
        self.replacement = replacement
        self.generator = generator

    def get_idxs(self):
        tree = {}
        for cl in self.classes:
            tree[cl] = {}
            patch_slides = self.patch_slides[self.labels == cl]
            for slide in np.unique(patch_slides):
                tree[cl][slide] = (
                    np.argwhere(self.patch_slides == slide).squeeze(1).tolist()
                )
        idxs = []
        for _ in range(self.num_samples):
            x = torch.rand(3, generator=self.generator)
            classes = list(tree.keys())
            cl = classes[int(x[0]*len(classes))]
            cl_slides = tree[cl]
            slides = list(cl_slides.keys())
            slide = slides[int(x[1]*len(slides))]
            slide_patches = cl_slides[slide]
            idx = int(x[2]*len(slide_patches))
            if self.replacement:
                patch = slide_patches[idx]
            else:
                patch = slide_patches.pop(idx)
                if len(slide_patches) == 0:
                    cl_slides.pop(slide)
                    if len(cl_slides) == 0:
                        tree.pop(cl)
            idxs.append(patch)
        return idxs

    def __iter__(self):
        return iter(self.get_idxs())

    def __len__(self):
        return self.num_samples

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_core.ipynb.
Converted 01_train.ipynb.
Converted 02_predict.ipynb.
Converted 10_data.read.ipynb.
Converted 11_data.loaders.ipynb.
Converted 12_data.dataset.ipynb.
Converted 13_data.utils.ipynb.
Converted 14_data.transforms.ipynb.
Converted 15_data.color.ipynb.
Converted 16_data.modules.ipynb.
Converted 20_models.plmodules.ipynb.
Converted 21_models.modules.ipynb.
Converted 22_models.utils.ipynb.
Converted 23_models.hooks.ipynb.
Converted 24_models.metrics.ipynb.
Converted 25_models.losses.ipynb.
Converted 80_params.defaults.ipynb.
Converted 81_params.parser.ipynb.
Converted 99_index.ipynb.
