Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Mar 6, 2020
1 parent f128a5c commit c2d99a4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
3 changes: 2 additions & 1 deletion tensorpack/dataflow/imgaug/imgproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class Hue(PhotometricAugmentor):
def __init__(self, range=(0, 180), rgb=True):
"""
Args:
range(list or tuple): range from which the applied hue offset is selected (maximum [-90,90] or [0,180])
range(list or tuple): range from which the applied hue offset is selected
(maximum range can be [-90,90] for both uint8 and float32)
rgb (bool): whether input is RGB or BGR.
"""
super(Hue, self).__init__()
Expand Down
14 changes: 11 additions & 3 deletions tensorpack/predict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tensorflow as tf

from ..input_source import PlaceholderInput
from ..tfutils.common import get_tensors_by_names
from ..tfutils.common import get_tensors_by_names, get_op_tensor_name
from ..tfutils.tower import PredictTowerContext

__all__ = ['PredictorBase',
Expand All @@ -34,6 +34,9 @@ def __call__(self, *dp):
.. code-block:: python
predictor(e1, e2)
Returns:
list[array]: list of outputs
"""
output = self._do_call(dp)
if self.return_input:
Expand Down Expand Up @@ -98,9 +101,14 @@ def __init__(self, input_tensors, output_tensors,
will use the default session at the first call.
Note that in TensorFlow, default session is thread-local.
"""
def normalize_name(t):
if isinstance(t, six.string_types):
return get_op_tensor_name(t)[1]
return t

self.return_input = return_input
self.input_tensors = input_tensors
self.output_tensors = output_tensors
self.input_tensors = [normalize_name(x) for x in input_tensors]
self.output_tensors = [normalize_name(x) for x in output_tensors]
self.sess = sess

if sess is not None:
Expand Down
3 changes: 1 addition & 2 deletions tensorpack/train/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,7 @@ def __init__(self, average=True, compression=None):
compression: `hvd.Compression.fp16` or `hvd.Compression.none`
"""
if 'pyarrow' in sys.modules:
logger.warn("Horovod and pyarrow may conflict due to pyarrow bugs. "
"Uninstall pyarrow and use msgpack instead.")
logger.warn("Horovod and pyarrow may conflict due to pyarrow bugs.")
# lazy import
import horovod.tensorflow as hvd
import horovod
Expand Down

0 comments on commit c2d99a4

Please sign in to comment.