In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import zarr
import numpy as np

In [2]:
class ZarrDataset(Dataset):
    '''
    Torch dataset backed by a zarr array. It is assumed that
    the first axis is the sample dimension.
    '''
    def __init__(self, store: str, minibatch_size: int=32, label_idx=12, **kwargs):
        # Open the underlying array. This particular store is a 2-D array with shape (n_rows, n_columns).
        self.array_ = zarr.open_array(store, mode="r", **kwargs)
        # Define how many samples we pull from the array at a time.
        self.minibatch_size_ = minibatch_size
        # Calculate length of the dataset. This "length" is the number of minibatches we can generate,
        # NOT the number of samples we can generate.
        self.len_ = int(np.ceil(self.array_.shape[0] / self.minibatch_size_))
        # Track the index of the mortality column because I was messy setting it up lol
        self.label_idx_ = label_idx

    def __len__(self):
        # Simply return the length of the dataset
        return self.len_

    def __getitem__(self, idx):
        # Calculate where in the array we will take a minibatch-sized sample.
        start = idx * self.minibatch_size_
        end = min(start + self.minibatch_size_, self.array_.shape[0])

        # This line is actually pulling data over the internet
        sel = self.array_[start:end, ...]

        # These lines reshape the data and turn it into a tensor that torch can use
        X = np.concatenate((sel[:, 0:self.label_idx_], sel[:, self.label_idx_+1:]), axis=1)
        y = sel[:, self.label_idx_]

        return torch.tensor(X), torch.tensor(y)

    @staticmethod
    def collator(Xy):
        # The output of __getitem__ is two arrays of shape (minibatch_size, n_columns) and 
        # (minibatch_size, 1). These are "collated" together to make new arrays of shape
        # (minibatch_size * batch_size, n_columns) and (minibatch_size * batch_size, 1).
        return (
            torch.cat([sample[0] for sample in Xy], dim=0),
            torch.cat([sample[1] for sample in Xy], dim=0)
        )

In [3]:
minibatch = 128
batch = 4
train_store = "gs://ads_training_data/single_pixel_year/training.zarr"
ds = ZarrDataset(train_store, minibatch_size=minibatch)

train_dataloader = DataLoader(
    ds, 
    batch_size=batch, 
    shuffle=True, 
    collate_fn=ZarrDataset.collator
)

In [4]:
# Time how long it takes to laod data and make sure
# that the shape of the array is what we expect.
%time X, y = next(iter(train_dataloader))
assert X.shape[0] == minibatch * batch
assert y.shape[0] == minibatch * batch

CPU times: user 372 ms, sys: 51.1 ms, total: 423 ms
Wall time: 896 ms


In [5]:
time_per_batch_ms = 779
batches = np.ceil(ds.array_.shape[0] / (minibatch*batch))
data_load_time = batches * time_per_batch_ms
print("Data loading time per epoch (sec):", data_load_time / 1000)

Data loading time per epoch (sec): 15188.163
