In [30]:
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [31]:
class NumpyDataset(Dataset):
    """ Dataset for numpy arrays with explicit memmap support """

    def __init__(self, *arrays, dtype=torch.float):
        self.dtype = dtype
        self.memmap = []
        self.data = []
        self.n = None

        for array in arrays:
            if self.n is None:
                self.n = array.shape[0]
            assert array.shape[0] == self.n

            if isinstance(array, np.memmap):
                self.memmap.append(True)
                self.data.append(array)
            else:
                self.memmap.append(False)
                tensor = torch.from_numpy(array).to(self.dtype)
                self.data.append(tensor)

    def __getitem__(self, index):
        items = []
        for memmap, array in zip(self.memmap, self.data):
            if memmap:
                tensor = np.array(array[index])
                items.append(torch.from_numpy(tensor).to(self.dtype))
            else:
                items.append(array[index])
        return tuple(items)

    def __len__(self):
        return self.n

In [32]:
def make_dataloaders(dataset, validation_split, batch_size, num_workers=1, pin_memory=True, seed=None):
    if validation_split is None or validation_split <= 0.0:
        train_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=pin_memory,
            num_workers=num_workers,
        ) 
        val_loader = None
    else:
        assert 0.0 < validation_split < 1.0, "Wrong validation split: {}".format(validation_split)

        n_samples = len(dataset)
        indices = list(range(n_samples))
        split = int(np.floor(validation_split * n_samples))
        if seed is not None:
            np.random.seed(seed)
        np.random.shuffle(indices)
        train_idx, valid_idx = indices[split:], indices[:split]

        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(valid_idx)

        train_loader = DataLoader(
            dataset,
            sampler=train_sampler,
            batch_size=batch_size,
            pin_memory=pin_memory,
            num_workers=num_workers,
        ) 
        val_loader = DataLoader(
            dataset,
            sampler=val_sampler,
            batch_size=batch_size,
            pin_memory=pin_memory,
            num_workers=num_workers,
        )  

    return train_loader, val_loader

In [33]:
datasets = []
for i in range(5):
    x = np.load("../data/samples/x_train_ModelO_gamma_default_{}.npy".format(i), mmap_mode="r")
    x_aux = np.load("../data/samples/x_aux_train_ModelO_gamma_default_{}.npy".format(i))
    theta = np.load("../data/samples/theta_train_ModelO_gamma_default_{}.npy".format(i))
    datasets.append(NumpyDataset(x, x_aux, theta))

In [34]:
concat_dataset = ConcatDataset(datasets)

In [36]:
train_loader, val_loader = make_dataloaders(concat_dataset, 0.1, 2)

In [37]:
for batch in train_loader:
    print(batch)

[tensor([[[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]]]), tensor([[[2.1009, 2.5807]],

        [[1.9538, 2.3994]]]), tensor([[ 1.8175,  0.2346,  0.8662,  0.9101, 11.9363,  1.7165,  0.8127, 15.7822,
          1.3138, -4.8845, 16.2328,  2.8191,  2.1685, 15.5775,  1.6987, -8.2578,
         20.7385,  0.9238],
        [ 2.0953,  1.3828,  0.9456,  1.1062,  6.8554,  2.4969,  1.8113, 16.8100,
          1.1582, -7.4965, 17.8690,  1.0945,  1.6038, 15.9917,  1.4055, -3.5893,
         34.5689,  2.2626]])]
[tensor([[[0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.]]]), tensor([[[2.1150, 2.5909]],

        [[2.0811, 2.5695]]]), tensor([[ 0.3648,  1.0547,  0.2916,  1.1991, 10.5689,  5.5542,  1.1153, 13.0875,
          1.2392, -1.0709, 11.7635,  0.1278,  0.5847, 18.8948,  1.4850,  1.2052,
         15.1613,  1.0942],
        [ 1.5108,  0.8213,  0.2896,  0.3048, 11.9731,  3.0730,  0.9353, 14.6270,
          1.9865,  1.8233, 13.1451,  1.1871,  0.4610,

KeyboardInterrupt: 