# Map-Reduce

A lot of data processing operation become very neat and scalable when expressed as **MapRecuce** operations.

[MapReduce](https://en.wikipedia.org/wiki/MapReduce) is a system (protocol? concept?) in which the operations over the set of data entries (**rows** in a **table**) are represented through a sequence of **map** and **reduce** operations. The fact is that almost all the transformation over data could be expressed as a sequence of map and reduce operations. The operations themselves can easily be parallelized, so you can grok insane amounts of data by running the ops with some distributed system, for example, [YTSaurus](https://ytsaurus.tech/).

What it means:
1. **Map** is a function $f$ such as, if the table is $T$, then for each $x \in T$, it maps $x$ into a set of new rows: $f(x_i) = \{y_1, \ldots, y_{j_i}\}$. This function is then applied to all rows in the table and the result is flattened.
2. **Reduce** function $g$ takes a *key* $k$ and a set of all rows $X_k$ that correspond to it and returns a set of new rows: $g(k, X_k) = \{y_1, \ldots, y_{j_k}\}$

Let's implement a toy MapReduce system. All the operations will be configures with `pydantic` config objects.

In [None]:
import pydantic

from itertools import groupby
from operator import itemgetter
from typing import Any, Callable, Iterable, Iterator, Literal, TypeVar


class TransformConfigBase(pydantic.BaseModel): ...


class MapTransformConfigBase(TransformConfigBase): ...


class MapReduceTransformConfigBase(TransformConfigBase): ...


Key = TypeVar("Key")
Row = dict[Any, Any]


class TransformBase:
    def run(self, iterator: Iterable[dict[Any, Any]]) -> list[dict[Any, Any]]:
        pass


class MapTransformBase(TransformBase):
    def __init__(self, map_fn: Callable[[Row], Iterator[Row]], config: MapTransformConfigBase):
        self._map_fn = map_fn
        self.config = config

    def run(self, iterator: Iterable[Row]) -> list[Row]:
        return [row for input_row in iterator for row in self._map_fn(input_row)]


class MapReduceTransformBase(TransformBase):
    def __init__(
        self,
        map_fn: Callable[[Row], Iterator[tuple[Key, Row]]],
        reduce_fn: Callable[[Key, Iterable[Row]], Iterator[Row]],
        config: MapReduceTransformConfigBase,
    ) -> None:
        self._map_fn = map_fn
        self._reduce_fn = reduce_fn
        self.config = config

    def run(self, iterator: Iterable[Row]) -> list[Row]:
        # 1) Map: produce (key, row) pairs
        pairs: list[tuple[K, Row]] = [row for input_row in iterator for row in self._map_fn(input_row)]
        if not pairs:
            return []

        # 2) Sort by key so groupby works correctly
        pairs.sort(key=itemgetter(0))

        # 3) Group by key and feed each group's rows into reduce_fn
        out: list[Row] = []
        for key, group in groupby(pairs, key=itemgetter(0)):
            values_iter = [row for _, row in group]
            out.extend(self._reduce_fn(key, values_iter))

        # 4) Return the flattened reduce results
        return out

Let's start with a simple map example. Our table contains some columns and one of them contains text. We want to transform the table by changing the case of text in a specified column to either upper or lower.

In [None]:
class ChangeCaseTransformConfig(MapTransformConfigBase):
    type: Literal["change_case"] = "change_case"
    column: str
    to_lower: bool = False


class ChangeCaseTransform(MapTransformBase):
    """
    Takes a dictionary and changes the case of value with the key `column`
    """

    class _MapFn:
        def __init__(self, config: ChangeCaseTransformConfig):
            self.config = config

        def __call__(self, row: Row) -> Iterator[Row]:
            # your code goes here
            yield row

    def __init__(
        self,
        config: ChangeCaseTransformConfig,
    ) -> None:
        super().__init__(self._MapFn(config), config)


test_config = ChangeCaseTransformConfig(
    column="content",
    to_lower=True,
)
test_input = [
    {"content": "A String"},
    {"content": "abcc"},
    {"content": "BCD"},
]

assert ChangeCaseTransform(test_config).run(test_input) == [
    {"content": "a string"},
    {"content": "abcc"},
    {"content": "bcd"},
]

A classic example of a map-reduce operation is counting the number of occurences of all letters (chars) in the texts in a specified table's column. To do so, we first run a mapper that for each row produces multiple rows with the letter as a **key** and a column that stores the number of occurences of this letter in this particular input row. The reducer than just sums up the counters for each letter.

In [None]:
from collections import Counter


class CountCharsTransformConfig(MapReduceTransformConfigBase):
    """
    If `lowercase` is `False`, don't change the case,
    otherwise, make the string lowercase first
    """

    lowercase: bool = False
    column: str


class CountCharsTransform(MapReduceTransformBase):
    """
    Takes a table with strings, optionally, casts them to the lower case,
    produces a table with columns `letter` and `cnt` with the count of total occurences
    of each character in the input table strings.
    """

    class _MapFn:
        def __init__(self, config: CountCharsTransformConfig):
            self.config = config

        def __call__(self, row: Row) -> Iterator[Row]:
            # your code goes here

            letters_count = ...
            for letter, cnt in letters_count.items():
                yield letter, {"cnt": cnt}

    class _ReduceFn:
        def __init__(self, config: CountCharsTransformConfig):
            self.config = config

        def __call__(self, key: Key, rows: Iterable[Row]) -> Iterator[Row]:
            # your code goes here
            pass

    def __init__(
        self,
        config: CountCharsTransformConfig,
    ) -> None:
        super().__init__(
            self._MapFn(config),
            self._ReduceFn(config),
            config,
        )


test_config = CountCharsTransformConfig(
    column="content",
    lowercase=True,
)

test_input = [
    {"content": "A String"},
    {"content": "abcc"},
    {"content": "BCD"},
]

assert CountCharsTransform(test_config).run(test_input) == [
    {"letter": " ", "cnt": 1},
    {"letter": "a", "cnt": 2},
    {"letter": "b", "cnt": 2},
    {"letter": "c", "cnt": 3},
    {"letter": "d", "cnt": 1},
    {"letter": "g", "cnt": 1},
    {"letter": "i", "cnt": 1},
    {"letter": "n", "cnt": 1},
    {"letter": "r", "cnt": 1},
    {"letter": "s", "cnt": 1},
    {"letter": "t", "cnt": 1},
]

# Language Modeling Datasets

Language models don't operate on text, they operate on tokens. We know that. However, only the list of tokens is usually not sufficient for the model training. I mean it kinda is but not **always**. Sometimes we want two additional things:
1. **Loss mask.** It shows on which tokens we need to calculate the loss. Helps in two cases: 1) whenever we have padding somewhere in the tensors, we don't want to calculate the loss on pads; 2) when the data has some "inputs" and "outputs" and only the output is written by the LM. Then we want to calculate loss only on what LM is actually required to predict.
2. **Position IDs**. Just a sequential numbers indicating the position of a current token. They're required in case we have truncated sequences and learned/cosine postional encodings. They are not strictly required with RoPE but it's still very helpful to have them stored for auxulary needs we'll cover later.

### What do they actually look like?

Position IDs are simple. Just an `arange` from 0 to the number of tokens. The loss mask, however, depends on what we want. Let's cover some standard situations.

#### Pre-Training

In pre-training, loss mask is all `True`s. There might be some issues with padding when we implement specific packing algorithms but on an individual-example level, it's all `True`s.

#### Fill-In-the-Middle

In FIM, the loss mask might be still all `True`s as in pre-training but we might have an option to only calculate the loss on **middle**, because this is what we actually expect the model to generate, everything else is the input.

#### Supervised Fine-Tuning of Instruct Models

In instruct-SFT, we usually only calculate the loss on the assistant's repsonses.



Let's implement all these strategies.

In [None]:
import numpy as np
from jaxtyping import UInt32, Bool
from transformers import AutoTokenizer, PreTrainedTokenizer

TokenizedRow = dict[str, UInt32[np.typing.ArrayLike, "example_len"] | Bool[np.typing.ArrayLike, "example_len"]]

tokenizer = AutoTokenizer.from_pretrained("JetBrains/Mellum-4b-base")

In [None]:
def pretrain_processing(
    row: Row,
    tokenizer: PreTrainedTokenizer,
) -> TokenizedRow:
    """
    Takes a row with a column `content`, returns a new row with three columns:
    1. `input_ids`
    2. `position_ids`
    3. `loss_mask`

    All of them are NumPy arrays. The tokens IDs and positions fit into UInt32, so
    use it as a dtype. For loss mask, use `bool`.
    """
    text = row["content"]
    # your code goes here

    return {
        "input_ids": input_ids,
        "position_ids": position_ids,
        "loss_mask": loss_mask,
    }


result = pretrain_processing({"content": "A sample text"}, tokenizer)
assert np.allclose(result["input_ids"], np.array([59, 5875, 1378, 0], dtype=np.uint32))
assert np.allclose(result["loss_mask"], np.array([True, True, True, True], dtype=np.bool))
assert np.allclose(result["position_ids"], np.array([0, 1, 2, 3], dtype=np.uint32))

In [None]:
def fim_processing(
    row: Row,
    tokenizer: PreTrainedTokenizer,
    only_middle_not_masked: bool = False,
    fim_format: Literal["spm", "psm"] = "spm",
) -> TokenizedRow:
    """
    The input is a row with three columns: `prefix`, `suffix`, `middle`.
    The output is the same as with pre-training. The FIM format indicates the type of FIM we're using.
    SPM stands for `<fim_suffix>{suffix}<fim_prefix>{prefix}<fim_middle>{middle}`,
    PSM for `<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>{middle}`.
    """
    assert len(tokenizer.encode("<fim_prefix>")) == 1, "Tokenizer doesn't have special FIM tokens"
    assert len(tokenizer.encode("<fim_suffix>")) == 1, "Tokenizer doesn't have special FIM tokens"
    assert len(tokenizer.encode("<fim_middle>")) == 1, "Tokenizer doesn't have special FIM tokens"
    fim_prefix_id = tokenizer.encode("<fim_prefix>")[0]
    fim_suffix_id = tokenizer.encode("<fim_suffix>")[0]
    fim_middle_id = tokenizer.encode("<fim_middle>")[0]
    prefix = row["prefix"]
    suffix = row["suffix"]
    middle = row["middle"]

    # your code goes here


example = {
    "prefix": "import numpy as",
    "middle": "np",
    "suffix": "\nimport pandas as pd",
}
result_psm = fim_processing(example, tokenizer, only_middle_not_masked=False, fim_format="psm")
assert np.allclose(
    result_psm["input_ids"], np.array([1, 653, 9440, 609, 3, 225, 653, 17984, 609, 8559, 2, 3057, 0], dtype=np.uint32)
)
assert np.allclose(result_psm["loss_mask"], np.ones(13, dtype=np.bool))
assert np.allclose(result_psm["position_ids"], np.arange(13, dtype=np.uint32))

result_spm = fim_processing(example, tokenizer, only_middle_not_masked=False, fim_format="spm")
assert np.allclose(
    result_spm["input_ids"], np.array([3, 225, 653, 17984, 609, 8559, 1, 653, 9440, 609, 2, 3057, 0], dtype=np.uint32)
)
assert np.allclose(result_spm["loss_mask"], np.ones(13, dtype=np.bool))
assert np.allclose(result_spm["position_ids"], np.arange(13, dtype=np.uint32))

result_spm_masked = fim_processing(example, tokenizer, only_middle_not_masked=True, fim_format="spm")
assert np.allclose(
    result_spm_masked["input_ids"],
    np.array([3, 225, 653, 17984, 609, 8559, 1, 653, 9440, 609, 2, 3057, 0], dtype=np.uint32),
)
assert np.allclose(
    result_spm_masked["loss_mask"], np.concatenate([np.zeros(11, dtype=np.bool), np.ones(2, dtype=np.bool)])
)
assert np.allclose(result_spm_masked["position_ids"], np.arange(13, dtype=np.uint32))

In [None]:
def sft_processing(
    row: Row,
    tokenizer: PreTrainedTokenizer,
) -> TokenizedRow:
    """
    The row has a column `content` which is a list of turns in the
    format of `[{"role": user/assistant, "content": msg_text}]`.
    """
    assert len(tokenizer.encode("<user>")) == 1, "Tokenizer chat format is not supported or doesn't exist"
    assert len(tokenizer.encode("</user>")) == 1, "Tokenizer chat format is not supported or doesn't exist"
    assert len(tokenizer.encode("<assistant>")) == 1, "Tokenizer chat format is not supported or doesn't exist"
    assert len(tokenizer.encode("</assistant>")) == 1, "Tokenizer chat format is not supported or doesn't exist"
    user_start_id = tokenizer.encode("<user>")[0]
    user_end_id = tokenizer.encode("</user>")[0]
    assistant_start_id = tokenizer.encode("<assistant>")[0]
    assistant_end_id = tokenizer.encode("</assistant>")[0]
    conversation = row["content"]

    # your code goes here


example = {"content": [{"role": "user", "content": "Hi!"}, {"role": "assistant", "content": "How can I help you?"}]}
result = sft_processing(example, tokenizer)
assert np.allclose(
    result["input_ids"], np.array([21, 4479, 27, 22, 23, 7166, 867, 391, 2739, 691, 57, 24, 0], dtype=np.uint32)
)
assert np.allclose(result["loss_mask"], np.concatenate([np.zeros(4, dtype=np.bool), np.ones(9, dtype=np.bool)]))
assert np.allclose(result["position_ids"], np.arange(13, dtype=np.uint32))

In [None]:
ProcessingStrategy = Literal["pretrain", "fim", "sft"]


def process_row(row: Row, tokenizer: PreTrainedTokenizer, strategy: ProcessingStrategy, **kwargs) -> TokenizedRow:
    if strategy == "pretrain":
        return pretrain_processing(row, tokenizer)
    elif strategy == "fim":
        return fim_processing(row, tokenizer, **kwargs)
    elif strategy == "sft":
        return sft_processing(row, tokenizer)
    else:
        raise ValueError(f"Strategy `{strategy} is not supported`")

# Examples Packing

You may notice that all the text have different lengths. However, we need to somehow feed them all into the model. We can't just leave the examples as is because in this case we won't be able to build a batch (unless using batch size 1 and PyTorch eager) since all tensors will have different shapes. We could pad them all to some pre-defined number of tokens (context size) but this will be super inefficient: an example of the length 1 would have, let's say, 8191 pad tokens.

To solve this issue, the examples are usually *packed* into tensors of the `seq_len` size. So, multiple short sequences are getting concatenated. Will it mean that when looking at the second example, the model would see the tokens from the first one? No! To avoid it, we'll build an attention mask matrix. It will be used in the attention layers, so, for each token, the attention will be calculated only on the previous tokens and only on those that correspond to the same example in the packing.

There are two main strategies used for packing:
1. **Dense (pre-training).** We greedily concatenate all the sequences and cut them into `seq_len` pieces.
2. **No incomplete (fine-tuning).** We just try to concatenate the examples until we exceed the `seq_len`. When we can't fit a new sequence, *pad the previous ones*.

Let's implement them.

In [None]:
def position_fix_helper(postion_ids: UInt32[np.typing.ArrayLike, "example_len"]) -> UInt32[np.ndarray, "example_len"]:
    """
    Rebase each segment (between decreases) so it starts at 0.
    Example: [2,3,4,0,1,2,3] -> [0,1,2,0,1,2,3]
             [4,0,1,0]       -> [0,0,1,0]
    """
    ids = np.asarray(postion_ids, dtype=np.uint32)
    n = ids.shape[0]
    if n == 0:
        return ids

    # Find starts of new segments where the sequence decreases.
    disc = np.flatnonzero(ids[1:] < ids[:-1]) + 1
    starts = np.concatenate(([0], disc))

    out = ids.astype(np.int64, copy=True)  # avoid uint underflow
    for i, s in enumerate(starts):
        e = starts[i + 1] if i + 1 < len(starts) else n
        base = int(out[s])
        out[s:e] -= base

    return out.astype(np.uint32)

In [None]:
def pack_dense(rows: Iterable[TokenizedRow], seq_len: int) -> list[TokenizedRow]:
    """
    Greedily packs all the rows into seq_len tensors, dropping the trailing examples if there are any
    """
    input_ids = np.concatenate([row["input_ids"] for row in rows])
    loss_mask = np.concatenate([row["loss_mask"] for row in rows])
    position_ids = np.concatenate([row["position_ids"] for row in rows])
    # your code goes here


examples = [
    {
        "input_ids": np.array([1, 2, 3, 4, 5], dtype=np.uint32),
        "loss_mask": np.ones(5, dtype=np.bool),
        "position_ids": np.arange(5, dtype=np.uint32),
    },
    {
        "input_ids": np.array([0, 1], dtype=np.uint32),
        "loss_mask": np.ones(2, dtype=np.bool),
        "position_ids": np.arange(2, dtype=np.uint32),
    },
    {
        "input_ids": np.array([10, 10], dtype=np.uint32),
        "loss_mask": np.ones(2, dtype=np.bool),
        "position_ids": np.arange(2, dtype=np.uint32),
    },
    {
        "input_ids": np.array([11], dtype=np.uint32),
        "loss_mask": np.ones(1, dtype=np.bool),
        "position_ids": np.arange(1, dtype=np.uint32),
    },
]

result = pack_dense(examples, 4)
assert np.allclose(result[0]["input_ids"], np.array([1, 2, 3, 4], dtype=np.uint32))
assert np.allclose(result[1]["input_ids"], np.array([5, 0, 1, 10], dtype=np.uint32))
assert np.allclose(result[0]["position_ids"], np.array([0, 1, 2, 3], dtype=np.uint32))
assert np.allclose(result[1]["position_ids"], np.array([0, 0, 1, 0], dtype=np.uint32))

In [None]:
def pack_no_incomplete(rows: Iterable[TokenizedRow], seq_len: int, pad_token_id: int) -> list[TokenizedRow]:
    """
    Ensures there are no incomplete (truncated) examples. If an example exceeds seq_len, drop it. Pad the tensors
    with pad_token_id, zeros in position_ids, and False in loss_mask. We don't want pads to be included in the gradient
    calculation graphs.
    """
    # your code goes here


result = pack_no_incomplete(examples, 4, 0)
assert np.allclose(result[0]["input_ids"], np.array([0, 1, 10, 10], dtype=np.uint32))
assert np.allclose(result[0]["position_ids"], np.array([0, 1, 0, 1], dtype=np.uint32))
assert np.allclose(result[0]["loss_mask"], np.array([True, True, True, True], dtype=np.bool))
assert np.allclose(result[1]["input_ids"], np.array([11, 0, 0, 0], dtype=np.uint32))
assert np.allclose(result[1]["position_ids"], np.array([0, 0, 0, 0], dtype=np.uint32))
assert np.allclose(result[1]["loss_mask"], np.array([True, False, False, False], dtype=np.bool))

In [None]:
PackingStrategy = Literal["dense", "no_incomplete"]


def pack_sequences(
    rows: Iterable[TokenizedRow], seq_len: int, pad_token_id: int, strategy: Literal["dense", "no_incomplete"]
) -> list[TokenizedRow]:
    if strategy == "dense":
        return pack_dense(rows, seq_len)
    elif strategy == "no_incomplete":
        return pack_no_incomplete(rows, seq_len, pad_token_id)
    else:
        raise ValueError(f"Strategy `{strategy}` is not supported")

# Using Map-Reduce for Preparing Datasets

The tokenization + packing can be scaled by using MapReduce. In mapper, we tokenize the examples and assign them some "group" ID (e.g., a hash % n_groups). Then, we run the reduce by that group ID and pack the sequences that appeared in the same group.

In [None]:
class PrepareForTrainingTransformConfig(MapReduceTransformConfigBase):
    processing_strategy: ProcessingStrategy
    packing_strategy: PackingStrategy
    tokenizer_name: str
    seq_len: int
    pad_token_id: int | None = None
    fim_kwargs: dict[str, Any] = {}
    num_reduce_chunks: int

In [None]:
import hashlib


def array_hash(arr: np.ndarray) -> int:
    """
    Calculate a hash of a NumPy array using its string representation.
    Includes dtype and shape to ensure distinct arrays have distinct hashes.
    """
    arr_bytes = arr.tobytes()
    meta = f"{arr.shape}-{arr.dtype}".encode()
    digest = hashlib.sha256(meta + arr_bytes).digest()
    return int.from_bytes(digest, byteorder="big")


class PrepareForTrainingTransform(MapReduceTransformBase):
    class _MapFn:
        def __init__(self, config: PrepareForTrainingTransformConfig):
            self.config = config
            self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)

        def __call__(self, row: Row) -> Iterator[TokenizedRow]:
            result = ...
            # the key is the array hash mod number of reduce chunks
            key = array_hash(result["input_ids"]) % self.config.num_reduce_chunks
            yield key, result

    class _ReduceFn:
        def __init__(self, config: PrepareForTrainingTransformConfig):
            self.config = config
            self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
            self.pad_token_id = (
                self.config.pad_token_id if self.config.pad_token_id is not None else tokenizer.eos_token_id
            )

        def __call__(self, key: Key, rows: Iterable[TokenizedRow]) -> Iterator[TokenizedRow]:
            """
            Packs the sequences that ended up in the same reduce chunk
            """
            yield from ...

    def __init__(self, config: PrepareForTrainingTransformConfig):
        super().__init__(
            map_fn=self._MapFn(config),
            reduce_fn=self._ReduceFn(config),
            config=config,
        )


test_config = PrepareForTrainingTransformConfig(
    processing_strategy="pretrain",
    packing_strategy="dense",
    tokenizer_name="JetBrains/Mellum-4b-base",
    seq_len=8,
    num_reduce_chunks=2,
)

examples = [
    {"content": "Hello, World!"},
    {"content": "A test string"},
    {"content": "import numpy as np"},
    {"content": "import pandas as pd"},
]

transform = PrepareForTrainingTransform(test_config)
result = transform.run(examples)

assert np.allclose(result[0]["input_ids"], np.array([10626, 38, 8098, 27, 0, 653, 17984, 609], dtype=np.uint32))
assert np.allclose(result[1]["input_ids"], np.array([59, 247, 110, 95, 109, 110, 987, 0], dtype=np.uint32))
assert np.allclose(result[0]["position_ids"], np.array([0, 1, 2, 3, 4, 0, 1, 2], dtype=np.uint32))
assert np.allclose(result[1]["position_ids"], np.array([0, 1, 2, 3, 4, 5, 6, 7], dtype=np.uint32))

Let's try running it over a real dataset.

In [None]:
from datasets import load_dataset

ds = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-10BT", split="train", streaming=True)

rows = []
idx = 0
for row in ds:
    rows.append({"content": row["text"]})
    idx += 1
    if idx > 1000:
        break

In [None]:
config = PrepareForTrainingTransformConfig(
    processing_strategy="pretrain",
    packing_strategy="dense",
    tokenizer_name="JetBrains/Mellum-4b-base",
    seq_len=1025,
    num_reduce_chunks=4,
)

transform = PrepareForTrainingTransform(config)
result = transform.run(rows)
assert len(result) == 1053

# Making a Dataset Object

Now let's build an actual PyTorch dataset that we can use for training. For it, we'll need:
1. Input tokens: the ones for which we'll predict the next ones (LM objective). All the tokens except the last because for the last we don't have the next.
2. Target tokens (classification "labels"): the ground truth next tokens. All the tokens except the first. So, the input and label tokens are shifted by 1.
3. Loss mask for labels
4. Position IDs of input tokens
5. Attention mask. A square matrix in which the entry $i$-th row contains information about whether the token $i$ should look at the token $j$ during attention calculation. It can easily be constructed from position IDs because they are just consecutive but when the $i+1$-th element isn't greater that $i$-th, this means we have the next sequence in the packing. Since before we pathched position IDs, so they always start from 0, we can just find all the zeros in position IDs and they will correspond to begginings of segments in the packing.

In [None]:
def make_attention_mask_from_positions(
    position_ids: UInt32[np.typing.ArrayLike, "seq_len"],
) -> Bool[np.typing.ArrayLike, "seq_len seq_len"]:
    """
    Creates an attention mask matrix from `position_ids`

    :param np.ndarray position_ids: An array representing the position IDs of tokens in a sequence.
        It should have a shape of (seq_len,).
    :return: An array representing the attention mask. It has a shape of (1, seq_len, seq_len).
        The elements of the attention mask are binary values indicating whether each token in
        the sequence should be attended to (True) or not (False).
    """
    # your code goes here


assert np.allclose(
    make_attention_mask_from_positions(np.array([0, 1, 2, 0, 1], dtype=np.uint32)),
    np.array(
        [
            [True, False, False, False, False],
            [True, True, False, False, False],
            [True, True, True, False, False],
            [False, False, False, True, False],
            [False, False, False, True, True],
        ]
    ),
)

In [None]:
import torch
from torch.utils.data import Dataset

Inputs = UInt32[np.typing.ArrayLike, "seq_len"]
Labels = UInt32[np.typing.ArrayLike, "seq_len"]
PositionIds = UInt32[np.typing.ArrayLike, "seq_len"]
AttentionMask = Bool[np.typing.ArrayLike, "seq_len seq_len"]
LossMask = Bool[np.typing.ArrayLike, "seq_len"]
LMExample = dict[str, Inputs | Labels | PositionIds | AttentionMask | LossMask]


class LMDataset(Dataset):
    def __init__(self, rows: list[TokenizedRow]):
        super().__init__()
        self.rows = rows

    def __len__(self) -> int:
        return len(self.rows)

    def __getitem__(self, idx: int) -> LMExample:
        """
        Language modeling objective.
        tokens: are input tokens, all except the last one
        labels: target next tokens, all except the first one
        attention_mask: since we're doing causal language modeling
        loss_mask: whether loss should be calculated on specific `target` tokens
        position_ids: positions of input tokens
        """
        record = self.rows[idx]
        text = record["input_ids"]
        # your code goes here

        return {
            "tokens": tokens.contiguous(),
            "labels": labels.contiguous(),
            "attention_mask": attention_mask,
            "loss_mask": loss_mask,
            "position_ids": position_ids,
        }


dataset = LMDataset(result)

assert isinstance(dataset[0]["tokens"], torch.Tensor)
assert dataset[0]["tokens"].shape[0] == 1024
assert dataset[0]["labels"].shape[0] == 1024
assert dataset[0]["attention_mask"].shape[1] == 1024
assert dataset[0]["loss_mask"].shape[0] == 1024
assert dataset[0]["position_ids"].shape[0] == 1024

assert dataset[0]["tokens"][0] == 89963
assert dataset[0]["tokens"][-1] == 360
assert dataset[0]["labels"][0] == 52
assert dataset[0]["labels"][-1] == 3045
assert dataset[0]["position_ids"][-1] == 364

Finally, a dataloader. We'll need to write a data collator that stacks all the examples.

In [None]:
from torch.utils.data import DataLoader

InputsBatch = UInt32[np.typing.ArrayLike, "bs seq_len"]
LabelsBatch = UInt32[np.typing.ArrayLike, "bs seq_len"]
PositionIdsBatch = UInt32[np.typing.ArrayLike, "bs seq_len"]
AttentionMaskBatch = Bool[np.typing.ArrayLike, "bs seq_len seq_len"]
LossMaskBatch = Bool[np.typing.ArrayLike, "bs seq_len"]
LMExampleBatch = dict[str, InputsBatch | LabelsBatch | PositionIdsBatch | AttentionMaskBatch | LossMaskBatch]


def collate_lm(batch: list[LMExample]) -> LMExampleBatch:
    """
    Stacks all the examples in the batch
    """
    # your code goes here


loader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    drop_last=True,
    num_workers=0,
    pin_memory=False,
    collate_fn=collate_lm,
)

batch = next(iter(loader))
assert batch["tokens"].shape == (8, 1024)
assert batch["labels"].shape == (8, 1024)
assert batch["position_ids"].shape == (8, 1024)
assert batch["attention_mask"].ndim == 3 and batch["attention_mask"].shape[0] == 8