/
dataset_builder.py
1150 lines (978 loc) · 44.7 KB
/
dataset_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2020 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""DatasetBuilder base class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import functools
import inspect
import itertools
import os
import sys
from absl import logging
import six
import tensorflow.compat.v2 as tf
from tensorflow_datasets.core import api_utils
from tensorflow_datasets.core import constants
from tensorflow_datasets.core import download
from tensorflow_datasets.core import lazy_imports_lib
from tensorflow_datasets.core import naming
from tensorflow_datasets.core import registered
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core import tfrecords_reader
from tensorflow_datasets.core import tfrecords_writer
from tensorflow_datasets.core import units
from tensorflow_datasets.core import utils
from tensorflow_datasets.core.utils import gcs_utils
from tensorflow_datasets.core.utils import read_config as read_config_lib
import termcolor
FORCE_REDOWNLOAD = download.GenerateMode.FORCE_REDOWNLOAD
REUSE_CACHE_IF_EXISTS = download.GenerateMode.REUSE_CACHE_IF_EXISTS
REUSE_DATASET_IF_EXISTS = download.GenerateMode.REUSE_DATASET_IF_EXISTS
GCS_HOSTED_MSG = """\
Dataset %s is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.
"""
class BuilderConfig(object):
"""Base class for `DatasetBuilder` data configuration.
DatasetBuilder subclasses with data configuration options should subclass
`BuilderConfig` and add their own properties.
"""
@api_utils.disallow_positional_args
def __init__(self, name, version=None, supported_versions=None,
description=None):
self._name = name
self._version = version
self._supported_versions = supported_versions or []
self._description = description
@property
def name(self):
return self._name
@property
def version(self):
return self._version
@property
def supported_versions(self):
return self._supported_versions
@property
def description(self):
return self._description
def __repr__(self):
return "<{cls_name} name={name}, version={version}>".format(
cls_name=type(self).__name__,
name=self.name,
version=self.version or "None")
@six.add_metaclass(registered.RegisteredDataset)
class DatasetBuilder(object):
"""Abstract base class for all datasets.
`DatasetBuilder` has 3 key methods:
* `tfds.DatasetBuilder.info`: documents the dataset, including feature
names, types, and shapes, version, splits, citation, etc.
* `tfds.DatasetBuilder.download_and_prepare`: downloads the source data
and writes it to disk.
* `tfds.DatasetBuilder.as_dataset`: builds an input pipeline using
`tf.data.Dataset`s.
**Configuration**: Some `DatasetBuilder`s expose multiple variants of the
dataset by defining a `tfds.core.BuilderConfig` subclass and accepting a
config object (or name) on construction. Configurable datasets expose a
pre-defined set of configurations in `tfds.DatasetBuilder.builder_configs`.
Typical `DatasetBuilder` usage:
```python
mnist_builder = tfds.builder("mnist")
mnist_info = mnist_builder.info
mnist_builder.download_and_prepare()
datasets = mnist_builder.as_dataset()
train_dataset, test_dataset = datasets["train"], datasets["test"]
assert isinstance(train_dataset, tf.data.Dataset)
# And then the rest of your input pipeline
train_dataset = train_dataset.repeat().shuffle(1024).batch(128)
train_dataset = train_dataset.prefetch(2)
features = tf.compat.v1.data.make_one_shot_iterator(train_dataset).get_next()
image, label = features['image'], features['label']
```
"""
# Name of the dataset, filled by metaclass based on class name.
name = None
# Semantic version of the dataset (ex: tfds.core.Version('1.2.0'))
VERSION = None
# List dataset versions which can be loaded using current code.
# Data can only be prepared with canonical VERSION or above.
SUPPORTED_VERSIONS = []
# Named configurations that modify the data generated by download_and_prepare.
BUILDER_CONFIGS = []
# Set to True for datasets that are under active development and should not
# be available through tfds.{load, builder} or documented in overview.md.
IN_DEVELOPMENT = False
# Must be set for datasets that use 'manual_dir' functionality - the ones
# that require users to do additional steps to download the data
# (this is usually due to some external regulations / rules).
#
# This field should contain a string with user instructions, including
# the list of files that should be present. It will be
# displayed in the dataset documentation.
MANUAL_DOWNLOAD_INSTRUCTIONS = None
@api_utils.disallow_positional_args
def __init__(self, data_dir=None, config=None, version=None):
"""Constructs a DatasetBuilder.
Callers must pass arguments as keyword arguments.
Args:
data_dir: `str`, directory to read/write data. Defaults to
"~/tensorflow_datasets".
config: `tfds.core.BuilderConfig` or `str` name, optional configuration
for the dataset that affects the data generated on disk. Different
`builder_config`s will have their own subdirectories and versions.
version: `str`. Optional version at which to load the dataset. An error is
raised if specified version cannot be satisfied. Eg: '1.2.3', '1.2.*'.
The special value "experimental_latest" will use the highest version,
even if not default. This is not recommended unless you know what you
are doing, as the version could be broken.
"""
# For pickling:
self._original_state = dict(data_dir=data_dir, config=config,
version=version)
# To do the work:
self._builder_config = self._create_builder_config(config)
# Extract code version (VERSION or config)
if not self._builder_config and not self.VERSION:
raise AssertionError(
"DatasetBuilder {} does not have a defined version. Please add a "
"`VERSION = tfds.core.Version('x.y.z')` to the class.".format(
self.name))
self._version = self._pick_version(version)
# Compute the base directory (for download) and dataset/version directory.
self._data_dir_root, self._data_dir = self._build_data_dir(data_dir)
if tf.io.gfile.exists(self._data_dir):
self.info.read_from_directory(self._data_dir)
else: # Use the code version (do not restore data)
self.info.initialize_from_bucket()
def __getstate__(self):
return self._original_state
def __setstate__(self, state):
self.__init__(**state)
@utils.memoized_property
def canonical_version(self):
if self._builder_config:
return self._builder_config.version
else:
return self.VERSION
@utils.memoized_property
def supported_versions(self):
if self._builder_config:
return self._builder_config.supported_versions
else:
return self.SUPPORTED_VERSIONS
@utils.memoized_property
def versions(self):
"""Versions (canonical + availables), in preference order."""
return [
utils.Version(v) if isinstance(v, six.string_types) else v
for v in [self.canonical_version] + self.supported_versions
]
def _pick_version(self, requested_version):
"""Returns utils.Version instance, or raise AssertionError."""
if requested_version == "experimental_latest":
return max(self.versions)
for version in self.versions:
if requested_version is None or version.match(requested_version):
return version
available_versions = [str(v) for v in self.versions]
msg = "Dataset {} cannot be loaded at version {}, only: {}.".format(
self.name, requested_version, ", ".join(available_versions))
raise AssertionError(msg)
@property
def version(self):
return self._version
@property
def data_dir(self):
return self._data_dir
@utils.memoized_property
def info(self):
"""`tfds.core.DatasetInfo` for this builder."""
# Ensure .info hasn't been called before versioning is set-up
# Otherwise, backward compatibility cannot be guaranteed as some code will
# depend on the code version instead of the restored data version
if not getattr(self, "_version", None):
# Message for developper creating new dataset. Will trigger if they are
# using .info in the constructor before calling super().__init__
raise AssertionError(
"Info should not been called before version has been defined. "
"Otherwise, the created .info may not match the info version from "
"the restored dataset.")
return self._info()
@api_utils.disallow_positional_args
def download_and_prepare(self, download_dir=None, download_config=None):
"""Downloads and prepares dataset for reading.
Args:
download_dir: `str`, directory where downloaded files are stored.
Defaults to "~/tensorflow-datasets/downloads".
download_config: `tfds.download.DownloadConfig`, further configuration for
downloading and preparing dataset.
Raises:
IOError: if there is not enough disk space available.
"""
download_config = download_config or download.DownloadConfig()
data_exists = tf.io.gfile.exists(self._data_dir)
if data_exists and download_config.download_mode == REUSE_DATASET_IF_EXISTS:
logging.info("Reusing dataset %s (%s)", self.name, self._data_dir)
return
if self.version.tfds_version_to_prepare:
available_to_prepare = ", ".join(str(v) for v in self.versions
if not v.tfds_version_to_prepare)
raise AssertionError(
"The version of the dataset you are trying to use ({}:{}) can only "
"be generated using TFDS code synced @ {} or earlier. Either sync to "
"that version of TFDS to first prepare the data or use another "
"version of the dataset (available for `download_and_prepare`: "
"{}).".format(
self.name, self.version, self.version.tfds_version_to_prepare,
available_to_prepare))
# Only `cls.VERSION` or `experimental_latest` versions can be generated.
# Otherwise, users may accidentally generate an old version using the
# code from newer versions.
installable_versions = {
str(v) for v in (self.canonical_version, max(self.versions))
}
if str(self.version) not in installable_versions:
msg = (
"The version of the dataset you are trying to use ({}) is too "
"old for this version of TFDS so cannot be generated."
).format(self.info.full_name)
if self.version.tfds_version_to_prepare:
msg += (
"{} can only be generated using TFDS code synced @ {} or earlier "
"Either sync to that version of TFDS to first prepare the data or "
"use another version of the dataset. "
).format(self.version, self.version.tfds_version_to_prepare)
else:
msg += (
"Either sync to a previous version of TFDS to first prepare the "
"data or use another version of the dataset. "
)
msg += "Available for `download_and_prepare`: {}".format(
list(sorted(installable_versions)))
raise ValueError(msg)
# Currently it's not possible to overwrite the data because it would
# conflict with versioning: If the last version has already been generated,
# it will always be reloaded and data_dir will be set at construction.
if data_exists:
raise ValueError(
"Trying to overwrite an existing dataset {} at {}. A dataset with "
"the same version {} already exists. If the dataset has changed, "
"please update the version number.".format(self.name, self._data_dir,
self.version))
logging.info("Generating dataset %s (%s)", self.name, self._data_dir)
if not utils.has_sufficient_disk_space(
self.info.dataset_size + self.info.download_size,
directory=self._data_dir_root):
raise IOError(
"Not enough disk space. Needed: {} (download: {}, generated: {})"
.format(
units.size_str(self.info.dataset_size + self.info.download_size),
units.size_str(self.info.download_size),
units.size_str(self.info.dataset_size),
))
self._log_download_bytes()
dl_manager = self._make_download_manager(
download_dir=download_dir,
download_config=download_config)
# Create a tmp dir and rename to self._data_dir on successful exit.
with utils.incomplete_dir(self._data_dir) as tmp_data_dir:
# Temporarily assign _data_dir to tmp_data_dir to avoid having to forward
# it to every sub function.
with utils.temporary_assignment(self, "_data_dir", tmp_data_dir):
if (download_config.try_download_gcs and
gcs_utils.is_dataset_on_gcs(self.info.full_name)):
logging.warning(GCS_HOSTED_MSG, self.name)
gcs_utils.download_gcs_dataset(self.info.full_name, self._data_dir)
self.info.read_from_directory(self._data_dir)
else:
self._download_and_prepare(
dl_manager=dl_manager,
download_config=download_config)
# NOTE: If modifying the lines below to put additional information in
# DatasetInfo, you'll likely also want to update
# DatasetInfo.read_from_directory to possibly restore these attributes
# when reading from package data.
splits = list(self.info.splits.values())
statistics_already_computed = bool(
splits and splits[0].statistics.num_examples)
# Update DatasetInfo metadata by computing statistics from the data.
if (download_config.compute_stats == download.ComputeStatsMode.SKIP or
download_config.compute_stats == download.ComputeStatsMode.AUTO
and statistics_already_computed
):
logging.info(
"Skipping computing stats for mode %s.",
download_config.compute_stats)
else: # Mode is forced or stats do not exists yet
logging.info("Computing statistics.")
self.info.compute_dynamic_properties()
self.info.download_size = dl_manager.downloaded_size
# Write DatasetInfo to disk, even if we haven't computed statistics.
self.info.write_to_directory(self._data_dir)
self._log_download_done()
@api_utils.disallow_positional_args
def as_dataset(
self,
split=None,
batch_size=None,
shuffle_files=False,
decoders=None,
read_config=None,
as_supervised=False,
):
# pylint: disable=line-too-long
"""Constructs a `tf.data.Dataset`.
Callers must pass arguments as keyword arguments.
The output types vary depending on the parameters. Examples:
```python
builder = tfds.builder('imdb_reviews')
builder.download_and_prepare()
# Default parameters: Returns the dict of tf.data.Dataset
ds_all_dict = builder.as_dataset()
assert isinstance(ds_all_dict, dict)
print(ds_all_dict.keys()) # ==> ['test', 'train', 'unsupervised']
assert isinstance(ds_all_dict['test'], tf.data.Dataset)
# Each dataset (test, train, unsup.) consists of dictionaries
# {'label': <tf.Tensor: .. dtype=int64, numpy=1>,
# 'text': <tf.Tensor: .. dtype=string, numpy=b"I've watched the movie ..">}
# {'label': <tf.Tensor: .. dtype=int64, numpy=1>,
# 'text': <tf.Tensor: .. dtype=string, numpy=b'If you love Japanese ..'>}
# With as_supervised: tf.data.Dataset only contains (feature, label) tuples
ds_all_supervised = builder.as_dataset(as_supervised=True)
assert isinstance(ds_all_supervised, dict)
print(ds_all_supervised.keys()) # ==> ['test', 'train', 'unsupervised']
assert isinstance(ds_all_supervised['test'], tf.data.Dataset)
# Each dataset (test, train, unsup.) consists of tuples (text, label)
# (<tf.Tensor: ... dtype=string, numpy=b"I've watched the movie ..">,
# <tf.Tensor: ... dtype=int64, numpy=1>)
# (<tf.Tensor: ... dtype=string, numpy=b"If you love Japanese ..">,
# <tf.Tensor: ... dtype=int64, numpy=1>)
# Same as above plus requesting a particular split
ds_test_supervised = builder.as_dataset(as_supervised=True, split='test')
assert isinstance(ds_test_supervised, tf.data.Dataset)
# The dataset consists of tuples (text, label)
# (<tf.Tensor: ... dtype=string, numpy=b"I've watched the movie ..">,
# <tf.Tensor: ... dtype=int64, numpy=1>)
# (<tf.Tensor: ... dtype=string, numpy=b"If you love Japanese ..">,
# <tf.Tensor: ... dtype=int64, numpy=1>)
```
Args:
split: `tfds.Split`, which subset(s) of the data to read. If None
(default), returns all splits in a dict
`<key: tfds.Split, value: tf.data.Dataset>`.
batch_size: `int`, batch size. Note that variable-length features will
be 0-padded if `batch_size` is set. Users that want more custom behavior
should use `batch_size=None` and use the `tf.data` API to construct a
custom pipeline. If `batch_size == -1`, will return feature
dictionaries of the whole dataset with `tf.Tensor`s instead of a
`tf.data.Dataset`.
shuffle_files: `bool`, whether to shuffle the input files. Defaults to
`False`.
decoders: Nested dict of `Decoder` objects which allow to customize the
decoding. The structure should match the feature structure, but only
customized feature keys need to be present. See
[the guide](https://github.com/tensorflow/datasets/tree/master/docs/decode.md)
for more info.
read_config: `tfds.ReadConfig`, Additional options to configure the
input pipeline (e.g. seed, num parallel reads,...).
as_supervised: `bool`, if `True`, the returned `tf.data.Dataset`
will have a 2-tuple structure `(input, label)` according to
`builder.info.supervised_keys`. If `False`, the default,
the returned `tf.data.Dataset` will have a dictionary with all the
features.
Returns:
`tf.data.Dataset`, or if `split=None`, `dict<key: tfds.Split, value:
tfds.data.Dataset>`.
If `batch_size` is -1, will return feature dictionaries containing
the entire dataset in `tf.Tensor`s instead of a `tf.data.Dataset`.
"""
# pylint: enable=line-too-long
logging.info("Constructing tf.data.Dataset for split %s, from %s",
split, self._data_dir)
if not tf.io.gfile.exists(self._data_dir):
raise AssertionError(
("Dataset %s: could not find data in %s. Please make sure to call "
"dataset_builder.download_and_prepare(), or pass download=True to "
"tfds.load() before trying to access the tf.data.Dataset object."
) % (self.name, self._data_dir_root))
# By default, return all splits
if split is None:
split = {s: s for s in self.info.splits}
read_config = read_config or read_config_lib.ReadConfig()
# Create a dataset for each of the given splits
build_single_dataset = functools.partial(
self._build_single_dataset,
shuffle_files=shuffle_files,
batch_size=batch_size,
decoders=decoders,
read_config=read_config,
as_supervised=as_supervised,
)
datasets = utils.map_nested(build_single_dataset, split, map_tuple=True)
return datasets
def _build_single_dataset(
self,
split,
shuffle_files,
batch_size,
decoders,
read_config,
as_supervised,
):
"""as_dataset for a single split."""
wants_full_dataset = batch_size == -1
if wants_full_dataset:
batch_size = self.info.splits.total_num_examples or sys.maxsize
# Build base dataset
ds = self._as_dataset(
split=split,
shuffle_files=shuffle_files,
decoders=decoders,
read_config=read_config,
)
# Auto-cache small datasets which are small enough to fit in memory.
if self._should_cache_ds(
split=split,
shuffle_files=shuffle_files,
read_config=read_config
):
ds = ds.cache()
if batch_size:
# Use padded_batch so that features with unknown shape are supported.
ds = ds.padded_batch(
batch_size, tf.compat.v1.data.get_output_shapes(ds))
if as_supervised:
if not self.info.supervised_keys:
raise ValueError(
"as_supervised=True but %s does not support a supervised "
"(input, label) structure." % self.name)
input_f, target_f = self.info.supervised_keys
ds = ds.map(lambda fs: (fs[input_f], fs[target_f]),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
# If shuffling is True and seeds not set, allow pipeline to be
# non-deterministic
# This code should probably be moved inside tfreader, such as
# all the tf.data.Options are centralized in a single place.
if (shuffle_files and
read_config.options.experimental_deterministic is None and
read_config.shuffle_seed is None):
options = tf.data.Options()
options.experimental_deterministic = False
ds = ds.with_options(options)
# If shuffle is False, keep the default value (deterministic), which
# allow the user to overwritte it.
if wants_full_dataset:
return tf.data.experimental.get_single_element(ds)
return ds
def _should_cache_ds(self, split, shuffle_files, read_config):
"""Returns True if TFDS should auto-cache the dataset."""
# The user can explicitly opt-out from auto-caching
if not read_config.try_autocache:
return False
# Skip datasets with unknown size.
# Even by using heuristic with `download_size` and
# `MANUAL_DOWNLOAD_INSTRUCTIONS`, it wouldn't catch datasets which hardcode
# the non-processed data-dir, nor DatasetBuilder not based on tf-record.
if not self.info.dataset_size:
return False
# Do not cache big datasets
# Instead of using the global size, we could infer the requested bytes:
# `self.info.splits[split].num_bytes`
# The info is available for full splits, and could be approximated
# for subsplits `train[:50%]`.
# However if the user is creating multiple small splits from a big
# dataset, those could adds up and fill up the entire RAM.
# 250 MiB is arbitrary picked. For comparison, Cifar10 is about 150 MiB.
if self.info.dataset_size > 250 * units.MiB:
return False
# We do not want to cache data which has more than one shards when
# shuffling is enabled, as this would effectivelly disable shuffling.
# An exception is for single shard (as shuffling is a no-op).
# Another exception is if reshuffle is disabled (shuffling already cached)
num_shards = len(self.info.splits[split].file_instructions)
if (shuffle_files and
# Shuffling only matter when reshuffle is True or None (default)
read_config.shuffle_reshuffle_each_iteration is not False and # pylint: disable=g-bool-id-comparison
num_shards > 1):
return False
# If the dataset satisfy all the right conditions, activate autocaching.
return True
def _relative_data_dir(self, with_version=True):
"""Relative path of this dataset in data_dir."""
builder_data_dir = self.name
builder_config = self._builder_config
if builder_config:
builder_data_dir = os.path.join(builder_data_dir, builder_config.name)
if not with_version:
return builder_data_dir
version_data_dir = os.path.join(builder_data_dir, str(self._version))
return version_data_dir
def _build_data_dir(self, given_data_dir):
"""Return the data directory for the current version.
Args:
given_data_dir: `Optional[str]`, root `data_dir` passed as
`__init__` argument.
Returns:
data_dir_root: `str`, The root dir containing all datasets, downloads,...
data_dir: `str`, The version data_dir
(e.g. `<data_dir_root>/<ds_name>/<config>/<version>`)
"""
builder_dir = self._relative_data_dir(with_version=False)
version_dir = self._relative_data_dir(with_version=True)
# If the data dir is explicitly given, no need to search everywhere.
if given_data_dir:
default_data_dir = os.path.expanduser(given_data_dir)
all_data_dirs = [default_data_dir]
else:
default_data_dir = os.path.expanduser(constants.DATA_DIR)
all_data_dirs = constants.list_data_dirs()
all_version_dirs = set()
requested_version_dirs = {}
for data_dir_root in all_data_dirs:
# List all existing versions
all_version_dirs.update(
_list_all_version_dirs(os.path.join(data_dir_root, builder_dir)))
# Check for existance of the requested dir
requested_version_dir = os.path.join(data_dir_root, version_dir)
if requested_version_dir in all_version_dirs:
requested_version_dirs[data_dir_root] = requested_version_dir
if len(requested_version_dirs) > 1:
raise ValueError(
"Dataset was found in more than one directory: {}. Please resolve "
"the ambiguity by explicitly specifying `data_dir=`."
"".format(requested_version_dirs))
elif len(requested_version_dirs) == 1: # The dataset is found once
return next(iter(requested_version_dirs.items()))
# No dataset found, use default directory
data_dir = os.path.join(default_data_dir, version_dir)
if all_version_dirs:
logging.warning(
"Found a different version of the requested dataset:\n"
"%s\n"
"Using %s instead.",
"\n".join(sorted(all_version_dirs)),
data_dir
)
return default_data_dir, data_dir
def _log_download_done(self):
msg = ("Dataset {name} downloaded and prepared to {data_dir}. "
"Subsequent calls will reuse this data.").format(
name=self.name,
data_dir=self._data_dir,
)
termcolor.cprint(msg, attrs=["bold"])
def _log_download_bytes(self):
# Print is intentional: we want this to always go to stdout so user has
# information needed to cancel download/preparation if needed.
# This comes right before the progress bar.
termcolor.cprint(
"Downloading and preparing dataset {} (download: {}, generated: {}, "
"total: {}) to {}...".format(
self.info.full_name,
units.size_str(self.info.download_size),
units.size_str(self.info.dataset_size),
units.size_str(self.info.download_size + self.info.dataset_size),
self._data_dir,
), attrs=["bold"])
@abc.abstractmethod
def _info(self):
"""Construct the DatasetInfo object. See `DatasetInfo` for details.
Warning: This function is only called once and the result is cached for all
following .info() calls.
Returns:
dataset_info: (DatasetInfo) The dataset information
"""
raise NotImplementedError
@abc.abstractmethod
def _download_and_prepare(self, dl_manager, download_config=None):
"""Downloads and prepares dataset for reading.
This is the internal implementation to overwrite called when user calls
`download_and_prepare`. It should download all required data and generate
the pre-processed datasets files.
Args:
dl_manager: (DownloadManager) `DownloadManager` used to download and cache
data.
download_config: `DownloadConfig`, Additional options.
"""
raise NotImplementedError
@abc.abstractmethod
def _as_dataset(
self, split, decoders=None, read_config=None, shuffle_files=False):
"""Constructs a `tf.data.Dataset`.
This is the internal implementation to overwrite called when user calls
`as_dataset`. It should read the pre-processed datasets files and generate
the `tf.data.Dataset` object.
Args:
split: `tfds.Split` which subset of the data to read.
decoders: Nested structure of `Decoder` object to customize the dataset
decoding.
read_config: `tfds.ReadConfig`
shuffle_files: `bool`, whether to shuffle the input files. Optional,
defaults to `False`.
Returns:
`tf.data.Dataset`
"""
raise NotImplementedError
def _make_download_manager(self, download_dir, download_config):
"""Creates a new download manager object."""
download_dir = download_dir or os.path.join(self._data_dir_root,
"downloads")
extract_dir = (download_config.extract_dir or
os.path.join(download_dir, "extracted"))
# Use manual_dir only if MANUAL_DOWNLOAD_INSTRUCTIONS are set.
if self.MANUAL_DOWNLOAD_INSTRUCTIONS:
manual_dir = (
download_config.manual_dir or os.path.join(download_dir, "manual"))
else:
manual_dir = None
return download.DownloadManager(
dataset_name=self.name,
download_dir=download_dir,
extract_dir=extract_dir,
manual_dir=manual_dir,
manual_dir_instructions=utils.dedent(self.MANUAL_DOWNLOAD_INSTRUCTIONS),
force_download=(download_config.download_mode == FORCE_REDOWNLOAD),
force_extraction=(download_config.download_mode == FORCE_REDOWNLOAD),
register_checksums=download_config.register_checksums,
)
@property
def builder_config(self):
"""`tfds.core.BuilderConfig` for this builder."""
return self._builder_config
def _create_builder_config(self, builder_config):
"""Create and validate BuilderConfig object."""
if builder_config is None and self.BUILDER_CONFIGS:
builder_config = self.BUILDER_CONFIGS[0]
logging.info("No config specified, defaulting to first: %s/%s", self.name,
builder_config.name)
if not builder_config:
return None
if isinstance(builder_config, six.string_types):
name = builder_config
builder_config = self.builder_configs.get(name)
if builder_config is None:
raise ValueError("BuilderConfig %s not found. Available: %s" %
(name, list(self.builder_configs.keys())))
name = builder_config.name
if not name:
raise ValueError("BuilderConfig must have a name, got %s" % name)
is_custom = name not in self.builder_configs
if is_custom:
logging.warning("Using custom data configuration %s", name)
else:
if builder_config is not self.builder_configs[name]:
raise ValueError(
"Cannot name a custom BuilderConfig the same as an available "
"BuilderConfig. Change the name. Available BuilderConfigs: %s" %
(list(self.builder_configs.keys())))
if not builder_config.version:
raise ValueError("BuilderConfig %s must have a version" % name)
if not builder_config.description:
raise ValueError("BuilderConfig %s must have a description" % name)
return builder_config
@utils.classproperty
@classmethod
@utils.memoize()
def builder_configs(cls):
"""Pre-defined list of configurations for this builder class."""
config_dict = {config.name: config for config in cls.BUILDER_CONFIGS}
if len(config_dict) != len(cls.BUILDER_CONFIGS):
names = [config.name for config in cls.BUILDER_CONFIGS]
raise ValueError(
"Names in BUILDER_CONFIGS must not be duplicated. Got %s" % names)
return config_dict
def _list_all_version_dirs(root_dir):
"""Lists all dataset versions present on disk."""
if not tf.io.gfile.exists(root_dir):
return []
def _is_version_valid(version):
try:
return utils.Version(version) and True
except ValueError: # Invalid version (ex: incomplete data dir)
return False
return [ # Return all version dirs
os.path.join(root_dir, version)
for version in tf.io.gfile.listdir(root_dir)
if _is_version_valid(version)
]
class FileAdapterBuilder(DatasetBuilder):
"""Base class for datasets with data generation based on file adapter."""
@utils.memoized_property
def _example_specs(self):
return self.info.features.get_serialized_info()
@property
def _tfrecords_reader(self):
return tfrecords_reader.Reader(self._data_dir, self._example_specs)
@abc.abstractmethod
def _split_generators(self, dl_manager):
"""Specify feature dictionary generators and dataset splits.
This function returns a list of `SplitGenerator`s defining how to generate
data and what splits to use.
Example:
return[
tfds.core.SplitGenerator(
name=tfds.Split.TRAIN,
gen_kwargs={'file': 'train_data.zip'},
),
tfds.core.SplitGenerator(
name=tfds.Split.TEST,
gen_kwargs={'file': 'test_data.zip'},
),
]
The above code will first call `_generate_examples(file='train_data.zip')`
to write the train data, then `_generate_examples(file='test_data.zip')` to
write the test data.
Datasets are typically split into different subsets to be used at various
stages of training and evaluation.
Note that for datasets without a `VALIDATION` split, you can use a
fraction of the `TRAIN` data for evaluation as you iterate on your model
so as not to overfit to the `TEST` data.
For downloads and extractions, use the given `download_manager`.
Note that the `DownloadManager` caches downloads, so it is fine to have each
generator attempt to download the source data.
A good practice is to download all data in this function, and then
distribute the relevant parts to each split with the `gen_kwargs` argument
Args:
dl_manager: (DownloadManager) Download manager to download the data
Returns:
`list<SplitGenerator>`.
"""
raise NotImplementedError()
@abc.abstractmethod
def _prepare_split(self, split_generator, **kwargs):
"""Generate the examples and record them on disk.
Args:
split_generator: `SplitGenerator`, Split generator to process
**kwargs: Additional kwargs forwarded from _download_and_prepare (ex:
beam pipeline)
"""
raise NotImplementedError()
def _make_split_generators_kwargs(self, prepare_split_kwargs):
"""Get kwargs for `self._split_generators()` from `prepare_split_kwargs`."""
del prepare_split_kwargs
return {}
def _download_and_prepare(self, dl_manager, **prepare_split_kwargs):
if not tf.io.gfile.exists(self._data_dir):
tf.io.gfile.makedirs(self._data_dir)
# Generating data for all splits
split_dict = splits_lib.SplitDict(dataset_name=self.name)
split_generators_kwargs = self._make_split_generators_kwargs(
prepare_split_kwargs)
for split_generator in self._split_generators(
dl_manager, **split_generators_kwargs):
if str(split_generator.split_info.name).lower() == "all":
raise ValueError(
"`all` is a special split keyword corresponding to the "
"union of all splits, so cannot be used as key in "
"._split_generator()."
)
logging.info("Generating split %s", split_generator.split_info.name)
split_dict.add(split_generator.split_info)
# Prepare split will record examples associated to the split
self._prepare_split(split_generator, **prepare_split_kwargs)
# Update the info object with the splits.
self.info.update_splits_if_different(split_dict)
def _as_dataset(
self,
split=splits_lib.Split.TRAIN,
decoders=None,
read_config=None,
shuffle_files=False):
ds = self._tfrecords_reader.read(
name=self.name,
instructions=split,
split_infos=self.info.splits.values(),
read_config=read_config,
shuffle_files=shuffle_files,
)
decode_fn = functools.partial(
self.info.features.decode_example, decoders=decoders)
ds = ds.map(decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return ds
class GeneratorBasedBuilder(FileAdapterBuilder):
"""Base class for datasets with data generation based on dict generators.
`GeneratorBasedBuilder` is a convenience class that abstracts away much
of the data writing and reading of `DatasetBuilder`. It expects subclasses to
implement generators of feature dictionaries across the dataset splits
`_split_generators`. See the method docstrings for details.
"""
@abc.abstractmethod
def _generate_examples(self, **kwargs):
"""Default function generating examples for each `SplitGenerator`.
This function preprocess the examples from the raw data to the preprocessed
dataset files.
This function is called once for each `SplitGenerator` defined in
`_split_generators`. The examples yielded here will be written on
disk.
Args:
**kwargs: `dict`, Arguments forwarded from the SplitGenerator.gen_kwargs
Yields:
key: `str` or `int`, a unique deterministic example identification key.
* Unique: An error will be raised if two examples are yield with the
same key.
* Deterministic: When generating the dataset twice, the same example
should have the same key.
Good keys can be the image id, or line number if examples are extracted
from a text file.
The key will be hashed and sorted to shuffle examples deterministically,
such as generating the dataset multiple times keep examples in the
same order.
example: `dict<str feature_name, feature_value>`, a feature dictionary
ready to be encoded and written to disk. The example will be
encoded with `self.info.features.encode_example({...})`.
"""
raise NotImplementedError()
def _download_and_prepare(self, dl_manager, download_config):
# Extract max_examples_per_split and forward it to _prepare_split
super(GeneratorBasedBuilder, self)._download_and_prepare(
dl_manager=dl_manager,
max_examples_per_split=download_config.max_examples_per_split,
)
def _prepare_split(self, split_generator, max_examples_per_split):
generator = self._generate_examples(**split_generator.gen_kwargs)