In [2]:
!pip install mmap_dataset

[31mERROR: Could not find a version that satisfies the requirement mmap_dataset (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for mmap_dataset[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [None]:
import os
import struct
import numpy as np
from functools import lru_cache
from itertools import accumulate

# Define data types
dtypes = {
    1: np.uint8,
    2: np.int8,
    3: np.int16,
    4: np.int32,
    5: np.int64,
    6: np.float32,
    7: np.float64,
    8: np.uint16,
}

def index_file_path(prefix_path):
    return prefix_path + ".idx"

def _warmup_mmap_file(path):
    with open(path, "rb") as stream:
        while stream.read(100 * 1024 * 1024):
            pass

class MMapIndexedDataset:
    class Index:
        _HDR_MAGIC = b"MMIDIDX\x00\x00"

        @classmethod
        def writer(cls, path, dtype):
            class _Writer:
                def __enter__(self):
                    self._file = open(path, "wb")

                    # Write Magic string so we can check the file format then opening it again.
                    self._file.write(cls._HDR_MAGIC)
                    # Write version number
                    # Little endian unsigned 64 Bit integer
                    self._file.write(struct.pack("<Q", 1))
                    # Little endian unsigned 8 Bit integer
                    self._file.write(struct.pack("<B", code(dtype)))

                    return self

                @staticmethod
                def _get_pointers(sizes):
                    pointers = np.zeros(len(sizes), dtype=np.int64)
                    sizes = np.array(sizes, dtype=np.int64)

                    np.cumsum(sizes[:-1], out=pointers[1:])
                    pointers = pointers * dtype().itemsize
                    return pointers

                def write(self, sizes, doc_idx):
                    pointers = self._get_pointers(sizes)

                    # Little endian unsigned 64 Bit integer
                    self._file.write(struct.pack("<Q", len(sizes)))
                    # Little endian unsigned 64 Bit integer
                    self._file.write(struct.pack("<Q", len(doc_idx)))

                    sizes = np.array(sizes, dtype=np.int32)
                    self._file.write(sizes.tobytes(order="C"))
                    del sizes

                    pointers = np.array(pointers, dtype=np.int64)
                    self._file.write(pointers.tobytes(order="C"))
                    del pointers

                    doc_idx = np.array(doc_idx, dtype=np.int64)
                    self._file.write(doc_idx.tobytes(order="C"))

                def __exit__(self, exc_type, exc_val, exc_tb):
                    self._file.close()

            return _Writer()

        def __init__(self, path, skip_warmup=False):
            with open(path, "rb") as stream:
                magic_test = stream.read(9)
                assert self._HDR_MAGIC == magic_test, (
                    "Index file doesn't match expected format. "
                    "Make sure that --dataset-impl is configured properly."
                )
                # Little endian unsigned 64 Bit integer
                version = struct.unpack("<Q", stream.read(8))
                assert (1,) == version

                # Little endian unsigned 8 Bit integer
                (dtype_code,) = struct.unpack("<B", stream.read(1))
                self._dtype = dtypes[dtype_code]
                self._dtype_size = self._dtype().itemsize

                self._len = struct.unpack("<Q", stream.read(8))[0]
                self._doc_count = struct.unpack("<Q", stream.read(8))[0]
                offset = stream.tell()

            if not skip_warmup:
                print("Warming up index mmap file...")
                _warmup_mmap_file(path)

            self._bin_buffer_mmap = np.memmap(path, mode="r", order="C")
            self._bin_buffer = memoryview(self._bin_buffer_mmap)
            print("Reading sizes...")
            self._sizes = np.frombuffer(
                self._bin_buffer, dtype=np.int32, count=self._len, offset=offset
            )
            print("Reading pointers...")
            self._pointers = np.frombuffer(
                self._bin_buffer,
                dtype=np.int64,
                count=self._len,
                offset=offset + self._sizes.nbytes,
            )
            print("Reading document index...")
            self._doc_idx = np.frombuffer(
                self._bin_buffer,
                dtype=np.int64,
                count=self._doc_count,
                offset=offset + self._sizes.nbytes + self._pointers.nbytes,
            )

        def __del__(self):
            self._bin_buffer_mmap._mmap.close()
            del self._bin_buffer_mmap

        @property
        def dtype(self):
            return self._dtype

        @property
        def sizes(self):
            return self._sizes

        @property
        def doc_idx(self):
            return self._doc_idx

        @lru_cache(maxsize=8)
        def __getitem__(self, i):
            return self._pointers[i], self._sizes[i]

        def __len__(self):
            return self._len

    def __init__(self, path, skip_warmup=False):
        super().__init__()

        self._path = None
        self._index = None
        self._bin_buffer = None

        if path.endswith(".bin") or path.endswith(".idx"):
            path = path[:-4]

        self._do_init(path, skip_warmup)

    def __getstate__(self):
        return self._path

    def __setstate__(self, state):
        self._do_init(state)

    def _do_init(self, path, skip_warmup):
        self._path = path
        self._index = self.Index(index_file_path(self._path), skip_warmup)

        if not skip_warmup:
            print("Warming up data mmap file...")
            for i in range(20):
                _warmup_mmap_file(f"{self._path}-{i:05d}-of-00020.bin")
        print("Creating numpy buffers of mmap...")
        self._bin_buffers = [
            np.memmap(f"{self._path}-{i:05d}-of-00020.bin", mode="r", order="C")
            for i in range(20)
        ]
        print("Creating memory views of numpy buffers...")
        self._bin_buffer_views = [memoryview(b) for b in self._bin_buffers]

    def __del__(self):
        for b in self._bin_buffers:
            b._mmap.close()
        del self._bin_buffers
        del self._index

    def __len__(self):
        return len(self._index)

    def __getitem__(self, idx):
        if isinstance(idx, int):
            ptr, size = self._index[idx]
            shard_idx = ptr // 1024 ** 3
            shard_offset = ptr % 1024 ** 3
            np_array = np.frombuffer(
                self._bin_buffer_views[shard_idx], dtype=self._index.dtype, count=size, offset=shard_offset
            )
            return np_array
        elif isinstance(idx, slice):
            start, stop, step = idx.indices(len(self))
            if step != 1:
                raise ValueError("Slices into indexed_dataset must be contiguous")
            ptr = self._index._pointers[start]
            sizes = self._index._sizes[idx]
            offsets = list(accumulate(sizes))
            total_size = sum(sizes)
            shard_idx = ptr // 1024 ** 3
            shard_offset = ptr % 1024 ** 3
            np_array = np.frombuffer(
                self._bin_buffer_views[shard_idx], dtype=self._index.dtype, count=total_size, offset=shard_offset
            )
            return np_array.reshape(-1, 2049)

    def get(self, idx, offset=0, length=None):
        """Retrieves a single item from the dataset with the option to only
        return a portion of the item.

        get(idx) is the same as [idx] but get() does not support slicing.
        """
        ptr, size = self._index[idx]
        if length is None:
            length = size - offset
        shard_idx = ptr // 1024 ** 3
        shard_offset = ptr % 1024 ** 3
        ptr = shard_offset + offset * np.dtype(self._index.dtype).itemsize
        np_array = np.frombuffer(
            self._bin_buffer_views[shard_idx], dtype=self._index.dtype, count=length, offset=ptr
        )
        return np_array

    @property
    def sizes(self):
        return self._index.sizes

    @property
    def doc_idx(self):
        return self._index.doc_idx

    def get_doc_idx(self):
        return self._index._doc_idx

    def set_doc_idx(self, doc_idx_):
        self._index._doc_idx = doc_idx_

    @property
    def supports_prefetch(self):
        return False

    @staticmethod
    def exists(path):
        return os.path.exists(index_file_path(path)) and any(
            os.path.exists(f"{path}-{i:05d}-of-00020.bin") for i in range(20)
        )

def load_partial_index_file(load_path, start_iteration=0, end_iteration=10):
    print(f"Loading entries from index file {load_path}, from iteration {start_iteration} to {end_iteration}...")
    dataset = MMapIndexedDataset(load_path, skip_warmup=True)
    indices = dataset[start_iteration*1024: end_iteration*1024]
    print(f"Loaded indices shape: {indices.shape}")
    return indices

def print_example_text(indices, num_samples=5):
    print("Printing example texts...")
    for i in range(min(num_samples, len(indices))):
        text = indices[i]
        print(f"Example {i+1}:\n{text}\n")

if __name__ == "__main__":
    load_path = "/workspace/data/datasets--EleutherAI--pile-standard-pythia-preshuffled/snapshots/bac79b6820adb34e451f9a02cc1dc7cd920febf0/document"
    start_iteration = 0  # Start iteration can be adjusted
    end_iteration = 10   # Adjust the end iteration to load more data
    print(f"Starting script with load path: {load_path}")
    
    indices = load_partial_index_file(load_path, start_iteration=start_iteration, end_iteration=end_iteration)
    print("Indices loaded successfully.")
    
    print_example_text(indices)
    print("Finished printing example text.")
