In [None]:
#|default_exp ffcv.loader

In [None]:
#|exporti
# Contains code from:
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai

# FFCV Loader
> fastxtend's fastai+FFCV Integrated DataLoader

fastxtend's `Loader` adds fastai features to [FFCV's Loader](https://docs.ffcv.io/making_dataloaders.html), including [<code>one_batch</code>](https://docs.fast.ai/data.core.html#dataloader.one_batch), [<code>show_batch</code>](https://docs.fast.ai/data.core.html#tfmddl.show_batch), [<code>show_results</code>](https://docs.fast.ai/data.core.html#tfmddl.show_results), and support for GPU batch transforms, to name a few.

In [None]:
#|export
from __future__ import annotations

from pathlib import Path
from typing import Mapping, Sequence

import numpy as np

from ffcv.fields.base import Field
from ffcv.loader.loader import Loader as _Loader
from ffcv.loader.loader import OrderOption, ORDER_TYPE, DEFAULT_OS_CACHE, ORDER_MAP
from ffcv.pipeline.compiler import Compiler
from ffcv.pipeline.operation import Operation
from ffcv.transforms.ops import ToDevice as _ToDevice

from fastcore.basics import GetAttr, detuplify, Inf
from fastcore.dispatch import retain_types, explode_types
from fastcore.meta import funcs_kwargs
from fastcore.transform import Pipeline

from fastai.data.core import show_batch, show_results

from fastxtend.ffcv.epoch_iterator import EpochIterator, AsyncEpochIterator
from fastxtend.imports import *

In [None]:
#|export
_all_ = ['OrderOption']

In [None]:
#|hide
from nbdev.showdoc import show_doc

In [None]:
#|exporti
@funcs_kwargs
class BaseDL(GetAttr):
    "Provides callbacks for DataLoaders which inherit from `BaseLoader`"
    _methods = 'before_iter after_batch after_iter'.split()
    def __init__(self, **kwargs):
        pass

    def before_iter(self, x=None, *args, **kwargs):
        "Called before `BaseLoader` starts to read/iterate over the dataset."
        return x

    def after_batch(self, x=None, *args, **kwargs):
        "After collating mini-batch of items, the mini-batch is passed through this function."
        return x

    def after_iter(self, x=None, *args, **kwargs):
        "Called after `BaseLoader` has fully read/iterated over the dataset."
        return x

In [None]:
#|export
class Loader(BaseDL, _Loader):
    "FFCV `Loader` with fastai Transformed DataLoader `TfmdDL` batch transforms"
    def __init__(self,
        fname:str|Path, # Path to the location of the dataset (FFCV beton format)
        batch_size:int, # Batch size
        num_workers:int=-1, # Number of CPU cores to use in parallel (default: All available up to 16)
        os_cache:bool=DEFAULT_OS_CACHE, # Leverage the OS for caching. Beneficial when there is enough memory to cache the dataset
        order:ORDER_TYPE=OrderOption.SEQUENTIAL, # Dataset traversal order, one of: `SEQEUNTIAL`, `RANDOM`, `QUASI_RANDOM`
        distributed:bool=False, # Emulates the behavior of PyTorch's DistributedSampler for distributed training
        seed:int|None=None, # Random seed for batch ordering
        indices:Sequence[int]|None=None, # Subset dataset by returning only these indices
        pipelines:Mapping[str, Sequence[Operation|nn.Module]]={}, # Dictionary defining for each field the sequence of Decoders and transforms to apply
        custom_fields:Mapping[str, Field]={}, # Dictonary informing `Loader` of the types associated to fields that are using a custom type
        drop_last:bool|None=None, # Drop non-full batch in each epoch. Defaults to True if order is `SEQEUNTIAL`
        batches_ahead:int=2, # Number of batches prepared in advance; balances latency and memory
        recompile:bool=False, # Recompile at every epoch. Required if FFCV augmentations change during training
        device:str|int|torch.device|None=None, # Device to place batch. Defaults to fastai's `default_device`
        async_tfms:bool=False, # Asynchronously run `batch_tfms` before batch is drawn.
        n_inp:int|None=None, # Number of inputs to the model. Defaults to pipelines length minus 1
        split_idx:int|None=None, # Apply batch transform(s) to training (0) or validation (1) set. Defaults to valid if order is `SEQEUNTIAL`
        do_setup:bool=True, # Run `setup()` for batch transform(s)
        **kwargs
    ):
        if 'batch_tfms' in kwargs:
            if 'after_batch' not in kwargs:
                kwargs['after_batch'] = kwargs.pop('batch_tfms')
            else:
                raise ValueError('Cannot pass both `after_batch` and `batch_tfms` to `FFCVDataLoader`')

        if split_idx is None:
            self._split_idx = int(order==OrderOption.SEQUENTIAL)
        else:
            self._split_idx = split_idx

        kwargs['after_batch'] = Pipeline(kwargs.get('after_batch', None), split_idx=self._split_idx)
        if do_setup:
            kwargs['after_batch'].setup(self)

        self.async_tfms = async_tfms and len(kwargs['after_batch'].fs) > 0
        self.cuda_streams = None

        if drop_last is None:
            drop_last != order==OrderOption.SEQUENTIAL

        _Loader.__init__(self,
            fname=str(Path(fname)),
            batch_size=batch_size,
            num_workers=num_workers,
            os_cache=os_cache,
            order=order,
            distributed=distributed,
            seed=seed,
            indices=indices,
            pipelines=pipelines,
            custom_fields=custom_fields,
            drop_last=drop_last,
            batches_ahead=batches_ahead,
            recompile=recompile
        )
        BaseDL.__init__(self, **kwargs)

        if device is None:
            self.device = default_device()
        else:
            self.device = device

        if n_inp is None:
            self._n_inp = len(pipelines) - 1
        else:
            self._n_inp = n_inp

        for name in ['item_tfms', 'after_item', 'before_batch']:
            if name in kwargs:
                if name != 'before_batch':
                    msg = f"fastxtend's `Loader` will not call any {name} methods. " \
                          f"{name} is for use with a fastai DataLoader.\n" \
                          f"Instead of passing fastai Item Transforms to {name}," \
                          f"initialize the fastxtend `Loader` pipeline with FFCV transforms."
                else:
                    msg = f"fastxtend's `Loader` will not call any {name} methods. " \
                          f"{name} are for use with a fastai DataLoader."
                warn(msg)


    def one_batch(self, batches_ahead:bool=False):
        "Return one processed batch of input(s) and target(s), optionally loading `batches_ahead`"
        for b in self._n_batches(self.batches_ahead + 2 if batches_ahead else 1):
            # need to return the yield from _n_batches so `Loader` can reset to iterate the entire epoch
            pass
        return b

    def show_batch(self,
        b=None, # Batch to show
        max_n:int=9, # Maximum number of items to show
        ctxs=None, # List of `ctx` objects to show data. Could be matplotlib axis, DataFrame etc
        show:bool=True, # Whether to display data
        unique:bool=False, # Whether to show only one
        **kwargs
    ):
        "Show `max_n` input(s) and target(s) from the batch."
        if unique:
            old_get_idxs = self.get_idxs
            self.get_idxs = lambda: Inf.zeros
        if b is None:
            b = self.one_batch()
        if not show:
            return self._pre_show_batch(b, max_n=max_n)
        # Uses Type Dispatch to call the correct `show_batch` for b
        show_batch(*self._pre_show_batch(b, max_n=max_n), ctxs=ctxs, max_n=max_n, **kwargs)
        if unique:
            self.get_idxs = old_get_idxs

    def show_results(self,
        b, # Batch to show results for
        out, # Predicted output from model for the batch
        max_n:int=9, # Maximum number of items to show
        ctxs=None, # List of `ctx` objects to show data. Could be matplotlib axis, DataFrame etc
        show:bool=True, # Whether to display data
        **kwargs
    ):
        "Show `max_n` results with input(s), target(s) and prediction(s)."
        x,y,its = self.show_batch(b, max_n=max_n, show=False)
        b_out = type(b)(b[:self.n_inp] + (tuple(out) if is_listy(out) else (out,)))
        x1,_,outs = self.show_batch(b_out, max_n=max_n, show=False)
        if its is None:
            res = (x, x1, None, None)
        else:
            res = (x, y, its, outs.itemgot(slice(self.n_inp,None)))
        if not show:
            return res
        # Uses Type Dispatch to call the correct `show_results` for b
        show_results(*res, ctxs=ctxs, max_n=max_n, **kwargs)

    @property
    def n_inp(self) -> int:
        "Number of elements in a batch for model input"
        return self._n_inp

    @property
    def device(self):
        return self._device

    @device.setter
    def device(self, device:int|str|torch.device):
        # parse device
        device, *_ = torch._C._nn._parse_to(device=device)
        self._device = device
        # Device setter for FFCV Pipeline
        for p in self.pipeline_specs.values():
            for t in p.transforms:
                if isinstance(t, _ToDevice):
                    t.device = device
        # Device setter for Loader.batch_tfms
        if hasattr(self.after_batch, 'fs'):
            self._pipeline_device(self.after_batch.fs)

    def to(self, device:int|str|torch.device):
        "Sets `self.device=device`."
        self.device = device
        return self

    @property
    def split_idx(self):
        return self._split_idx

    @split_idx.setter
    def split_idx(self, split_idx:int):
        "Sets fastai batch transforms to train (split_idx=0) or valid (split_idx=1)"
        self._split_idx = split_idx
        if isinstance(self.after_batch, Pipeline):
            self.after_batch.split_idx = split_idx

    def decode(self, b):
        "Decode batch `b`"
        return to_cpu(self.after_batch.decode(self._retain_dl(b)))

    def decode_batch(self, b, max_n:int=9):
        "Decode up to `max_n` input(s) from batch `b`"
        return self._decode_batch(self.decode(b), max_n)

    def _pipeline_device(self, pipe):
        "Device setter for fastai pipeline"
        for tfm in pipe:
            if hasattr(tfm, 'to') and callable(tfm.to):
                tfm.to(self.device, non_blocking=True)
            else:
                for a in L(getattr(tfm, 'parameters', None)):
                    setattr(tfm, a, getattr(tfm, a).to(self.device, non_blocking=True))

    def _iter(self):
        Compiler.set_num_threads(self.num_workers)
        order = self.next_traversal_order()
        selected_order = order[:len(self) * self.batch_size]
        self.next_epoch += 1

        # Compile at the first epoch
        if self.code is None or self.recompile:
            self.generate_code()

        # Asynchronous transforms require using the same Cuda streams for the entire run
        if self.cuda_streams is None:
            self.cuda_streams = [(torch.cuda.Stream() if torch.cuda.is_available() else None)
                                  for _ in range(self.batches_ahead + 2)]
        if self.async_tfms:
            return AsyncEpochIterator(self, selected_order, self.after_batch)
        else:
            return EpochIterator(self, selected_order)

    def __iter__(self):
        self.before_iter()
        if self.async_tfms:
            yield from self._iter()
        else:
            for b in self._iter():
                yield self.after_batch(b)
        self.after_iter()
        if hasattr(self, 'it'):
            del(self.it)

    def _one_pass(self, b=None):
        if b is None:
            b = self.one_batch()
        self._types = explode_types(b)

    def _retain_dl(self, b):
        if not getattr(self, '_types', None):
            self._one_pass(b)
        return retain_types(b, typs=self._types)

    def _decode_batch(self, b, max_n=9):
        return L(batch_to_samples(b, max_n=max_n))

    def _pre_show_batch(self, b, max_n=9):
        "Decode `b` to be ready for `show_batch`"
        b = self.decode(b)
        if hasattr(b, 'show'):
            return b,None,None
        its = self._decode_batch(b, max_n)
        if not is_listy(b):
            b,its = [b],L((o,) for o in its)
        return detuplify(b[:self.n_inp]),detuplify(b[self.n_inp:]),its

    def _n_batches(self, num_batches:int=1):
        orig_traversal_order = self.traversal_order
        orig_indices = self.indices
        orig_drop_last = self.drop_last

        # Set Loader to only return one batch per epoch
        if self._args['order'] == OrderOption.SEQUENTIAL:
            self.indices = np.arange(0, self.batch_size*num_batches)
        else:
            self.indices = np.random.random_integers(0, self.reader.num_samples, self.batch_size*num_batches)
        self.traversal_order = ORDER_MAP[OrderOption.SEQUENTIAL](self)
        self.drop_last = False

        # yield num_batches
        yield from self.__iter__()

        # Reset Loader state to its original status
        self.next_epoch -= 1
        self.indices = orig_indices
        self.drop_last = orig_drop_last
        self.traversal_order = orig_traversal_order

Important `Loader` arguments:

- `order`: Controls how much memory is used for dataset caching and whether the dataset is randomly shuffled. Can be one of `RANDOM`, `QUASI_RANDOM`, or `SEQUENTIAL`. See the note below for more details. Defaults to `SEQUENTIAL`, which is unrandomized.

- `os_cache`: By default, FFCV will attempt to cache the entire dataset into RAM using the operating system's caching. This can be changed by setting `os_cache=False` or setting the enviroment variable 'FFCV_DEFAULT_CACHE_PROCESS' to "True" or "1". If `os_cache=False` then `order` must be set to `QUASI_RANDOM` for the training `Loader`.

- `num_workers`: If not set, will use all CPU cores up to 16 by default.

- `batches_ahead`: Controls the number of batches ahead the `Loader` works. Increasing uses more RAM, both CPU and GPU. Defaults to 2.

- `n_inp`: Controls which inputs to pass to the model. By default, set to number of pipelines minus 1.

- `drop_last`: Whether to drop the last partial batch. By default, will set to True if `order` is `RANDOM` or `QUASI_RANDOM`, False if `SEQUENTIAL`.

- `device`: The device to place the processed batches of data on. Defaults to `fastai.torch_core.default_device` if not set.

- `async_tfms`: Asynchronously apply `batch_tfms` before the batch is drawn. Can accelerate training if GPU compute isn't fully saturated (95% or less) or if only using `IntToFloatTensor` and `Normalize`.

- `split_idx`: This tells the fastai batch transforms what dataset they are operating on. By default will use 0 (train) if `order` is `RANDOM` or `QUASI_RANDOM`, 1 (valid) if `SEQUENTIAL`.

- `distributed`: For distributed training on multiple GPUs. Emulates the behavior of PyTorch's [`DistributedSampler`](https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler). `QUASI_RANDOM` is unavailable with distributed training.

:::{.callout-note collapse="false"}
#### Note: Order Memory Usage
Each `order` option requires differing amounts of system memory.

- `RANDOM` caches the entire dataset in memory for fast random sampling. `RANDOM` uses the most memory.

- `QUASI_RANDOM` caches a subset of the dataset at a time in memory and randomly samples from the subset. Use when the entire dataset cannot fit into memory.

- `SEQUENTIAL` requires least memory. It loads a few samples ahead of time. As the name suggests, it is not random, and primarly is for validation.
:::

Asynchronous batch transforms can accelerate training by decreasing the draw time at the expense of slightly longer batch step. If the GPU isn't fully saturated, usually 95% or less compute use, this will be a net gain in training performance. `async_tfms=True` pairs well with `ProgressiveResize`, as the GPU is almost never saturated when training on smaller then full sized images. When near or fully saturated, asynchronous batch transforms usually result a wash in training time.

In [None]:
show_doc(Loader.one_batch)

In [None]:
show_doc(Loader.show_batch)

In [None]:
show_doc(Loader.show_results)

In [None]:
show_doc(Loader.to)

In [None]:
show_doc(Loader.n_inp)

In [None]:
show_doc(Loader.decode)

In [None]:
show_doc(Loader.decode_batch)