In [25]:
import torch, sys, glob, os, math
from pathlib import Path
import numpy as np
from multiprocessing import Pool
from copy import deepcopy
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
from typing import Optional


# DataSets and DataLoaders

In [26]:
sample_locs = np.load("data/sample_locs.npy")


In [27]:
class CC100DataSet(torch.utils.data.Dataset):
    def __init__(self, file_path, num_workers):
        super().__init__()
        self.sample_locs = sample_locs
        self.file_handles = [open(file_path, "rb") for _ in range(num_workers)]
        self.file_handles.append(open(file_path, "rb"))

    def __len__(self):
        return len(self.sample_locs)

    def __getitem__(self, idx):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            worker_id = 0
        else:
            worker_id = worker_info.id
        self.file_handles[worker_id].seek(self.sample_locs[idx])
        count = int.from_bytes(
            self.file_handles[worker_id].read(2), byteorder=sys.byteorder, signed=False
        )
        arr = np.frombuffer(
            self.file_handles[worker_id].read(count * 4), count=count, dtype=np.int32
        )
        return arr


In [42]:
class CC100DataModule(pl.LightningDataModule):
    def __init__(self, dataset: torch.utils.data.Dataset, path: str, num_workers: int = 0):
        super().__init__()
        self.dataset = dataset(path, num_workers)
        self.num_workers = num_workers

    def prepare_data(self):
        pass

    def setup(self, stage: Optional[str] = None):
        if stage == "fit" or stage is None:
            train_len = int(len(self.dataset) * 0.8)
            self.dataset_train, self.dataset_val = random_split(
                self.dataset, [train_len, len(self.dataset) - train_len]
            )

        if stage == "test" or stage is None:
            self.dataset_test = self.dataset

        if stage == "predict" or stage is None:
            self.dataset_predict = self.dataset

    def train_dataloader(self):
        return DataLoader(
            self.dataset_train,
            batch_size=32,
            collate_fn=self._collate_wrapper,
            num_workers=self.num_workers,
            prefetch_factor=4,
        )

    def val_dataloader(self):
        return DataLoader(
            self.dataset_val,
            batch_size=32,
            collate_fn=self._collate_wrapper,
            num_workers=self.num_workers,
            prefetch_factor=4,
        )

    def test_dataloader(self):
        return DataLoader(
            self.dataset_test,
            batch_size=32,
            collate_fn=self._collate_wrapper,
            num_workers=self.num_workers,
            prefetch_factor=4,
        )

    def predict_dataloader(self):
        return DataLoader(
            self.dataset_predict,
            batch_size=32,
            collate_fn=self._collate_wrapper,
            num_workers=self.num_workers,
            prefetch_factor=4,
        )

    def _collate_wrapper(self, batch):
        return batch


In [43]:
data_module = CC100DataModule(CC100DataSet, "data/ru_small.bin",4)
data_module.prepare_data()
data_module.setup()
data_loader = data_module.train_dataloader()

In [45]:
# next(iter(dl))
c=0
for _ in data_loader:
    list(map(lambda x: torch.Tensor(x),_))
    c+=1
c

15182