In [1]:
import torch
from hydra import initialize, compose
from hydra.utils import instantiate
from bliss.surveys.dc2 import DC2DataModule
import tqdm
from bliss.catalog import TileCatalog
from bliss.global_env import GlobalEnv
import time

In [2]:
with initialize(config_path=".", version_base=None):
    notebook_cfg = compose("notebook_config")
notebook_cfg.surveys.dc2.train_transforms = []
notebook_cfg.surveys.dc2.num_workers = 0

In [3]:
dc2: DC2DataModule = instantiate(notebook_cfg.surveys.dc2)
dc2.prepare_data()
dc2.setup("fit")

dc2_train_dataloader = dc2.train_dataloader()
device = torch.device("cpu")




In [5]:
GlobalEnv.seed_in_this_program = 0
GlobalEnv.current_encoder_epoch = 0
start_time = time.time()
for batch in tqdm.tqdm(dc2_train_dataloader):
    pass
end_time = time.time()
print(f"iterate data time: {end_time - start_time: .3f}s")

100%|██████████| 3047/3047 [04:35<00:00, 11.06it/s]

iterate data time:  275.401s





In [None]:
import math
import os
import random
import warnings
from typing import List

import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, IterableDataset
from torchvision import transforms

# prevent pytorch_lightning warning for num_workers = 2 in dataloaders with IterableDataset
warnings.filterwarnings(
    "ignore", ".*does not have many workers which may be a bottleneck.*", UserWarning
)
# an IterableDataset isn't supposed to have a __len__ method
warnings.filterwarnings("ignore", ".*Total length of .* across ranks is zero.*", UserWarning)


class MyIterableDataset(IterableDataset):
    def __init__(self, file_paths, shuffle=False, transform=None):
        self.file_paths = file_paths
        self.shuffle = shuffle
        self.transform = transform

    def get_stream(self, files):
        for file_path in files:
            examples = torch.load(file_path)

            # each training worker also shuffles the examples within each file
            if self.shuffle:
                random.shuffle(examples)

            for ex in examples:
                if self.transform is not None:
                    ex = self.transform(ex)
                yield ex

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()

        # shuffle files use for training each epoch
        files = self.file_paths.copy()
        if self.shuffle:
            random.shuffle(files)

        if worker_info is None:  # single-process data loading
            files_subset = files
        else:  # in a worker process
            # split files evenly amongst workers
            per_worker = int(math.ceil(len(files) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            files_subset = files[worker_id * per_worker : (worker_id + 1) * per_worker]

        return iter(self.get_stream(files_subset))


class CachedSimulatedDataset(pl.LightningDataModule):
    def __init__(
        self,
        splits: str,
        batch_size: int,
        num_workers: int,
        cached_data_path: str,
        train_transforms: List,
        nontrain_transforms: List,
    ):
        super().__init__()

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_transforms = train_transforms
        self.nontrain_transforms = nontrain_transforms

        file_names = [f for f in os.listdir(cached_data_path) if f.endswith(".pt")]
        self.file_paths = [os.path.join(cached_data_path, f) for f in file_names]

        # parse slices from percentages to indices
        self.slices = self.parse_slices(splits, len(self.file_paths))

    def _percent_to_idx(self, x, length):
        """Converts string in percent to an integer index."""
        return int(float(x.strip()) / 100 * length) if x.strip() else None

    def parse_slices(self, splits: str, length: int):
        slices = [slice(0, 0) for _ in range(3)]  # default to empty slice for each split
        for i, data_split in enumerate(splits.split("/")):
            # map "start_percent:stop_percent" to slice(start_idx, stop_idx)
            slices[i] = slice(*(self._percent_to_idx(val, length) for val in data_split.split(":")))
        return slices

    def train_dataloader(self):
        assert self.file_paths[self.slices[0]], "No cached data found"
        transform = transforms.Compose(self.train_transforms)
        my_dataset = MyIterableDataset(
            self.file_paths[self.slices[0]], transform=transform, shuffle=True
        )

        return DataLoader(
            my_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            worker_init_fn=random.seed,
        )

    def _get_nontrain_dataloader(self, file_paths_subset):
        assert file_paths_subset, "No cached data found"
        transform = transforms.Compose(self.nontrain_transforms)
        my_dataset = MyIterableDataset(file_paths_subset, transform=transform)
        return DataLoader(
            my_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            worker_init_fn=random.seed,
        )

    def val_dataloader(self):
        return self._get_nontrain_dataloader(self.file_paths[self.slices[1]])

    def test_dataloader(self):
        return self._get_nontrain_dataloader(self.file_paths[self.slices[2]])

    def predict_dataloader(self):
        return self._get_nontrain_dataloader(self.file_paths)

In [None]:
ori_train_dataloader = CachedSimulatedDataset("0:80/80:90/90:100", batch_size=64, num_workers=0, 
                                              cached_data_path="../../../bliss_output/dc2_cached_data", 
                                              train_transforms=[], nontrain_transforms=[]).train_dataloader()

In [None]:
start_time = time.time()
for batch in tqdm.tqdm(ori_train_dataloader):
    pass
end_time = time.time()
print(f"iterate data time: {end_time - start_time: .3f}s")