Skip to content

Commit

Permalink
fix some concurrency bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 23, 2016
1 parent dceac08 commit 67f37f2
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 37 deletions.
56 changes: 33 additions & 23 deletions tensorpack/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import argparse
from collections import namedtuple
import numpy as np
import bisect
from tqdm import tqdm
from six.moves import zip

Expand All @@ -21,23 +20,24 @@

__all__ = ['PredictConfig', 'DatasetPredictor', 'get_predict_func']

PredictResult = namedtuple('PredictResult', ['input', 'output'])

class PredictConfig(object):
def __init__(self, **kwargs):
"""
The config used by `get_predict_func`.
:param session_config: a `tf.ConfigProto` instance to instantiate the
session. default to a session running 1 GPU.
:param session_config: a `tf.ConfigProto` instance to instantiate the session.
:param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session.
:param input_data_mapping: Decide the mapping from each component in data
to the input tensor, since you may not need all input variables
of the graph to run the graph for prediction (for example
the `label` input is not used if you only need probability
distribution).
It should be a list with size=len(data_point),
where each element is an index of the input variables each
component of the data point should be fed into.
of the Model to run the graph for prediction (for example
the `label` input is not used if you only need probability distribution).
It should be a list of int with length equal to `len(data_point)`,
where each element in the list defines which input variables each
component in the data point should be fed into.
If not given, defaults to range(len(input_vars))
For example, in image classification task, the testing
Expand All @@ -46,7 +46,7 @@ def __init__(self, **kwargs):
input_vars: [image_var, label_var]
the mapping should look like: ::
the mapping should then look like: ::
input_data_mapping: [0] # the first component in a datapoint should map to `image_var`
Expand Down Expand Up @@ -95,19 +95,19 @@ def run_input(dp):
"Graph has {} inputs but dataset only gives {} components!".format(
len(input_map), len(dp))
feed = dict(zip(input_map, dp))

results = sess.run(output_vars, feed_dict=feed)
if len(output_vars) == 1:
return results[0]
else:
return results
return sess.run(output_vars, feed_dict=feed)
return run_input

PredictResult = namedtuple('PredictResult', ['input', 'output'])


class PredictWorker(multiprocessing.Process):
""" A worker process to run predictor on one GPU """
def __init__(self, idx, gpuid, inqueue, outqueue, config):
"""
:param idx: index of the worker
:param gpuid: id of the GPU to be used
:param inqueue: input queue to get data point
:param outqueue: output queue put result
:param config: a `PredictConfig`
"""
super(PredictWorker, self).__init__()
self.idx = idx
self.gpuid = gpuid
Expand All @@ -132,6 +132,15 @@ def run(self):
self.outqueue.put((tid, res))

def DFtoQueue(ds, size, nr_consumer):
"""
Build a queue that produce data from `DataFlow`, and a process
that fills the queue.
:param ds: a `DataFlow`
:param size: size of the queue
:param nr_consumer: number of consumer of the queue.
will add this many of `DIE` sentinel to the end of the queue.
:returns: (queue, process)
"""
q = multiprocessing.Queue(size)
class EnqueProc(multiprocessing.Process):
def __init__(self, ds, q, nr_consumer):
Expand Down Expand Up @@ -172,17 +181,15 @@ def __init__(self, config, dataset):
for i in range(self.nr_gpu)]
self.result_queue = OrderedResultGatherProc(self.outqueue)

# run the procs
# setup all the procs
self.inqueue_proc.start()
for p in self.workers: p.start()
self.result_queue.start()

ensure_proc_terminate(self.workers)
ensure_proc_terminate([self.result_queue, self.inqueue_proc])
else:
self.func = get_predict_func(config)


def get_result(self):
""" A generator to produce prediction for each data"""
with tqdm(total=self.ds.size()) as pbar:
Expand All @@ -191,12 +198,15 @@ def get_result(self):
yield PredictResult(dp, self.func(dp))
pbar.update()
else:
die_cnt = 0
while True:
res = self.result_queue.get()
if res[0] != DIE:
yield res[1]
else:
break
die_cnt += 1
if die_cnt == self.nr_gpu:
break
pbar.update()
self.inqueue_proc.join()
self.inqueue_proc.terminate()
Expand Down
20 changes: 11 additions & 9 deletions tensorpack/utils/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@
# Credit belongs to Xinyu Zhou

import threading
import multiprocessing, multiprocess
from contextlib import contextmanager
import tensorflow as tf
import multiprocessing
import atexit
import bisect
import weakref
from six.moves import zip

from .naming import *

__all__ = ['StoppableThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE']
Expand All @@ -29,9 +25,9 @@ def stopped(self):


class DIE(object):
""" A placeholder class indicating end of queue """
pass


def ensure_proc_terminate(proc):
if isinstance(proc, list):
for p in proc:
Expand All @@ -47,11 +43,14 @@ def stop_proc_by_weak_ref(ref):
proc.terminate()
proc.join()

assert isinstance(proc, (multiprocessing.Process, multiprocess.Process))
assert isinstance(proc, multiprocessing.Process)
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))


class OrderedContainer(object):
"""
Like a priority queue, but will always wait for item with index (x+1) before producing (x+2).
"""
def __init__(self, start=0):
self.ranks = []
self.data = []
Expand All @@ -78,9 +77,12 @@ def get(self):


class OrderedResultGatherProc(multiprocessing.Process):
"""
Gather indexed data from a data queue, and produce results with the
original index-based order.
"""
def __init__(self, data_queue, start=0):
super(self.__class__, self).__init__()

self.data_queue = data_queue
self.ordered_container = OrderedContainer(start=start)
self.result_queue = multiprocessing.Queue()
Expand Down
18 changes: 13 additions & 5 deletions tensorpack/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,25 @@ def _set_file(path):
filename=path, encoding='utf-8', mode='w')
logger.addHandler(hdl)

def set_logger_dir(dirname):
def set_logger_dir(dirname, action=None):
"""
Set the directory for global logging.
:param dirname: log directory
:param action: an action (k/b/d/n) to be performed. Will ask user by default.
"""
global LOG_FILE, LOG_DIR
if os.path.isdir(dirname):
logger.warn("""\
Directory {} exists! Please either backup/delete it, or use a new directory \
unless you're resuming from a previous task.""".format(dirname))
logger.info("Select Action: k (keep) / b (backup) / d (delete) / n (new):")
while True:
act = input().lower().strip()
if act:
break
if not action:
while True:
act = input().lower().strip()
if act:
break
else:
act = action
if act == 'b':
backup_name = dirname + get_time_str()
shutil.move(dirname, backup_name)
Expand Down

0 comments on commit 67f37f2

Please sign in to comment.