Skip to content

Commit

Permalink
MultiProcessMapData with strict (#414)
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Dec 9, 2017
1 parent be3a07a commit 68e8d9e
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 75 deletions.
10 changes: 2 additions & 8 deletions .github/ISSUE_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Bug Reports/Feature Requests/Usage Questions Only:

Bug Reports: PLEASE always include
Bug reports or other problems with code: PLEASE always include
1. What you did. (command you run and changes you made if using examples; post or describe your code if not)
2. What you observed, e.g. logs.
2. What you observed, e.g. as much as logs possible.
3. What you expected, if not obvious.
4. Your environment (TF version, cudnn version, number & type of GPUs), if it matters.
5. About efficiency, PLEASE first read http://tensorpack.readthedocs.io/en/latest/tutorial/performance-tuning.html
Expand All @@ -14,10 +14,4 @@ Feature Requests:
It may not have to be added to tensorpack unless you have a good reason.
3. Note that we don't implement papers at others' requests.

Usage Questions, e.g.:
"How do I do [this specific thing] in tensorpack?"
"Why certain examples need to be written in this way?"
We don't answer general machine learning questions like:
"I want to do [this machine learning task]. What specific things do I need to do?"

You can also use gitter (https://gitter.im/tensorpack/users) for more casual discussions.
164 changes: 97 additions & 67 deletions tensorpack/dataflow/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,11 @@ def _zmq_catch_error(name):


class _MultiProcessZMQDataFlow(DataFlow):
def __init__(self, ds):
def __init__(self):
assert os.name != 'nt', "ZMQ IPC doesn't support windows!"
self._reset_done = False
self._procs = []

self.ds = ds
try:
self._size = ds.size()
except NotImplementedError:
self._size = -1

def size(self):
return self.ds.size()

def reset_state(self):
"""
All forked dataflows are reset **once and only once** in spawned processes.
Expand Down Expand Up @@ -265,10 +256,17 @@ def __init__(self, ds, nr_proc=1, hwm=50):
if nr_proc > 1:
logger.info("[PrefetchDataZMQ] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.")
try:
self._size = ds.size()
except NotImplementedError:
self._size = -1

def _recv(self):
return loads(self.socket.recv(copy=False).bytes)

def size(self):
return self.ds.size()

def get_data(self):
with self._guard, _zmq_catch_error('PrefetchDataZMQ'):
for k in itertools.count():
Expand Down Expand Up @@ -311,7 +309,59 @@ def _start_processes(self):
proc.start()


class MultiThreadMapData(ProxyDataFlow):
class _ParallelMapData(ProxyDataFlow):
def __init__(self, ds, buffer_size):
super(_ParallelMapData, self).__init__(ds)
assert buffer_size > 0, buffer_size
self._buffer_size = buffer_size

def _recv(self):
pass

def _send(self, dp):
pass

def _recv_filter_none(self):
ret = self._recv()
assert ret is not None, \
"[{}] Map function cannot return None when strict mode is used.".format(type(self).__name__)
return ret

def _fill_buffer(self):
try:
for _ in range(self._buffer_size):
dp = next(self._iter)
self._send(dp)
except StopIteration:
logger.error(
"[{}] buffer_size cannot be larger than the size of the DataFlow!".format(type(self).__name__))
raise

def get_data_non_strict(self):
for dp in self._iter:
self._send(dp)
yield self._recv()

self._iter = self.ds.get_data() # refresh
for _ in range(self._buffer_size):
self._send(next(self._iter))
yield self._recv()

def get_data_strict(self):
for dp in self._iter:
self._send(dp)
yield self._recv_filter_none()
self._iter = self.ds.get_data() # refresh

# first clear the buffer, then fill
for k in range(self._buffer_size):
dp = self._recv_filter_none()
if k == self._buffer_size - 1:
self._fill_buffer()
yield dp


class MultiThreadMapData(_ParallelMapData):
"""
Same as :class:`MapData`, but start threads to run the mapping function.
This is useful when the mapping function is the bottleneck, but you don't
Expand Down Expand Up @@ -367,11 +417,10 @@ def __init__(self, ds, nr_thread, map_func, buffer_size=200, strict=False):
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
"""
super(MultiThreadMapData, self).__init__(ds)
super(MultiThreadMapData, self).__init__(ds, buffer_size)

self._strict = strict
self.nr_thread = nr_thread
self.buffer_size = buffer_size
self.map_func = map_func
self._threads = []
self._evt = None
Expand All @@ -398,43 +447,20 @@ def reset_state(self):
# only call once, to ensure inq+outq has a total of buffer_size elements
self._fill_buffer()

def _fill_buffer(self):
n = self.buffer_size - self._in_queue.qsize() - self._out_queue.qsize()
assert n >= 0, n
if n == 0:
return
try:
for _ in range(n):
self._in_queue.put(next(self._iter))
except StopIteration:
logger.error("[MultiThreadMapData] buffer_size cannot be larger than the size of the DataFlow!")
raise

def _recv(self):
ret = self._out_queue.get()
if ret is None:
assert not self._strict, \
"[MultiThreadMapData] Map function cannot return None when strict mode is used."
return ret
return self._out_queue.get()

def _send(self, dp):
self._in_queue.put(dp)

def get_data(self):
with self._guard:
for dp in self._iter:
self._in_queue.put(dp)
yield self._recv()

self._iter = self.ds.get_data()
if self._strict:
# first call get() to clear the queues, then fill
for k in range(self.buffer_size):
dp = self._recv()
if k == self.buffer_size - 1:
self._fill_buffer()
for dp in self.get_data_strict():
yield dp
else:
for _ in range(self.buffer_size):
self._in_queue.put(next(self._iter))
yield self._recv()
for dp in self.get_data_non_strict():
yield dp

def __del__(self):
if self._evt is not None:
Expand All @@ -447,10 +473,20 @@ def __del__(self):
ThreadedMapData = MultiThreadMapData


class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow):
class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
"""
Same as :class:`MapData`, but start processes to run the mapping function,
and communicate with ZeroMQ pipe.
Note:
1. Processes run in parallel and can take different time to run the
mapping function. Therefore the order of datapoints won't be
preserved, and datapoints from one pass of `df.get_data()` might get
mixed with datapoints from the next pass.
You can use **strict mode**, where `MultiProcessMapData.get_data()`
is guranteed to produce the exact set which `df.get_data()`
produces. Although the order of data still isn't preserved.
"""
class _Worker(mp.Process):
def __init__(self, identity, map_func, pipename, hwm):
Expand All @@ -472,30 +508,32 @@ def run(self):
dp = self.map_func(dp)
socket.send(dumps(dp), copy=False)

def __init__(self, ds, nr_proc, map_func, buffer_size=200):
def __init__(self, ds, nr_proc, map_func, buffer_size=200, strict=False):
"""
Args:
ds (DataFlow): the dataflow to map
nr_proc(int): number of threads to use
map_func (callable): datapoint -> datapoint | None
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
"""
super(MultiProcessMapDataZMQ, self).__init__(ds)
_ParallelMapData.__init__(self, ds, buffer_size)
_MultiProcessZMQDataFlow.__init__(self)
self.nr_proc = nr_proc
self.map_func = map_func
self.buffer_size = buffer_size
self._strict = strict
self._procs = []
self._guard = DataFlowReentrantGuard()

def _reset_once(self):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER)
self.socket.set_hwm(self.buffer_size * 2)
self.socket.set_hwm(self._buffer_size * 2)
pipename = _get_pipe_name('dataflow-map')
_bind_guard(self.socket, pipename)

self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.nr_proc)]
worker_hwm = int(self.buffer_size * 2 // self.nr_proc)
worker_hwm = int(self._buffer_size * 2 // self.nr_proc)
self._procs = [MultiProcessMapDataZMQ._Worker(
self._proc_ids[k], self.map_func, pipename, worker_hwm)
for k in range(self.nr_proc)]
Expand All @@ -507,14 +545,8 @@ def _reset_once(self):
self._start_processes()
self._fill_buffer()

def _fill_buffer(self):
# Filling the buffer.
try:
for _ in range(self.buffer_size):
self._send(next(self._iter))
except StopIteration:
logger.error("[MultiProcessMapData] buffer_size cannot be larger than the size of the DataFlow!")
raise
def reset_state(self):
_MultiProcessZMQDataFlow.reset_state(self)

def _send(self, dp):
# round-robin assignment
Expand All @@ -529,14 +561,12 @@ def _recv(self):

def get_data(self):
with self._guard, _zmq_catch_error('MultiProcessMapData'):
for dp in self._iter:
self._send(dp)
yield self._recv()

self._iter = self.ds.get_data() # refresh
for _ in range(self.buffer_size):
self._send(next(self._iter))
yield self._recv()
if self._strict:
for dp in self.get_data_strict():
yield dp
else:
for dp in self.get_data_non_strict():
yield dp


MultiProcessMapData = MultiProcessMapDataZMQ # alias
Expand All @@ -549,13 +579,13 @@ def __init__(self, size):

def get_data(self):
for k in range(self._size):
yield [0]
yield [k]

def size(self):
return self._size

ds = Zero(300)
ds = MultiProcessMapData(ds, 3, lambda x: [x[0] + 1])
ds = MultiProcessMapData(ds, 3, lambda x: [x[0] + 1], strict=True)
ds.reset_state()
for k in ds.get_data():
print("Bang!", k)
Expand Down

0 comments on commit 68e8d9e

Please sign in to comment.