-
Notifications
You must be signed in to change notification settings - Fork 154
/
extradatasets.py
133 lines (101 loc) · 3.75 KB
/
extradatasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#
"""Train PyTorch models directly from POSIX tar archive.
Code works locally or over HTTP connections.
"""
from . import utils
from .pytorch import IterableDataset
from .utils import PipelineStage
class MockDataset(IterableDataset):
"""MockDataset.
A mock dataset for performance testing and unit testing.
"""
def __init__(self, sample, length):
"""Create a mock dataset instance.
:param sample: the sample to be returned repeatedly
:param length: the length of the mock dataset
"""
self.sample = sample
self.length = length
def __iter__(self):
"""Return an iterator over this mock dataset."""
for i in range(self.length):
yield self.sample
class repeatedly(IterableDataset, PipelineStage):
"""Repeatedly yield samples from a dataset."""
def __init__(self, source, nepochs=None, nbatches=None, length=None):
"""Create an instance of Repeatedly.
:param nepochs: repeat for a maximum of nepochs
:param nbatches: repeat for a maximum of nbatches
"""
self.source = source
self.length = length
self.nbatches = nbatches
def invoke(self, source):
"""Return an iterator that iterates repeatedly over a source."""
return utils.repeatedly(
source,
nepochs=self.nepochs,
nbatches=self.nbatches,
)
class with_epoch(IterableDataset):
"""Change the actual and nominal length of an IterableDataset.
This will continuously iterate through the original dataset, but
impose new epoch boundaries at the given length/nominal.
This exists mainly as a workaround for the odd logic in DataLoader.
It is also useful for choosing smaller nominal epoch sizes with
very large datasets.
"""
def __init__(self, dataset, length):
"""Chop the dataset to the given length.
:param dataset: IterableDataset
:param length: declared length of the dataset
:param nominal: nominal length of dataset (if different from declared)
"""
super().__init__()
self.length = length
self.source = None
def __getstate__(self):
"""Return the pickled state of the dataset.
This resets the dataset iterator, since that can't be pickled.
"""
result = dict(self.__dict__)
result["source"] = None
return result
def invoke(self, dataset):
"""Return an iterator over the dataset.
This iterator returns as many samples as given by the `length`
parameter.
"""
if self.source is None:
self.source = iter(dataset)
for i in range(self.length):
try:
sample = next(self.source)
except StopIteration:
self.source = iter(dataset)
try:
sample = next(self.source)
except StopIteration:
return
yield sample
self.source = None
class with_length(IterableDataset, PipelineStage):
"""Repeatedly yield samples from a dataset."""
def __init__(self, dataset, length):
"""Create an instance of Repeatedly.
:param dataset: source dataset
:param length: stated length
"""
super().__init__()
self.dataset = dataset
self.length = length
def invoke(self, dataset):
"""Return an iterator that iterates repeatedly over a source."""
return iter(dataset)
def __len__(self):
"""Return the user specified length."""
return self.length