-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
dataset.py
494 lines (424 loc) · 17.6 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
"""Dataset examples for loading individual data points
Authors
* Aku Rouhe 2020
* Samuele Cornell 2020
"""
import copy
import contextlib
from types import MethodType
from torch.utils.data import Dataset
from speechbrain.utils.data_pipeline import DataPipeline
from speechbrain.dataio.dataio import load_data_json, load_data_csv
from speechbrain.utils.data_utils import batch_shuffle
import logging
import math
logger = logging.getLogger(__name__)
class DynamicItemDataset(Dataset):
"""Dataset that reads, wrangles, and produces dicts.
Each data point dict provides some items (by key), for example, a path to a
wavefile with the key "wav_file". When a data point is fetched from this
Dataset, more items are produced dynamically, based on pre-existing items
and other dynamic created items. For example, a dynamic item could take the
wavfile path and load the audio from the disk.
The dynamic items can depend on other dynamic items: a suitable evaluation
order is used automatically, as long as there are no circular dependencies.
A specified list of keys is collected in the output dict. These can be items
in the original data or dynamic items. If some dynamic items are not
requested, nor depended on by other requested items, they won't be computed.
So for example if a user simply wants to iterate over the text, the
time-consuming audio loading can be skipped.
About the format:
Takes a dict of dicts as the collection of data points to read/wrangle.
The top level keys are data point IDs.
Each data point (example) dict should have the same keys, corresponding to
different items in that data point.
Altogether the data collection could look like this:
>>> data = {
... "spk1utt1": {
... "wav_file": "/path/to/spk1utt1.wav",
... "text": "hello world",
... "speaker": "spk1",
... },
... "spk1utt2": {
... "wav_file": "/path/to/spk1utt2.wav",
... "text": "how are you world",
... "speaker": "spk1",
... }
... }
NOTE
----
The top-level key, the data point id, is implicitly added as an item
in the data point, with the key "id"
Each dynamic item is configured by three things: a key, a func, and a list
of argkeys. The key should be unique among all the items (dynamic or not) in
each data point. The func is any callable, and it returns the dynamic item's
value. The callable is called with the values of other items as specified
by the argkeys list (as positional args, passed in the order specified by
argkeys).
The dynamic_items configuration could look like this:
>>> import torch
>>> dynamic_items = [
... {"func": lambda l: torch.Tensor(l),
... "takes": ["wav_loaded"],
... "provides": "wav"},
... {"func": lambda path: [ord(c)/100 for c in path], # Fake "loading"
... "takes": ["wav_file"],
... "provides": "wav_loaded"},
... {"func": lambda t: t.split(),
... "takes": ["text"],
... "provides": "words"}]
With these, different views of the data can be loaded:
>>> from speechbrain.dataio.dataloader import SaveableDataLoader
>>> from speechbrain.dataio.batch import PaddedBatch
>>> dataset = DynamicItemDataset(data, dynamic_items)
>>> dataloader = SaveableDataLoader(dataset, collate_fn=PaddedBatch,
... batch_size=2)
>>> # First, create encoding for words:
>>> dataset.set_output_keys(["words"])
>>> encoding = {}
>>> next_id = 1
>>> for batch in dataloader:
... for sent in batch.words:
... for word in sent:
... if word not in encoding:
... encoding[word] = next_id
... next_id += 1
>>> # Next, add an encoded words_tensor dynamic item:
>>> dataset.add_dynamic_item(
... func = lambda ws: torch.tensor([encoding[w] for w in ws],
... dtype=torch.long),
... takes = ["words"],
... provides = "words_encoded")
>>> # Now we can get word and audio tensors:
>>> dataset.set_output_keys(["id", "wav", "words_encoded"])
>>> batch = next(iter(dataloader))
>>> batch.id
['spk1utt1', 'spk1utt2']
>>> batch.wav # +ELLIPSIS
PaddedData(data=tensor([[0.4700, 1.1200, ...
>>> batch.words_encoded
PaddedData(data=tensor([[1, 2, 0, 0],
[3, 4, 5, 2]]), lengths=tensor([0.5000, 1.0000]))
Output keys can also be a map:
>>> dataset.set_output_keys({"id":"id", "signal": "wav", "words": "words_encoded"})
>>> batch = next(iter(dataloader))
>>> batch.words
PaddedData(data=tensor([[1, 2, 0, 0],
[3, 4, 5, 2]]), lengths=tensor([0.5000, 1.0000]))
Arguments
---------
data : dict
Dictionary containing single data points (e.g. utterances).
dynamic_items : list, optional
Configuration for the dynamic items produced when fetching an example.
List of DynamicItems or dicts with the format::
func: <callable> # To be called
takes: <list> # key or list of keys of args this takes
provides: key # key or list of keys that this provides
output_keys : dict, list, optional
List of keys (either directly available in data or dynamic items)
to include in the output dict when data points are fetched.
If a dict is given; it is used to map internal keys to output keys.
From the output_keys dict key:value pairs the key appears outside,
and value is the internal key.
"""
def __init__(
self, data, dynamic_items=[], output_keys=[],
):
self.data = data
self.data_ids = list(self.data.keys())
static_keys = list(self.data[self.data_ids[0]].keys())
if "id" in static_keys:
raise ValueError("The key 'id' is reserved for the data point id.")
else:
static_keys.append("id")
self.pipeline = DataPipeline(static_keys, dynamic_items)
self.set_output_keys(output_keys)
def __len__(self):
return len(self.data_ids)
def __getitem__(self, index):
data_id = self.data_ids[index]
data_point = self.data[data_id]
return self.pipeline.compute_outputs({"id": data_id, **data_point})
def add_dynamic_item(self, func, takes=None, provides=None):
"""Makes a new dynamic item available on the dataset.
Two calling conventions. For DynamicItem objects, just use:
add_dynamic_item(dynamic_item).
But otherwise, should use:
add_dynamic_item(func, takes, provides).
See `speechbrain.utils.data_pipeline`.
Arguments
---------
func : callable, DynamicItem
If a DynamicItem is given, adds that directly. Otherwise a
DynamicItem is created, and this specifies the callable to use. If
a generator function is given, then create a GeneratorDynamicItem.
Otherwise creates a normal DynamicItem.
takes : list, str
List of keys. When func is called, each key is resolved to
either an entry in the data or the output of another dynamic_item.
The func is then called with these as positional arguments,
in the same order as specified here.
A single arg can be given directly.
provides : str
Unique key or keys that this provides.
"""
self.pipeline.add_dynamic_item(func, takes, provides)
def set_output_keys(self, keys):
"""Use this to change the output keys.
These are the keys that are actually evaluated when a data point
is fetched from the dataset.
Arguments
---------
keys : dict, list
List of keys (str) to produce in output.
If a dict is given; it is used to map internal keys to output keys.
From the output_keys dict key:value pairs the key appears outside,
and value is the internal key.
"""
self.pipeline.set_output_keys(keys)
@contextlib.contextmanager
def output_keys_as(self, keys):
"""Context manager to temporarily set output keys.
Example
-------
>>> dataset = DynamicItemDataset({"a":{"x":1,"y":2},"b":{"x":3,"y":4}},
... output_keys = ["x"])
>>> with dataset.output_keys_as(["y"]):
... print(dataset[0])
{'y': 2}
>>> print(dataset[0])
{'x': 1}
NOTE
----
Not thread-safe. While in this context manager, the output keys
are affected for any call.
"""
saved_output = self.pipeline.output_mapping
self.pipeline.set_output_keys(keys)
yield self
self.pipeline.set_output_keys(saved_output)
def filtered_sorted(
self,
key_min_value={},
key_max_value={},
key_test={},
sort_key=None,
reverse=False,
select_n=None,
):
"""Get a filtered and/or sorted version of this, shares static data.
The reason to implement these operations in the same method is that
computing some dynamic items may be expensive, and this way the
filtering and sorting steps don't need to compute the dynamic items
twice.
Arguments
---------
key_min_value : dict
Map from key (in data or in dynamic items) to limit, will only keep
data_point if data_point[key] >= limit
key_max_value : dict
Map from key (in data or in dynamic items) to limit, will only keep
data_point if data_point[key] <= limit
key_test : dict
Map from key (in data or in dynamic items) to func, will only keep
data_point if bool(func(data_point[key])) == True
sort_key : None, str
If not None, sort by data_point[sort_key]. Default is ascending
order.
reverse : bool
If True, sort in descending order.
select_n : None, int
If not None, only keep (at most) the first n filtered data_points.
The possible sorting is applied, but only on the first n data
points found. Meant for debugging.
Returns
-------
FilteredSortedDynamicItemDataset
Shares the static data, but has its own output keys and
dynamic items (initially deep copied from this, so they have the
same dynamic items available)
NOTE
----
Temporarily changes the output keys!
"""
filtered_sorted_ids = self._filtered_sorted_ids(
key_min_value, key_max_value, key_test, sort_key, reverse, select_n,
)
return FilteredSortedDynamicItemDataset(
self, filtered_sorted_ids
) # NOTE: defined below
def _filtered_sorted_ids(
self,
key_min_value={},
key_max_value={},
key_test={},
sort_key=None,
reverse=False,
select_n=None,
):
"""Returns a list of data ids, fulfilling the sorting and filtering."""
def combined_filter(computed):
"""Applies filter."""
for key, limit in key_min_value.items():
# NOTE: docstring promises >= so using that.
# Mathematically could also use < for nicer syntax, but
# maybe with some super special weird edge case some one can
# depend on the >= operator
if computed[key] >= limit:
continue
return False
for key, limit in key_max_value.items():
if computed[key] <= limit:
continue
return False
for key, func in key_test.items():
if bool(func(computed[key])):
continue
return False
return True
temp_keys = (
set(key_min_value.keys())
| set(key_max_value.keys())
| set(key_test.keys())
| set([] if sort_key is None else [sort_key])
)
filtered_ids = []
with self.output_keys_as(temp_keys):
for i, data_id in enumerate(self.data_ids):
if select_n is not None and len(filtered_ids) == select_n:
break
data_point = self.data[data_id]
data_point["id"] = data_id
computed = self.pipeline.compute_outputs(data_point)
if combined_filter(computed):
if sort_key is not None:
# Add (main sorting index, current index, data_id)
# So that we maintain current sorting and don't compare
# data_id values ever.
filtered_ids.append((computed[sort_key], i, data_id))
else:
filtered_ids.append(data_id)
if sort_key is not None:
filtered_sorted_ids = [
tup[2] for tup in sorted(filtered_ids, reverse=reverse)
]
else:
filtered_sorted_ids = filtered_ids
return filtered_sorted_ids
def overfit_test(self, sample_count, total_count):
"""Creates a subset of this dataset for an overfitting
test - repeating sample_count samples to create a repeating
dataset with a total of epoch_data_count samples
Argument
--------
sample_count: int
the number of samples to select
total_count: int
the total data count
Returns
-------
dataset: FilteredSortedDynamicItemDataset
a dataset with a repeated subset
"""
num_repetitions = math.ceil(total_count / sample_count)
overfit_samples = self.data_ids[:sample_count] * num_repetitions
overfit_samples = overfit_samples[:total_count]
return FilteredSortedDynamicItemDataset(self, overfit_samples)
def batch_shuffle(self, batch_size):
"""Shuffles batches within a dataset. This is particularly
useful in combination with length sorting - to ensure
that the length variation within a batch is not very high,
but the batches themselves remain randomized
Arguments
---------
batch_size: int
the batch size
Returns
-------
dataset: FilteredSortedDynamicItemDataset
a shuffled dataset
"""
data_ids = batch_shuffle(self.data_ids, batch_size)
return FilteredSortedDynamicItemDataset(self, data_ids)
@classmethod
def from_json(
cls, json_path, replacements={}, dynamic_items=[], output_keys=[]
):
"""Load a data prep JSON file and create a Dataset based on it."""
data = load_data_json(json_path, replacements)
return cls(data, dynamic_items, output_keys)
@classmethod
def from_csv(
cls, csv_path, replacements={}, dynamic_items=[], output_keys=[]
):
"""Load a data prep CSV file and create a Dataset based on it."""
data = load_data_csv(csv_path, replacements)
return cls(data, dynamic_items, output_keys)
@classmethod
def from_arrow_dataset(
cls, dataset, replacements={}, dynamic_items=[], output_keys=[]
):
"""Loading a prepared huggingface dataset"""
# define an unbound method to generate puesdo keys
def keys(self):
"Returns the keys."
return [i for i in range(dataset.__len__())]
# bind this method to arrow dataset
dataset.keys = MethodType(keys, dataset)
return cls(dataset, dynamic_items, output_keys)
class FilteredSortedDynamicItemDataset(DynamicItemDataset):
"""Possibly filtered, possibly sorted DynamicItemDataset.
Shares the static data (reference).
Has its own dynamic_items and output_keys (deepcopy).
"""
def __init__(self, from_dataset, data_ids):
self.data = from_dataset.data
self.data_ids = data_ids
self.pipeline = copy.deepcopy(from_dataset.pipeline)
@classmethod
def from_json(
cls, json_path, replacements={}, dynamic_items=None, output_keys=None
):
raise TypeError("Cannot create SubsetDynamicItemDataset directly!")
@classmethod
def from_csv(
cls, csv_path, replacements={}, dynamic_items=None, output_keys=None
):
raise TypeError("Cannot create SubsetDynamicItemDataset directly!")
def add_dynamic_item(datasets, func, takes=None, provides=None):
"""Helper for adding the same item to multiple datasets."""
for dataset in datasets:
dataset.add_dynamic_item(func, takes, provides)
def set_output_keys(datasets, output_keys):
"""Helper for setting the same item to multiple datasets."""
for dataset in datasets:
dataset.set_output_keys(output_keys)
def apply_overfit_test(
overfit_test,
overfit_test_sample_count,
overfit_test_epoch_data_count,
dataset,
):
"""Applies the overfit test to the specified dataset,
as configured in the hyperparameters file
Arguments
---------
overfit_test: bool
when True the overfitting test is performed
overfit_test_sample_count: int
number of samples for the overfitting test
overfit_test_epoch_data_count: int
number of epochs for the overfitting test
dataset: DynamicItemDataset
the dataset
Returns
-------
dataset: DynamicItemDataset
the dataset, with the overfit test apply
"""
if overfit_test:
sample_count = overfit_test_sample_count
epoch_data_count = overfit_test_epoch_data_count
dataset = dataset.overfit_test(sample_count, epoch_data_count)
return dataset