-
Notifications
You must be signed in to change notification settings - Fork 814
/
iterator.py
293 lines (252 loc) · 12 KB
/
iterator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
from __future__ import division
import math
import random
import logging
from .utils import RandomShuffler
from .batch import Batch
from .dataset import Dataset
logger = logging.getLogger(__name__)
class Iterator(object):
"""Defines an iterator that loads batches of data from a Dataset.
Attributes:
dataset: The Dataset object to load Examples from.
batch_size: Batch size.
batch_size_fn: Function of three arguments (new example to add, current
count of examples in the batch, and current effective batch size)
that returns the new effective batch size resulting from adding
that example to a batch. This is useful for dynamic batching, where
this function would add to the current effective batch size the
number of tokens in the new example.
sort_key: A key to use for sorting examples in order to batch together
examples with similar lengths and minimize padding. The sort_key
provided to the Iterator constructor overrides the sort_key
attribute of the Dataset, or defers to it if None.
train: Whether the iterator represents a train set.
repeat: Whether to repeat the iterator for multiple epochs. Default: False.
shuffle: Whether to shuffle examples between epochs.
sort: Whether to sort examples according to self.sort_key.
Note that shuffle and sort default to train and (not train).
sort_within_batch: Whether to sort (in descending order according to
self.sort_key) within each batch. If None, defaults to self.sort.
If self.sort is True and this is False, the batch is left in the
original (ascending) sorted order.
device (str or `torch.device`): A string or instance of `torch.device`
specifying which device the Variables are going to be created on.
If left as default, the tensors will be created on cpu. Default: None.
"""
def __init__(self, dataset, batch_size, sort_key=None, device=None,
batch_size_fn=None, train=True,
repeat=False, shuffle=None, sort=None,
sort_within_batch=None):
self.batch_size, self.train, self.dataset = batch_size, train, dataset
self.batch_size_fn = batch_size_fn
self.iterations = 0
self.repeat = repeat
self.shuffle = train if shuffle is None else shuffle
self.sort = not train if sort is None else sort
if sort_within_batch is None:
self.sort_within_batch = self.sort
else:
self.sort_within_batch = sort_within_batch
if sort_key is None:
self.sort_key = dataset.sort_key
else:
self.sort_key = sort_key
if type(device) == int:
logger.warning("The `device` argument should be set by using `torch.device`"
+ " or passing a string as an argument. This behavior will be"
+ " deprecated soon and currently defaults to cpu.")
device = None
self.device = device
self.random_shuffler = RandomShuffler()
# For state loading/saving only
self._iterations_this_epoch = 0
self._random_state_this_epoch = None
self._restored_from_state = False
@classmethod
def splits(cls, datasets, batch_sizes=None, **kwargs):
"""Create Iterator objects for multiple splits of a dataset.
Arguments:
datasets: Tuple of Dataset objects corresponding to the splits. The
first such object should be the train set.
batch_sizes: Tuple of batch sizes to use for the different splits,
or None to use the same batch_size for all splits.
Remaining keyword arguments: Passed to the constructor of the
iterator class being used.
"""
if batch_sizes is None:
batch_sizes = [kwargs.pop('batch_size')] * len(datasets)
ret = []
for i in range(len(datasets)):
train = i == 0
ret.append(cls(
datasets[i], batch_size=batch_sizes[i], train=train, **kwargs))
return tuple(ret)
def data(self):
"""Return the examples in the dataset in order, sorted, or shuffled."""
if self.sort:
xs = sorted(self.dataset, key=self.sort_key)
elif self.shuffle:
xs = [self.dataset[i] for i in self.random_shuffler(range(len(self.dataset)))]
else:
xs = self.dataset
return xs
def init_epoch(self):
"""Set up the batch generator for a new epoch."""
if self._restored_from_state:
self.random_shuffler.random_state = self._random_state_this_epoch
else:
self._random_state_this_epoch = self.random_shuffler.random_state
self.create_batches()
if self._restored_from_state:
self._restored_from_state = False
else:
self._iterations_this_epoch = 0
if not self.repeat:
self.iterations = 0
def create_batches(self):
self.batches = batch(self.data(), self.batch_size, self.batch_size_fn)
@property
def epoch(self):
return math.floor(self.iterations / len(self))
def __len__(self):
if self.batch_size_fn is not None:
raise NotImplementedError
return math.ceil(len(self.dataset) / self.batch_size)
def __iter__(self):
while True:
self.init_epoch()
for idx, minibatch in enumerate(self.batches):
# fast-forward if loaded from state
if self._iterations_this_epoch > idx:
continue
self.iterations += 1
self._iterations_this_epoch += 1
if self.sort_within_batch:
# NOTE: `rnn.pack_padded_sequence` requires that a minibatch
# be sorted by decreasing order, which requires reversing
# relative to typical sort keys
if self.sort:
minibatch.reverse()
else:
minibatch.sort(key=self.sort_key, reverse=True)
yield Batch(minibatch, self.dataset, self.device)
if not self.repeat:
return
def state_dict(self):
return {
"iterations": self.iterations,
"iterations_this_epoch": self._iterations_this_epoch,
"random_state_this_epoch": self._random_state_this_epoch}
def load_state_dict(self, state_dict):
self.iterations = state_dict["iterations"]
self._iterations_this_epoch = state_dict["iterations_this_epoch"]
self._random_state_this_epoch = state_dict["random_state_this_epoch"]
self._restored_from_state = True
class BPTTIterator(Iterator):
"""Defines an iterator for language modeling tasks that use BPTT.
Provides contiguous streams of examples together with targets that are
one timestep further forward, for language modeling training with
backpropagation through time (BPTT). Expects a Dataset with a single
example and a single field called 'text' and produces Batches with text and
target attributes.
Attributes:
dataset: The Dataset object to load Examples from.
batch_size: Batch size.
bptt_len: Length of sequences for backpropagation through time.
sort_key: A key to use for sorting examples in order to batch together
examples with similar lengths and minimize padding. The sort_key
provided to the Iterator constructor overrides the sort_key
attribute of the Dataset, or defers to it if None.
train: Whether the iterator represents a train set.
repeat: Whether to repeat the iterator for multiple epochs. Default: False.
shuffle: Whether to shuffle examples between epochs.
sort: Whether to sort examples according to self.sort_key.
Note that shuffle and sort default to train and (not train).
device (str or torch.device): A string or instance of `torch.device`
specifying which device the Variables are going to be created on.
If left as default, the tensors will be created on cpu. Default: None.
"""
def __init__(self, dataset, batch_size, bptt_len, **kwargs):
self.bptt_len = bptt_len
super(BPTTIterator, self).__init__(dataset, batch_size, **kwargs)
def __len__(self):
return math.ceil((len(self.dataset[0].text) / self.batch_size - 1)
/ self.bptt_len)
def __iter__(self):
text = self.dataset[0].text
TEXT = self.dataset.fields['text']
TEXT.eos_token = None
text = text + ([TEXT.pad_token] * int(math.ceil(len(text) / self.batch_size)
* self.batch_size - len(text)))
data = TEXT.numericalize(
[text], device=self.device)
data = data.view(self.batch_size, -1).t().contiguous()
dataset = Dataset(examples=self.dataset.examples, fields=[
('text', TEXT), ('target', TEXT)])
while True:
for i in range(0, len(self) * self.bptt_len, self.bptt_len):
self.iterations += 1
seq_len = min(self.bptt_len, len(data) - i - 1)
batch_text = data[i:i + seq_len]
batch_target = data[i + 1:i + 1 + seq_len]
if TEXT.batch_first:
batch_text = batch_text.t().contiguous()
batch_target = batch_target.t().contiguous()
yield Batch.fromvars(
dataset, self.batch_size,
text=batch_text,
target=batch_target)
if not self.repeat:
return
class BucketIterator(Iterator):
"""Defines an iterator that batches examples of similar lengths together.
Minimizes amount of padding needed while producing freshly shuffled
batches for each new epoch. See pool for the bucketing procedure used.
"""
def create_batches(self):
if self.sort:
self.batches = batch(self.data(), self.batch_size,
self.batch_size_fn)
else:
self.batches = pool(self.data(), self.batch_size,
self.sort_key, self.batch_size_fn,
random_shuffler=self.random_shuffler,
shuffle=self.shuffle,
sort_within_batch=self.sort_within_batch)
def batch(data, batch_size, batch_size_fn=None):
"""Yield elements from data in chunks of batch_size."""
if batch_size_fn is None:
def batch_size_fn(new, count, sofar):
return count
minibatch, size_so_far = [], 0
for ex in data:
minibatch.append(ex)
size_so_far = batch_size_fn(ex, len(minibatch), size_so_far)
if size_so_far == batch_size:
yield minibatch
minibatch, size_so_far = [], 0
elif size_so_far > batch_size:
yield minibatch[:-1]
minibatch, size_so_far = minibatch[-1:], batch_size_fn(ex, 1, 0)
if minibatch:
yield minibatch
def pool(data, batch_size, key, batch_size_fn=lambda new, count, sofar: count,
random_shuffler=None, shuffle=False, sort_within_batch=False):
"""Sort within buckets, then batch, then shuffle batches.
Partitions data into chunks of size 100*batch_size, sorts examples within
each chunk using sort_key, then batch these examples and shuffle the
batches.
"""
if random_shuffler is None:
random_shuffler = random.shuffle
for p in batch(data, batch_size * 100, batch_size_fn):
p_batch = batch(sorted(p, key=key), batch_size, batch_size_fn) \
if sort_within_batch \
else batch(p, batch_size, batch_size_fn)
if shuffle:
for b in random_shuffler(list(p_batch)):
yield b
else:
for b in list(p_batch):
yield b