In [1]:
from __future__ import annotations
from typing import Any, Iterator, Tuple, Optional
from abc import ABC, abstractmethod
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as random
from concurrent.futures import ThreadPoolExecutor
import deepscratch

# Dataset

This package implements datasets as follows: they have a length and support indexing such that the user can interrogate the data. They are also one-time iterable, meaning they yield observations sequentially until a StopIteration is raised, at which point the dataset is exhausted.

In [2]:
class Dataset(ABC):
    
    @abstractmethod
    def __len__(self) -> int:
        pass
    
    @abstractmethod
    def __getitem__(self, index: int) -> Any:
        pass

## Dataloader

In [None]:
class DataLoader:
    """
    A PyTorch-like DataLoader for JAX with multi-threaded batching.
    """

    def __init__(
        self,
        dataset: Dataset,
        batch_size: int = 1,
        shuffle: bool = False,
        drop_last: bool = False,
        num_workers: int = 1,
        seed: Optional[int] = None,
        iobound: bool = True
    ):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.num_workers = num_workers
        self.seed = seed if seed is not None else np.random.randint(0, int(1e6))
        self.indices = jnp.arange(len(dataset))
        self.rng_key = random.PRNGKey(self.seed)
        self.iobound = iobound

    def _fetch_item(self, index: int) -> Any:
        return self.dataset[int(index)]

    def __iter__(self) -> Iterator[Tuple[np.ndarray, np.ndarray]]:
        if self.shuffle:
            self.rng_key, subkey = random.split(self.rng_key)
            self.indices = random.permutation(subkey, self.indices)

        if self.iobound:
            batch_images = []
            batch_labels = []
            with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
                for image, label in executor.map(self._fetch_item, self.indices):
                    batch_images.append(image)
                    batch_labels.append(label)
                    if len(batch_images) == self.batch_size:
                        yield np.stack(batch_images), np.stack(batch_labels)
                        batch_images = []
                        batch_labels = []

            if batch_images and not self.drop_last:
                yield np.stack(batch_images), np.stack(batch_labels)
        
        else:
            batched_getitem = jax.vmap(self.dataset._jitgetitem)
            for i in range(len(self)):
                batch_indices = self.indices[i * self.batch_size : (i + 1) * self.batch_size]
                batch = batched_getitem(batch_indices)
                yield batch


    def __len__(self) -> int:
        size = len(self.dataset) // self.batch_size
        if not self.drop_last and len(self.dataset) % self.batch_size != 0:
            size += 1
        return size