In [495]:
import numpy as np
import torch
from torch.utils.data import IterableDataset, DataLoader, DistributedSampler
from torchdata.datapipes.iter import Shuffler

class LocalDataset(IterableDataset):
    def __init__(self, file_path, context_size):
        self.data = np.memmap(file_path, dtype=np.uint16, mode="r")
        self.context_size = context_size

    def __len__(self):
        return len(self.data) - self.context_size - 1

    def __iter__(self):
        sampler = DistributedSampler(self, num_replicas=1, rank=0, shuffle=False)
        sampler_iter = iter(sampler)
        next_sample = next(sampler_iter)
        for idx in range(len(self.data) - self.context_size - 1):
            if idx == next_sample:
                next_sample = next(sampler_iter)
                x = torch.from_numpy(
                    (self.data[idx : idx + self.context_size]).astype(np.int64)
                )
                y = torch.from_numpy(
                    (self.data[idx + 1 : idx + self.context_size + 1]).astype(np.int64)
                )
                yield x, y
            else:
                continue


class ParallelDataset(IterableDataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __iter__(self):
        sampler = DistributedSampler(self.dataset, num_replicas=1, rank=0, shuffle=False)
        for sample in iter(sampler), iter(self.dataset):
            yield sample


path = '/Users/yifeiyan/yif-AI/datasets/full_harry_potter/full_harry_potter_train.bin'
context_size = 5
dataset = LocalDataset(path, context_size)
parallel_dataset = ParallelDataset(dataset)
parallel_dataset = Shuffler(parallel_dataset, buffer_size=2)
loader = DataLoader(parallel_dataset, batch_size=5, num_workers=0)
a = iter(loader)


In [496]:
next(a)

tensor([0, 1, 2, 3, 4])