/
common.py
67 lines (57 loc) · 1.83 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: batch.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
from .base import DataFlow
__all__ = ['BatchData', 'FixedSizeData']
class BatchData(DataFlow):
def __init__(self, ds, batch_size, remainder=False):
"""
Group data in ds into batches
ds: a DataFlow instance
remainder: whether to return the remaining data smaller than a batch_size.
if set, might return a data point of a different shape
"""
self.ds = ds
self.batch_size = batch_size
self.remainder = remainder
def size(self):
ds_size = self.ds.size()
div = ds_size / self.batch_size
rem = ds_size % self.batch_size
if rem == 0:
return div
return div + int(self.remainder)
def get_data(self):
holder = []
for data in self.ds.get_data():
holder.append(data)
if len(holder) == self.batch_size:
yield BatchData.aggregate_batch(holder)
holder = []
if self.remainder and len(holder) > 0:
yield BatchData.aggregate_batch(holder)
@staticmethod
def aggregate_batch(data_holder):
size = len(data_holder[0])
result = []
for k in xrange(size):
result.append(
np.array([x[k] for x in data_holder],
dtype=data_holder[0][k].dtype))
return tuple(result)
class FixedSizeData(DataFlow):
def __init__(self, ds, size):
self.ds = ds
self._size = size
def size(self):
return self._size
def get_data(self):
cnt = 0
while True:
for dp in self.ds.get_data():
cnt += 1
yield dp
if cnt == self._size:
return