# Generic Samplers

We create Sampler classes for Recursive Data Types, in particular `Mapping` and `Sequence`/`Collection`.


In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import numpy as np

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
import logging
from collections.abc import Iterator, Mapping
from itertools import chain
from typing import Iterator, Optional, Union

from tsdm.utils.types import KeyType, ValueType

In [None]:
from torch.utils.data import Sampler

In [None]:
class MappingSampler(Sampler):
    r"""Samples randomly from a given Mapping object."""

    # idx: Index
    # r"""The shared index."""
    early_stop: bool = False
    r"""Whether to stop sampling when the index is exhausted."""
    shuffle: bool = True
    r"""Whether to sample in random order."""
    sizes: Series
    r"""The sizes of the subsamplers."""
    partition: Series
    r"""Contains each key a number of times equal to the size of the subsampler."""
    subsamplers: Optional[Mapping[KeyType, Sampler]]
    r"""The subsamplers to sample from the collection."""

    def __init__(
        self,
        data_source: Mapping[KeyType, ValueType],
        subsamplers: Optional[Union[Mapping[KeyType, Sampler]]] = None,
        shuffle: bool = True,
        early_stop: bool = False,
    ):
        super().__init__(data_source)
        self.data = data_source
        self.shuffle = shuffle
        self.idx = data_source.keys()
        self.subsamplers = dict(subsamplers)
        self.early_stop = early_stop
        self.sizes = Series({key: len(self.subsamplers[key]) for key in self.idx})

        if early_stop:
            # sample min(map(len, subsamplers)) from each subsampler
            partition = list(chain(*([key] * min(self.sizes) for key in self.idx)))
        else:
            # sample len(subsampler) from each subsampler
            partition = list(chain(*([key] * self.sizes[key] for key in self.idx)))
        self.partition = Series(partition)

    def __len__(self):
        r"""Return the maximum allowed index."""
        if self.subsamplers is None:
            return len(self.data)
        if self.early_stop:
            return min(self.sizes) * len(self.subsamplers)
        return sum(self.sizes)

    def __iter__(self) -> Iterator:
        r"""Return indices of the samples.

        When ``early_stop=True``, it will sample precisely min() * len(subsamplers) samples.
        When ``early_stop=False``, it will sample all samples.
        """
        if self.subsamplers is None:
            perm = np.random.permutation(self.keys())
            return iter(perm)

        activate_iterators = {
            key: iter(sampler) for key, sampler in self.subsamplers.items()
        }
        perm = np.random.permutation(self.partition)

        for key in perm:
            yield key, next(activate_iterators[key])

    def __getitem__(self, key: KeyType) -> ValueType:
        r"""Return the data for the given key."""
        return self.data[key]

In [None]:
class HierarchicalSampler(Sampler):
    """Either: Initialize with nested dictionary

    - can handle arbitrary nestedness

    Or: tell what subsampler to use

    - can handle product types.
    - allow parametrization as input.
    """

    index: nested_dict
    samplers: nested_dict