Skip to content

Commit

Permalink
StagingInputWrapper takes a list of int
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Oct 18, 2017
1 parent 4c5cdf9 commit e649385
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 15 deletions.
2 changes: 1 addition & 1 deletion examples/GAN/GAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(self, config):
raw_devices = ['/gpu:{}'.format(k) for k in config.tower]

# setup input
input = StagingInputWrapper(QueueInput(config.dataflow), raw_devices)
input = StagingInputWrapper(QueueInput(config.dataflow), config.tower)
model = config.model
cbs = input.setup(model.get_inputs_desc())
config.callbacks.extend(cbs)
Expand Down
7 changes: 6 additions & 1 deletion tensorpack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>

import os as _os

from tensorpack.libinfo import __version__, _HAS_TF

Expand All @@ -15,7 +16,11 @@
from tensorpack.callbacks import *
from tensorpack.tfutils import *

from tensorpack.train import *
# In development. Default to v1
if _os.environ.get('TENSORPACK_TRAIN_API', 'v1') == 'v2':
from tensorpack.trainv2 import *
else:
from tensorpack.train import *
from tensorpack.graph_builder import *
from tensorpack.input_source import *
from tensorpack.predict import *
20 changes: 14 additions & 6 deletions tensorpack/input_source/input_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils.concurrency import ShareSessionThread
from ..utils.develop import log_deprecated
from ..callbacks.base import Callback
from ..callbacks.graph import RunOp

Expand Down Expand Up @@ -457,7 +458,8 @@ def _get_input_tensors(self):

class StagingInputWrapper(FeedfreeInput):
"""
A wrapper around a feedfree input, to prefetch it in StagingArea (usually on GPUs).
A wrapper around a feedfree input,
to prefetch the input in StagingArea (on GPUs).
"""
class StagingCallback(Callback):
"""
Expand All @@ -478,16 +480,22 @@ def _before_train(self):
def _before_run(self, ctx):
return self.fetches

def __init__(self, input, devices, nr_stage=5):
def __init__(self, input, towers, nr_stage=5):
"""
Args:
input: a :class:`FeedfreeInput`
devices: list of devices to be used for each training tower
nr_stage: number of elements to prefetch
input (FeedfreeInput):
towers ([int]): list of GPU ids to prefetch on.
nr_stage: number of elements to prefetch on each GPU.
"""
assert isinstance(input, FeedfreeInput), input
self._input = input
self._devices = devices
if not isinstance(towers[0], int):
# API changed
log_deprecated("StagingInputWrapper(devices=)", "Use (towers=) instead!", "2018-01-31")
self._devices = towers
else:
self._devices = ['/gpu:{}'.format(k) for k in towers]

self._nr_stage = nr_stage
self._areas = []
self._stage_ops = []
Expand Down
3 changes: 1 addition & 2 deletions tensorpack/train/multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def apply_prefetch_policy(config, gpu_prefetch=True):

# seem to only improve on >1 GPUs
if not isinstance(config.data, (StagingInputWrapper, DummyConstantInput)):
devices = ['/gpu:{}'.format(k) for k in config.tower]
config.data = StagingInputWrapper(config.data, devices)
config.data = StagingInputWrapper(config.data, config.tower)


class SyncMultiGPUTrainerParameterServer(Trainer):
Expand Down
5 changes: 0 additions & 5 deletions tensorpack/train/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,4 @@ def QueueInputTrainer(config, input_queue=None):
else:
config.data = QueueInput(config.dataflow, input_queue)
config.dataflow = None

# debug
# from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[128,224,224,3], [128]])
return SimpleTrainer(config)

0 comments on commit e649385

Please sign in to comment.