In [30]:
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [31]:
class NumpyDataset(Dataset):
    """ Dataset for numpy arrays with explicit memmap support """

    def __init__(self, *arrays, dtype=torch.float):
        self.dtype = dtype
        self.memmap = []
        self.data = []
        self.n = None

        for array in arrays:
            if self.n is None:
                self.n = array.shape[0]
            assert array.shape[0] == self.n

            if isinstance(array, np.memmap):
                self.memmap.append(True)
                self.data.append(array)
            else:
                self.memmap.append(False)
                tensor = torch.from_numpy(array).to(self.dtype)
                self.data.append(tensor)

    def __getitem__(self, index):
        items = []
        for memmap, array in zip(self.memmap, self.data):
            if memmap:
                tensor = np.array(array[index])
                items.append(torch.from_numpy(tensor).to(self.dtype))
            else:
                items.append(array[index])
        return tuple(items)

    def __len__(self):
        return self.n

In [116]:
def make_dataloaders(dataset, validation_split, batch_size, num_workers=1, pin_memory=True, seed=None):
    if validation_split is None or validation_split <= 0.0:
        train_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=pin_memory,
            num_workers=num_workers,
        ) 
        val_loader = None
    else:
        assert 0.0 < validation_split < 1.0, "Wrong validation split: {}".format(validation_split)

        n_samples = len(dataset)
        indices = list(range(n_samples))
        split = int(np.floor(validation_split * n_samples))
        if seed is not None:
            np.random.seed(seed)
        np.random.shuffle(indices)
        train_idx, valid_idx = indices[split:], indices[:split]

        train_sampler = SubsetRandomSampler(train_idx)
        val_sampler = SubsetRandomSampler(valid_idx)

        train_loader = DataLoader(
            dataset,
            sampler=train_sampler,
            batch_size=batch_size,
            pin_memory=pin_memory,
            num_workers=num_workers,
        ) 
        val_loader = DataLoader(
            dataset,
            sampler=val_sampler,
            batch_size=batch_size,
            pin_memory=pin_memory,
            num_workers=num_workers,
        )  

    return train_loader, val_loader

In [249]:
n_files = 4

from collections import OrderedDict
data = OrderedDict()

data['x'] = [np.load("../data/samples/x_train_ModelO_gamma_default_{}.npy".format(i), mmap_mode="r") for i in range(n_files)]
data['x_aux'] = [np.load("../data/samples/x_aux_train_ModelO_gamma_default_{}.npy".format(i)) for i in range(n_files)]
data['theta'] = [np.load("../data/samples/theta_train_ModelO_gamma_default_{}.npy".format(i)) for i in range(n_files)]


In [250]:
import six

def make_datasets(data):
    datasets = []
    for i in range(len(data[next(iter(data))])):
        data_arrays = []
        for key, value in six.iteritems(data):
            data_arrays.append(value[i])
        datasets.append(NumpyDataset(*data_arrays, dtype=torch.float))
    return ConcatDataset(datasets)

In [269]:
data_arrays_list = []
data_labels = []
for key, value in six.iteritems(data):
    data_labels.append(key)
    data_arrays_list.append(value)

In [270]:
train_loader, val_loader = make_dataloaders(make_datasets(data), 0.1, 64, num_workers=0, pin_memory=False

In [268]:
theta_z_score, x_z_score, x_aux_z_score = [], [], []
n_batches = 4
for i, batch in enumerate(train_loader):
    if i < n_batches:
        x, x_aux, theta = batch
        x_z_score.append(x)
        x_aux_z_score.append(x_aux)
        theta_z_score.append(theta)
    else:
        break

In [281]:
del a

In [286]:
a = np.memmap('test.npy', dtype='float32', mode='w+', shape=(500000, 64,64))

In [288]:
a[:20, :] = 12.

In [289]:
a

memmap([[[12., 12., 12., ..., 12., 12., 12.],
         [12., 12., 12., ..., 12., 12., 12.],
         [12., 12., 12., ..., 12., 12., 12.],
         ...,
         [12., 12., 12., ..., 12., 12., 12.],
         [12., 12., 12., ..., 12., 12., 12.],
         [12., 12., 12., ..., 12., 12., 12.]],

        [[12., 12., 12., ..., 12., 12., 12.],
         [12., 12., 12., ..., 12., 12., 12.],
         [12., 12., 12., ..., 12., 12., 12.],
         ...,
         [12., 12., 12., ..., 12., 12., 12.],
         [12., 12., 12., ..., 12., 12., 12.],
         [12., 12., 12., ..., 12., 12., 12.]],

        [[12., 12., 12., ..., 12., 12., 12.],
         [12., 12., 12., ..., 12., 12., 12.],
         [12., 12., 12., ..., 12., 12., 12.],
         ...,
         [12., 12., 12., ..., 12., 12., 12.],
         [12., 12., 12., ..., 12., 12., 12.],
         [12., 12., 12., ..., 12., 12., 12.]],

        ...,

        [[ 0.,  0.,  0., ...,  0.,  0.,  0.],
         [ 0.,  0.,  0., ...,  0.,  0.,  0.],
         [ 0.,  0.

In [293]:
b = np.memmap('test.npy', dtype='float32', mode='c', shape=(500000, 64,64))

In [294]:
b.shape

(500000, 64, 64)