Skip to content

Commit

Permalink
async predictor base
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Jul 17, 2016
1 parent e04d846 commit 4fc2108
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 22 deletions.
2 changes: 1 addition & 1 deletion examples/Atari2600/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Reproduce the following methods:
Reproduce the following reinforcement learning methods:

+ Nature-DQN in:
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
Expand Down
41 changes: 39 additions & 2 deletions tensorpack/predict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from abc import abstractmethod, ABCMeta, abstractproperty
import tensorflow as tf
import six
from ..tfutils import get_vars_by_names

__all__ = ['OnlinePredictor', 'OfflinePredictor']
__all__ = ['OnlinePredictor', 'OfflinePredictor', 'AsyncPredictorBase']


class PredictorBase(object):
Expand All @@ -31,7 +32,27 @@ def _do_call(self, dp):
:param dp: input datapoint. must have the same length as input_var_names
:return: output as defined by the config
"""
pass

class AsyncPredictorBase(PredictorBase):
@abstractmethod
def put_task(self, dp, callback=None):
"""
:param dp: A data point (list of component) as inputs.
(It should be either batched or not batched depending on the predictor implementation)
:param callback: a thread-safe callback to get called with the list of
outputs of (inputs, outputs) pair
:return: a Future of outputs
"""

@abstractmethod
def start(self):
""" Start workers """

def _do_call(self, dp):
assert six.PY3, "With Python2, sync methods not available for async predictor"
fut = self.put_task(dp)
# in Tornado, Future.result() doesn't wait
return fut.result()

class OnlinePredictor(PredictorBase):
def __init__(self, sess, input_vars, output_vars, return_input=False):
Expand Down Expand Up @@ -64,3 +85,19 @@ def __init__(self, config):
config.session_init.init(sess)
super(OfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input)


class AsyncOnlinePredictor(PredictorBase):
def __init__(self, sess, enqueue_op, output_vars, return_input=False):
"""
:param enqueue_op: an op to feed inputs with.
:param output_vars: a list of directly-runnable (no extra feeding requirements)
vars producing the outputs.
"""
self.session = sess
self.enqop = enqueue_op
self.output_vars = output_vars
self.return_input = return_input

def put_task(self, dp, callback):
pass
39 changes: 22 additions & 17 deletions tensorpack/predict/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..utils.timer import *
from ..tfutils import *

from .base import OfflinePredictor
from .base import *

try:
if six.PY2:
Expand Down Expand Up @@ -116,34 +116,39 @@ def fetch_batch(self):
cnt += 1
return batched, futures

class MultiThreadAsyncPredictor(object):
class MultiThreadAsyncPredictor(AsyncPredictorBase):
"""
An multithread predictor which run a list of predict func.
Use async interface, support multi-thread and multi-GPU.
An multithread online async predictor which run a list of OnlinePredictor.
It would do an extra batching internally.
"""
def __init__(self, funcs, batch_size=5):
""" :param funcs: a list of predict func"""
self.input_queue = queue.Queue(maxsize=len(funcs)*10)
def __init__(self, predictors, batch_size=5):
""" :param predictors: a list of OnlinePredictor"""
for k in predictors:
assert isinstance(k, OnlinePredictor), type(k)
self.input_queue = queue.Queue(maxsize=len(predictors)*10)
self.threads = [
PredictorWorkerThread(
self.input_queue, f, id, batch_size=batch_size)
for id, f in enumerate(funcs)]
for id, f in enumerate(predictors)]

# TODO XXX set logging here to avoid affecting TF logging
import tornado.options as options
options.parse_command_line(['--logging=debug'])
if six.PY2:
# TODO XXX set logging here to avoid affecting TF logging
import tornado.options as options
options.parse_command_line(['--logging=debug'])

def run(self):
def start(self):
for t in self.threads:
t.start()

def put_task(self, inputs, callback=None):
def run(self): # temporarily for back-compatibility
self.start()

def put_task(self, dp, callback=None):
"""
dp must be non-batched, i.e. single instance
"""
:param inputs: a data point (list of component) matching input_names (not batched)
:param callback: a thread-safe callback to get called with the list of outputs
:returns: a Future of output."""
f = Future()
if callback is not None:
f.add_done_callback(callback)
self.input_queue.put((inputs, f))
self.input_queue.put((dp, f))
return f
2 changes: 1 addition & 1 deletion tensorpack/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_predict_funcs(self, input_names, output_names, n):
Can be overwritten by subclasses to exploit more
parallelism among funcs.
"""
return [self.get_predict_func(input_name, output_names) for k in range(n)]
return [self.get_predict_func(input_names, output_names) for k in range(n)]

def trigger_epoch(self):
self._trigger_epoch()
Expand Down
3 changes: 2 additions & 1 deletion tensorpack/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, trainer, towers):
self.tower_built = False

def get_predictor(self, input_names, output_names, tower):
""" Return an online predictor"""
if not self.tower_built:
self._build_predict_tower()
tower = self.towers[tower % len(self.towers)]
Expand Down Expand Up @@ -204,7 +205,7 @@ def train(self):
self.main_loop()

def run_step(self):
""" just run self.train_op"""
""" Simply run self.train_op"""
self.sess.run(self.train_op)
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
Expand Down

0 comments on commit 4fc2108

Please sign in to comment.