/
metadatasets.py
90 lines (80 loc) · 3 KB
/
metadatasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/008_data.metadatasets.ipynb.
# %% auto 0
__all__ = ['TSMetaDataset', 'TSMetaDatasets']
# %% ../../nbs/008_data.metadatasets.ipynb 3
from ..imports import *
from ..utils import *
from .validation import *
from .core import *
# %% ../../nbs/008_data.metadatasets.ipynb 4
class TSMetaDataset():
_type = (TSTensor,)
" A dataset capable of indexing mutiple datasets at the same time"
def __init__(self, dataset_list, **kwargs):
if not is_listy(dataset_list): dataset_list = [dataset_list]
self.datasets = dataset_list
self.split = kwargs['split'] if 'split' in kwargs else None
self.mapping = self._mapping()
if hasattr(dataset_list[0], 'loss_func'):
self.loss_func = dataset_list[0].loss_func
else:
self.loss_func = None
def __len__(self):
if self.split is not None:
return len(self.split)
else:
return sum([len(ds) for ds in self.datasets])
def __getitem__(self, idx):
if self.datasets:
if self.split is not None: idx = self.split[idx]
idx = listify(idx)
idxs = self.mapping[idx]
idxs = idxs[idxs[:, 0].argsort()]
self.mapping_idxs = idxs
ds = np.unique(idxs[:, 0])
b = [self.datasets[d][idxs[idxs[:, 0] == d, 1]] for d in ds]
output = tuple(map(torch.cat, zip(*b)))
output = self._type[0](output[0]), output[1]
return output
else:
return
def _mapping(self):
lengths = [len(ds) for ds in self.datasets]
idx_pairs = np.zeros((np.sum(lengths), 2)).astype(np.int32)
start = 0
for i,length in enumerate(lengths):
if i > 0:
idx_pairs[start:start+length, 0] = i
idx_pairs[start:start+length, 1] = np.arange(length)
start += length
return idx_pairs
def new_empty(self):
new_dset = type(self)(self.datasets, split=self.split)
new_dset.datasets = None
return new_dset
@property
def vars(self):
s = self.datasets[0][0][0] if not isinstance(self.datasets[0][0][0], tuple) else self.datasets[0][0][0][0]
return s.shape[-2]
@property
def len(self):
s = self.datasets[0][0][0] if not isinstance(self.datasets[0][0][0], tuple) else self.datasets[0][0][0][0]
return s.shape[-1]
@property
def vocab(self):
return self.datasets[0].vocab
@property
def cat(self): return hasattr(self, "vocab")
class TSMetaDatasets(FilteredBase):
def __init__(self, metadataset, splits):
store_attr()
self.mapping = metadataset.mapping
self.datasets = metadataset.datasets
def subset(self, i):
return type(self.metadataset)(self.metadataset.datasets, split=self.splits[i])
@property
def train(self):
return self.subset(0)
@property
def valid(self):
return self.subset(1)