In [4]:
import numpy as np
from datasets import load_from_disk, concatenate_datasets
from tqdm.auto import tqdm


def merge_tokens(dataset_paths: list[str], output_path: str):
    all_ds = concatenate_datasets([load_from_disk(p) for p in dataset_paths])

    total_tokens = sum(len(ex["tokens"]) for ex in all_ds)
    print(f"Total tokens: {total_tokens:,}")

    dtype = np.uint16  # or uint32 if vocab > 65k

    arr = np.memmap(output_path, dtype=dtype, mode='w+', shape=(total_tokens,))

    BUFFER_SIZE = 10_000_000  # tokens — tune: 5–50M depending on RAM (≈20–200 MB buffer)
    pos = 0
    buffer = []

    with tqdm(total=total_tokens, unit="tok", desc="Writing") as pbar:
        for ex in all_ds:
            buffer.extend(ex["tokens"])  # or append np.array() if you prefer
            pbar.update(len(ex["tokens"]))

            while len(buffer) >= BUFFER_SIZE:
                chunk = np.array(buffer[:BUFFER_SIZE], dtype=dtype)
                arr[pos: pos + len(chunk)] = chunk
                pos += len(chunk)
                arr.flush()  # flush only after big chunk
                buffer = buffer[BUFFER_SIZE:]

        # final partial buffer
        if buffer:
            chunk = np.array(buffer, dtype=dtype)
            arr[pos: pos + len(chunk)] = chunk
            arr.flush()

In [None]:
test_datasets = [
    "tokenized_data/robots_test",
]

merge_tokens(test_datasets, "tokenized_data/instruction_test.bin")

Total tokens: 102,137,758


Writing:   0%|          | 0/102137758 [00:00<?, ?tok/s]

In [None]:
test_datasets = [
    "tokenized_data/robots_train",
]

merge_tokens(test_datasets, "tokenized_data/instruction_train.bin")

Total tokens: 1,505,334,453


Writing:   0%|          | 0/1505334453 [00:00<?, ?tok/s]