In [1]:
from megatron.data.indexed_dataset import MMapIndexedDataset
from pathlib import Path
import numpy as np
from torch.utils.data import Dataset


class GPT2Dataset(Dataset):
    """Streamlined version of the GPT2Dataset in megatron."""
    def __init__(
        self,
        indexed_dataset: MMapIndexedDataset,
        doc_idx: np.memmap,
        sample_idx: np.memmap,
        shuffle_idx: np.memmap,
    ):
        self.indexed_dataset = indexed_dataset
        self.doc_idx = doc_idx
        self.sample_idx = sample_idx
        self.shuffle_idx = shuffle_idx

        self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1
        self.sample_idx_len = self.sample_idx.shape[0] - 1

        if self.shuffle_idx_len != self.sample_idx_len:
            print(f"WARNING: {self.shuffle_idx_len=} != {self.sample_idx_len=}")

    def __len__(self):
        return min(self.shuffle_idx_len, self.sample_idx_len)

    def __getitem__(self, idx):
        try:
            # Get the shuffled index.
            idx = self.shuffle_idx[idx]
            # Start and end documents and offsets.
            doc_index_f = self.sample_idx[idx][0]
            doc_index_l = self.sample_idx[idx + 1][0]
            offset_f = self.sample_idx[idx][1]
            offset_l = self.sample_idx[idx + 1][1]
            # If we are within the same document, just extract the chunk.
            if doc_index_f == doc_index_l:
                sample = self.indexed_dataset.get(
                    self.doc_idx[doc_index_f],
                    offset=offset_f,
                    length=offset_l - offset_f + 1,
                )
            else:
                # Otherwise, get the rest of the initial document.
                sample_list = [
                    self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)
                ]
                # Loop over all in between documents and add the entire document.
                for i in range(doc_index_f + 1, doc_index_l):
                    sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
                # And finally add the relevant portion of last document.
                sample_list.append(
                    self.indexed_dataset.get(
                        self.doc_idx[doc_index_l], length=offset_l + 1
                    )
                )
                sample = np.concatenate(sample_list)

            return {"text": np.array(sample, dtype=np.int64)}
        except IndexError:
            new_idx = idx % len(self)
            print(
                f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})"
            )
            return self[new_idx]

In [28]:
def read_dataset(file_path: Path, pattern: str) -> GPT2Dataset:
    # e.g., pile_20B_tokenizer_text_document_train_indexmap_120ns_2048sl_1234s_doc_idx.npy
    # pattern: pile_20B_tokenizer_text_document_train_indexmap_120ns_2048sl_1234s
    
    doc_idx = np.load(file_path / f"{pattern}_doc_idx.npy", allow_pickle=True, mmap_mode="r")
    sample_idx = np.load(file_path / f"{pattern}_sample_idx.npy", allow_pickle=True, mmap_mode="r")
    shuffle_idx = np.load(file_path / f"{pattern}_shuffle_idx.npy", allow_pickle=True, mmap_mode="r")
    indexed_dataset = MMapIndexedDataset(str(file_path / "pile_20B_tokenizer_text_document"), skip_warmup=True)

    ds = GPT2Dataset(
        indexed_dataset=indexed_dataset,
        doc_idx=doc_idx,
        sample_idx=sample_idx,
        shuffle_idx=shuffle_idx,
    )

    # check seqlen is correct
    print(f"Seq length ==", len(ds[0]["text"]))
    print(f"Num batches ==", len(ds) / 1024, "(should be 143k)")

    return ds

In [29]:
file_path = Path("/mnt/ssd-2/pile_deduped/")
pattern = "pile_20B_tokenizer_text_document_train_indexmap_120ns_2048sl_1234s"
ds = read_dataset(file_path, pattern)

    reading sizes...
    reading pointers...
    reading document index...
    creating numpy buffer of mmap...
    creating memory view of numpy buffer...
Seq length == 2049
Num batches == 153464.0029296875 (should be 143k)


In [45]:
file_path = Path("/mnt/ssd-2/pile_deduped/")
pattern = "pile_20B_tokenizer_text_document_train_indexmap_960ns_2048sl_1234s"
_ = read_dataset(file_path, pattern)

    reading sizes...
    reading pointers...
    reading document index...
    creating numpy buffer of mmap...
    creating memory view of numpy buffer...
Seq length == 2049
Num batches == 153464.0029296875 (should be 143k)


In [30]:
file_path = Path("/mnt/ssd-2/pile_extra_seeds/")
pattern = "pile_20B_tokenizer_text_document_train_0_indexmap_258ns_2048sl_1s"
oskar_ds = read_dataset(file_path, pattern)

    reading sizes...
    reading pointers...
    reading document index...
    creating numpy buffer of mmap...
    creating memory view of numpy buffer...
Seq length == 2049
Num batches == 158364.9267578125 (should be 143k)


In [36]:
len(oskar_ds.indexed_dataset) == len(ds.indexed_dataset)

True

In [39]:
len(oskar_ds.sample_idx) == len(ds.sample_idx)

False

In [42]:
oskar_ds.indexed_dataset.sizes.shape[0] == ds.indexed_dataset.sizes.shape[0]

True