Skip to content

Commit

Permalink
async training.
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 24, 2016
1 parent 08821b5 commit 3f74330
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 17 deletions.
3 changes: 2 additions & 1 deletion tensorpack/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def trigger_epoch(self):

@abstractmethod
def _trigger_epoch(self):
""" This is called right after all steps in an epoch are finished"""
pass

def _init_summary(self):
Expand Down Expand Up @@ -94,7 +95,7 @@ def main_loop(self):
if self.coord.should_stop():
return
self.run_step()
callbacks.trigger_step()
#callbacks.trigger_step() # not useful?
self.global_step += 1
self.trigger_epoch()
except (KeyboardInterrupt, Exception):
Expand Down
67 changes: 52 additions & 15 deletions tensorpack/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .base import Trainer
from ..dataflow.common import RepeatedData
from ..utils import *
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average
from ..tfutils import *

Expand Down Expand Up @@ -79,13 +80,14 @@ def run(self):
finally:
logger.info("Enqueue Thread Exited.")


class QueueInputTrainer(Trainer):
"""
Trainer which builds a FIFO queue for input.
Support multi GPU.
"""

def __init__(self, config, input_queue=None):
def __init__(self, config, input_queue=None, async=False):
"""
:param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Expand All @@ -98,6 +100,9 @@ def __init__(self, config, input_queue=None):
100, [x.dtype for x in self.input_vars], name='input_queue')
else:
self.input_queue = input_queue
self.async = async
if self.async:
assert self.config.nr_tower > 1

@staticmethod
def _average_grads(tower_grads):
Expand All @@ -122,14 +127,15 @@ def _get_model_inputs(self):
qv.set_shape(v.get_shape())
return ret

def _single_tower_grad_cost(self):
def _single_tower_grad(self):
""" Get grad and cost for single-tower case"""
model_inputs = self._get_model_inputs()
cost_var = self.model.get_cost(model_inputs, is_training=True)
grads = self.config.optimizer.compute_gradients(cost_var)
return (grads, cost_var)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
return grads

def _multi_tower_grad_cost(self):
def _multi_tower_grads(self):
logger.info("Training a model of {} tower".format(self.config.nr_tower))

# to avoid repeated summary from each device
Expand All @@ -140,6 +146,7 @@ def _multi_tower_grad_cost(self):
for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope:
logger.info("Building graph for tower {}...".format(i))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue
cost_var = self.model.get_cost(model_inputs, is_training=True) # build tower

Expand All @@ -148,30 +155,49 @@ def _multi_tower_grad_cost(self):
self.config.optimizer.compute_gradients(cost_var, gate_gradients=0))

if i == 0:
cost_var_t0 = cost_var
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
tf.get_variable_scope().reuse_variables()
for k in collect_dedup:
kept_summaries[k] = copy.copy(tf.get_collection(k))
logger.info("Graph built for tower {}.".format(i))
for k in collect_dedup:
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(kept_summaries[k])
grads = QueueInputTrainer._average_grads(grad_list)
return (grads, cost_var_t0)
return grad_list

def train(self):
enqueue_op = self.input_queue.enqueue(self.input_vars)

grads, cost_var = self._single_tower_grad_cost() \
if self.config.nr_tower == 0 else self._multi_tower_grad_cost()
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
avg_maintain_op = summary_moving_average()

grads = self.process_grads(grads)
if self.config.nr_tower > 1:
grad_list = self._multi_tower_grads()
if not self.async:
grads = QueueInputTrainer._average_grads(grad_list)
grads = self.process_grads(grads)
else:
grad_list = [self.process_grads(g) for g in grad_list]
# pretend to average the grads, in order to make async and
# sync have consistent semantics
def scale(grads):
return [(grad / self.config.nr_tower, var) for grad, var in grads]
grad_list = map(scale, grad_list)
grads = grad_list[0] # use grad from the first tower for routinely stuff
else:
grads = self._single_tower_grad()
grads = self.process_grads(grads)

self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op)
summary_moving_average())

if self.async:
self.threads = []
for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
f = lambda : self.sess.run([train_op])
th = LoopThread(f)
th.pause()
th.start()
self.threads.append(th)
self.async_running = False

self.init_session_and_coord()
# create a thread that keeps filling the queue
Expand All @@ -183,14 +209,25 @@ def _start_all_threads(self):
self.input_th.start()

def run_step(self):
if self.async:
if not self.async_running:
self.async_running = True
for th in self.threads: # resume all threads
th.resume()
self.sess.run([self.train_op]) # faster since train_op return None

def _trigger_epoch(self):
# note that summary_op will take a data from the queue
if self.async:
self.async_running = False
for th in self.threads:
th.pause()
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self._process_summary(summary_str)


def start_train(config):
tr = QueueInputTrainer(config)
tr.train()

25 changes: 24 additions & 1 deletion tensorpack/utils/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import bisect
import weakref

__all__ = ['StoppableThread', 'ensure_proc_terminate',
__all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE']

class StoppableThread(threading.Thread):
Expand All @@ -23,6 +23,29 @@ def stop(self):
def stopped(self):
return self._stop.isSet()

class LoopThread(threading.Thread):
""" A pausable thread that simply runs a loop"""
def __init__(self, func):
"""
:param func: the function to run
"""
super(LoopThread, self).__init__()
self.func = func
self.lock = threading.Lock()
self.daemon = True

def run(self):
while True:
self.lock.acquire()
self.lock.release()
self.func()

def pause(self):
self.lock.acquire()

def resume(self):
self.lock.release()


class DIE(object):
""" A placeholder class indicating end of queue """
Expand Down

0 comments on commit 3f74330

Please sign in to comment.