# databatcher

> Iterable that will break a long data tensor into batches of samples.

In [None]:
#| default_exp common

In [None]:
# | hide
%load_ext autoreload
%autoreload 2

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import math

In [None]:
#| export
from fastcore.test import *
import torch


In [None]:
#| export
class DataBatcher:
    """Iterable that will break a long data tensor into batches of samples."""
    def __init__(
        self, data: torch.Tensor, sample_len: int, max_batch_size: int, stride: int
    ):
        assert len(data.shape) == 1, "Data must be a 1D tensor"
        assert len(data) >= sample_len, "Data length must be at least sample_len"

        self.samples = data.unfold(0, sample_len, stride)
        self.sample_len = sample_len
        self.max_batch_size = max_batch_size

    def __len__(self):
        """Returns the number of batches that will be produced."""
        return math.ceil(len(self.samples) / self.max_batch_size)

    def __iter__(self):
        for i in range(0, len(self.samples), self.max_batch_size):
            yield self.samples[i : i + self.max_batch_size]

In [None]:
# Tests for DataBatcher

# Basic example
data_batcher = DataBatcher(data=torch.arange(6), sample_len=3, max_batch_size=2, stride=1)
test_eq(len(data_batcher), 2)
test_eq(
    list(data_batcher),
    [
        [[0, 1, 2], [1, 2, 3]],
        [[2, 3, 4], [3, 4, 5]],
    ],
)

# Basic example with stride
data_batcher = DataBatcher(data=torch.arange(10), sample_len=3, max_batch_size=2, stride=2)
test_eq(len(data_batcher), 2)
test_eq(
    list(data_batcher),
    [
        [[0, 1, 2], [2, 3, 4]],
        [[4, 5, 6], [6, 7, 8]],
    ],
)

# No repeated elements: stride = chunk_len
data_batcher = DataBatcher(data=torch.arange(6), sample_len=3, max_batch_size=2, stride=3)
test_eq(len(data_batcher), 1)
test_eq(
    list(data_batcher),
    [
        [[0, 1, 2], [3, 4, 5]],
    ],
)

# Last batch is smaller than max_batch_size
data_batcher = DataBatcher(data=torch.arange(7), sample_len=3, max_batch_size=2, stride=2)
test_eq(len(data_batcher), 2)
test_eq(
    list(data_batcher),
    [
        [[0, 1, 2], [2, 3, 4]],
        [[4, 5, 6]],
    ],
)

# Not even one complete batch
data_batcher = DataBatcher(data=torch.arange(3), sample_len=3, max_batch_size=2, stride=2)
test_eq(len(data_batcher), 1)
test_eq(
    list(data_batcher),
    [
        [[0, 1, 2]],
    ],
)

# Same as above but add an extra element which still doesn't make a full batch
data_batcher = DataBatcher(data=torch.arange(8), sample_len=3, max_batch_size=2, stride=2)
test_eq(len(data_batcher), 2)
test_eq(
    list(data_batcher),
    [
        [[0, 1, 2], [2, 3, 4]],
        [[4, 5, 6]],
    ],
)

# Only one chunk can fit
data_batcher = DataBatcher(data=torch.arange(8), sample_len=8, max_batch_size=2, stride=2)
test_eq(len(data_batcher), 1)
test_eq(
    list(data_batcher),
    [
        [[0, 1, 2, 3, 4, 5, 6, 7]],
    ],
)

# Large stride makes it so that only one chunk can fit
data_batcher = DataBatcher(data=torch.arange(10), sample_len=3, max_batch_size=2, stride=12)
test_eq(len(data_batcher), 1)
test_eq(
    list(data_batcher),
    [
        [[0, 1, 2]],
    ],
)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()