Skip to content

Commit

Permalink
support batch
Browse files Browse the repository at this point in the history
  • Loading branch information
GreatYYX committed Feb 15, 2019
1 parent c15ddeb commit a2137ca
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions rltk/parallel_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ def __init__(self, instance, output_handler):
self.instance = instance

def run(self):
for o in self.instance.get_output():
self.output_handler(*o)
for batched_output in self.instance.get_output():
for o in batched_output:
self.output_handler(*o)


class ParallelProcessor(object):
Expand All @@ -62,6 +63,7 @@ class ParallelProcessor(object):
The return result is one by one, order is arbitrary.
enable_process_id (bool, optional): If it's true, an additional argument `_idx` (process id) will be
passed to `input_handler`. It defaults to False.
batch_size (int, optional): Batch size, defaults to 1.
Note:
Expand All @@ -77,7 +79,7 @@ class ParallelProcessor(object):
def __init__(self, input_handler: Callable, num_of_processor: int,
max_size_per_input_queue: int = 0, max_size_per_output_queue: int = 0,
output_handler: Callable = None, enable_process_id: bool = False,
input_batch_size: int = 1, output_batch_size: int = 1):
batch_size: int = 1):
self.num_of_processor = num_of_processor
self.input_queues = [mp.Queue(maxsize=max_size_per_input_queue) for _ in range(num_of_processor)]
self.output_queues = [mp.Queue(maxsize=max_size_per_output_queue) for _ in range(num_of_processor)]
Expand All @@ -88,10 +90,8 @@ def __init__(self, input_handler: Callable, num_of_processor: int,
self.input_queue_index = 0
self.output_queue_index = 0
self.enable_process_id = enable_process_id
self.input_batch_size = input_batch_size
self.output_batch_size = output_batch_size
self.input_batch = []
self.output_batch = []
self.batch_size = batch_size
self.batch_data = []

# output can be handled in each process or in main process after merging (output_handler needs to be set)
# if output_handler is set, output needs to be handled in main process; otherwise, it assumes there's no output.
Expand Down Expand Up @@ -121,9 +121,9 @@ def task_done(self):
Indicate that all resources which need to compute are added to processes.
(main process, blocked)
"""
if len(self.input_batch) > 0:
self._compute(self.input_batch)
self.input_batch = []
if len(self.batch_data) > 0:
self._compute(self.batch_data)
self.batch_data = []

for q in self.input_queues:
q.put( (ParallelProcessor.CMD_STOP,) )
Expand All @@ -133,11 +133,11 @@ def compute(self, *args, **kwargs):
Add data to one of the input queues.
(main process, unblocked, using round robin to find next available queue)
"""
self.input_batch.append( (args, kwargs) )
self.batch_data.append( (args, kwargs) )

if len(self.input_batch) == self.input_batch_size:
self._compute(self.input_batch)
self.input_batch = [] # reset buffer
if len(self.batch_data) == self.batch_size:
self._compute(self.batch_data)
self.batch_data = [] # reset buffer

def _compute(self, batched_args):
while True:
Expand All @@ -162,6 +162,7 @@ def run(self, idx: int, input_queue: mp.Queue, output_queue: mp.Queue):
output_queue.put( (ParallelProcessor.CMD_STOP,) )
return
elif data[0] == ParallelProcessor.CMD_DATA:
batch_result = []
for d in data[1]:
args, kwargs = d[0], d[1]
# print(idx, 'data')
Expand All @@ -170,7 +171,10 @@ def run(self, idx: int, input_queue: mp.Queue, output_queue: mp.Queue):
if self.output_handler:
if not isinstance(result, tuple): # output must represent as tuple
result = (result,)
output_queue.put( (ParallelProcessor.CMD_DATA, result) )
batch_result.append(result)
if len(batch_result) > 0:
output_queue.put( (ParallelProcessor.CMD_DATA, batch_result) )
batch_result = [] # reset buffer

def get_output(self):
"""
Expand Down

0 comments on commit a2137ca

Please sign in to comment.