Skip to content

Commit

Permalink
selectcomponent
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 21, 2016
1 parent 27ea283 commit 4d64098
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
24 changes: 22 additions & 2 deletions tensorpack/dataflow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def _aggregate_batch(data_holder):
return result

class FixedSizeData(ProxyDataFlow):
""" Generate data from another DataFlow, but with a fixed epoch size"""
""" Generate data from another DataFlow, but with a fixed epoch size.
The state of the underlying DataFlow is maintained among each epoch.
"""
def __init__(self, ds, size):
"""
:param ds: a :mod:`DataFlow` to produce data
Expand Down Expand Up @@ -165,7 +167,7 @@ class MapDataComponent(ProxyDataFlow):
def __init__(self, ds, func, index=0):
"""
:param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a datapoint dp[index], returns a
:param func: a function that takes a datapoint component dp[index], returns a
new value of dp[index]. return None to skip this datapoint.
"""
super(MapDataComponent, self).__init__(ds)
Expand Down Expand Up @@ -269,3 +271,21 @@ def get_data(self):
for dp in d.get_data():
yield dp

class SelectComponent(ProxyDataFlow):
"""
Select component from a datapoint.
"""
def __init__(self, ds, idxs):
"""
:param ds: a :mod:`DataFlow` instance
:param idxs: a list of datapoint component index of the original dataflow
"""
super(SelectComponent, self).__init__(ds)
self.idxs = idxs

def get_data(self):
for dp in self.ds.get_data():
newdp = []
for idx in self.idxs:
newdp.append(dp[idx])
yield newdp
6 changes: 2 additions & 4 deletions tensorpack/dataflow/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import multiprocessing

from six.moves import range
from .base import ProxyDataFlow
from ..utils.concurrency import ensure_procs_terminate

Expand Down Expand Up @@ -49,12 +50,9 @@ def __init__(self, ds, nr_prefetch, nr_proc=1):

def get_data(self):
tot_cnt = 0
while True:
for _ in range(tot_cnt):
dp = self.queue.get()
yield dp
tot_cnt += 1
if tot_cnt == self._size:
break

def __del__(self):
self.queue.close()
Expand Down

0 comments on commit 4d64098

Please sign in to comment.