In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import zarr
import numpy as np

In [2]:
class ZarrDataset(Dataset):
    '''
    Torch dataset backed by a zarr array. It is assumed that
    the first axis is the sample dimension.
    '''
    def __init__(self, store: str, minibatch_size: int=32, label_idx=12, **kwargs):
        self.array_ = zarr.open_array(store, mode="r", **kwargs)
        self.minibatch_size_ = minibatch_size
        self.len_ = int(np.ceil(self.array_.shape[0] / self.minibatch_size_))
        self.label_idx_ = label_idx

    def __len__(self):
        return self.len_

    def __getitem__(self, idx):
        start = idx * self.minibatch_size_
        end = min(start + self.minibatch_size_, self.array_.shape[0])
        sel = self.array_[start:end, ...]

        X = np.concatenate((sel[:, 0:self.label_idx_], sel[:, self.label_idx_+1:]), axis=1)
        y = sel[:, self.label_idx_]

        return torch.tensor(X), torch.tensor(y)

    @staticmethod
    def collator(Xy):
        return (
            torch.cat([sample[0] for sample in Xy], dim=0),
            torch.cat([sample[1] for sample in Xy], dim=0)
        )

In [3]:
minibatch = 128
batch = 4
train_store = "gs://ads_training_data/single_pixel_year/training.zarr"
ds = ZarrDataset(train_store, minibatch_size=minibatch)

train_dataloader = DataLoader(
    ds, 
    batch_size=batch, 
    shuffle=True, 
    collate_fn=ZarrDataset.collator
)

In [4]:
%time X, y = next(iter(train_dataloader))
assert X.shape[0] == minibatch * batch
assert y.shape[0] == minibatch * batch

CPU times: user 381 ms, sys: 53.4 ms, total: 434 ms
Wall time: 1.11 s


In [5]:
time_per_batch_ms = 779
batches = np.ceil(ds.array_.shape[0] / (minibatch*batch))
data_load_time = batches * time_per_batch_ms
print("Data loading time per epoch (sec):", data_load_time / 1000)

Data loading time per epoch (sec): 15188.163
