Skip to content

Commit

Permalink
multigpu predictor
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 23, 2016
1 parent 1dcc0e7 commit c9107ad
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 17 deletions.
1 change: 1 addition & 0 deletions tensorpack/callbacks/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, var_name, prefix=None, map_func=None, scale=255, clip=False):
self.clip = clip

def _before_train(self):
# TODO might not work for multiGPU?
self.var = self.graph.get_tensor_by_name(self.var_name)

def _trigger_epoch(self):
Expand Down
7 changes: 4 additions & 3 deletions tensorpack/dataflow/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from six.moves import range
from .base import ProxyDataFlow
from ..utils.concurrency import ensure_procs_terminate
from ..utils.concurrency import ensure_proc_terminate
from ..utils import logger

__all__ = ['PrefetchData']
Expand Down Expand Up @@ -36,7 +36,8 @@ def __init__(self, ds, nr_prefetch, nr_proc=1):
"""
:param ds: a `DataFlow` instance.
:param nr_prefetch: size of the queue to hold prefetched datapoints.
:param nr_proc: number of processes to use.
:param nr_proc: number of processes to use. When larger than 1, order
of data points will be random.
"""
super(PrefetchData, self).__init__(ds)
self._size = self.size()
Expand All @@ -45,7 +46,7 @@ def __init__(self, ds, nr_prefetch, nr_proc=1):
self.queue = multiprocessing.Queue(self.nr_prefetch)
self.procs = [PrefetchProcess(self.ds, self.queue)
for _ in range(self.nr_proc)]
ensure_procs_terminate(self.procs)
ensure_proc_terminate(self.procs)
for x in self.procs:
x.start()

Expand Down
98 changes: 89 additions & 9 deletions tensorpack/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
import argparse
from collections import namedtuple
import numpy as np
import bisect
from tqdm import tqdm
from six.moves import zip

import multiprocessing
from .utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE

from .tfutils import *
from .utils import logger
from .tfutils.modelutils import describe_model
Expand Down Expand Up @@ -50,6 +54,7 @@ def __init__(self, **kwargs):
:param output_var_names: a list of names of the output variables to predict, the
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
:param nr_gpu: default to 1. Use CUDA_VISIBLE_DEVICES to control which GPU to use sepcifically.
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
Expand All @@ -59,6 +64,7 @@ def assert_type(v, tp):
self.model = kwargs.pop('model')
self.input_data_mapping = kwargs.pop('input_data_mapping', None)
self.output_var_names = kwargs.pop('output_var_names')
self.nr_gpu = kwargs.pop('nr_gpu', 1)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))

def get_predict_func(config):
Expand All @@ -81,8 +87,6 @@ def get_predict_func(config):
output_vars = [tf.get_default_graph().get_tensor_by_name(get_op_var_name(n)[1])
for n in output_var_names]

describe_model()

sess = tf.Session(config=config.session_config)
config.session_init.init(sess)

Expand All @@ -101,27 +105,103 @@ def run_input(dp):

PredictResult = namedtuple('PredictResult', ['input', 'output'])

# TODO mutligpu predictor

class PredictWorker(multiprocessing.Process):
def __init__(self, idx, gpuid, inqueue, outqueue, config):
super(PredictWorker, self).__init__()
self.idx = idx
self.gpuid = gpuid
self.inqueue = inqueue
self.outqueue = outqueue
self.config = config

def run(self):
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpuid
G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default(), tf.device('/gpu:{}'.format(self.idx)):
self.func = get_predict_func(self.config)
if self.idx == 0:
describe_model()
while True:
tid, dp = self.inqueue.get()
if tid == DIE:
self.outqueue.put((DIE, None))
return
else:
res = PredictResult(dp, self.func(dp))
self.outqueue.put((tid, res))

def DFtoQueue(ds, size, nr_consumer):
q = multiprocessing.Queue(size)
class EnqueProc(multiprocessing.Process):
def __init__(self, ds, q, nr_consumer):
super(EnqueProc, self).__init__()
self.ds = ds
self.q = q

def run(self):
for idx, dp in enumerate(self.ds.get_data()):
self.q.put((idx, dp))
print "Enqueue ends"
for _ in range(nr_consumer):
self.q.put((DIE, None))

proc = EnqueProc(ds, q, nr_consumer)
return q, proc

class DatasetPredictor(object):
"""
Run the predict_config on a given `DataFlow`.
"""
def __init__(self, predict_config, dataset):
def __init__(self, config, dataset):
"""
:param predict_config: a `PredictConfig` instance.
:param config: a `PredictConfig` instance.
:param dataset: a `DataFlow` instance.
"""
assert isinstance(dataset, DataFlow)
self.ds = dataset
self.predict_func = get_predict_func(predict_config)
self.nr_gpu = config.nr_gpu
if self.nr_gpu > 1:
self.inqueue, self.inqueue_proc = DFtoQueue(self.ds, 10, self.nr_gpu)
self.outqueue = multiprocessing.Queue()
try:
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
except KeyError:
gpus = range(self.nr_gpu)
self.workers = [PredictWorker(i, gpus[i], self.inqueue, self.outqueue, config)
for i in range(self.nr_gpu)]
self.result_queue = OrderedResultGatherProc(self.outqueue)

# run the procs
self.inqueue_proc.start()
for p in self.workers: p.start()
self.result_queue.start()

ensure_proc_terminate(self.workers)
ensure_proc_terminate([self.result_queue, self.inqueue_proc])
else:
self.func = get_predict_func(config)


def get_result(self):
""" A generator to produce prediction for each data"""
with tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data():
yield PredictResult(dp, self.predict_func(dp))
pbar.update()
if self.nr_gpu == 1:
for dp in self.ds.get_data():
yield PredictResult(dp, self.func(dp))
pbar.update()
else:
while True:
res = self.result_queue.get()
if res[0] != DIE:
yield res[1]
else:
break
pbar.update()
self.inqueue_proc.join()
self.inqueue_proc.terminate()
for p in self.workers:
p.join(); p.terminate()

def get_all_result(self):
"""
Expand Down
2 changes: 2 additions & 0 deletions tensorpack/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,11 @@ def get_model_inputs():
# get gradients to update:
if self.config.nr_tower > 1:
logger.info("Training a model of {} tower".format(self.config.nr_tower))

# to avoid repeated summary from each device
coll_keys = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
kept_summaries = {}

grad_list = []
for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \
Expand Down
72 changes: 67 additions & 5 deletions tensorpack/utils/concurrency.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# -*- coding: UTF-8 -*-
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Credit belongs to Xinyu Zhou

import threading
import multiprocessing
import multiprocessing, multiprocess
from contextlib import contextmanager
import tensorflow as tf
import atexit
Expand All @@ -12,6 +13,9 @@

from .naming import *

__all__ = ['StoppableThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE']

class StoppableThread(threading.Thread):
def __init__(self):
super(StoppableThread, self).__init__()
Expand All @@ -24,7 +28,16 @@ def stopped(self):
return self._stop.isSet()


class DIE(object):
pass


def ensure_proc_terminate(proc):
if isinstance(proc, list):
for p in proc:
ensure_proc_terminate(p)
return

def stop_proc_by_weak_ref(ref):
proc = ref()
if proc is None:
Expand All @@ -34,9 +47,58 @@ def stop_proc_by_weak_ref(ref):
proc.terminate()
proc.join()

assert isinstance(proc, multiprocessing.Process)
assert isinstance(proc, (multiprocessing.Process, multiprocess.Process))
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))

def ensure_procs_terminate(procs):
for p in procs:
ensure_proc_terminate(p)

class OrderedContainer(object):
def __init__(self, start=0):
self.ranks = []
self.data = []
self.wait_for = start

def put(self, rank, val):
idx = bisect.bisect(self.ranks, rank)
self.ranks.insert(idx, rank)
self.data.insert(idx, val)

def has_next(self):
if len(self.ranks) == 0:
return False
return self.ranks[0] == self.wait_for

def get(self):
assert self.has_next()
ret = self.data[0]
rank = self.ranks[0]
del self.ranks[0]
del self.data[0]
self.wait_for += 1
return rank, ret


class OrderedResultGatherProc(multiprocessing.Process):
def __init__(self, data_queue, start=0):
super(self.__class__, self).__init__()

self.data_queue = data_queue
self.ordered_container = OrderedContainer(start=start)
self.result_queue = multiprocessing.Queue()

def run(self):
try:
while True:
task_id, data = self.data_queue.get()
if task_id == DIE:
self.result_queue.put((task_id, data))
else:
self.ordered_container.put(task_id, data)
while self.ordered_container.has_next():
self.result_queue.put(self.ordered_container.get())
except Exception as e:
import traceback
traceback.print_exc()
raise e

def get(self):
return self.result_queue.get()

0 comments on commit c9107ad

Please sign in to comment.