Skip to content

Commit

Permalink
fix argscope bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 14, 2016
1 parent e90acf2 commit 9f1af4c
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tensorpack/callbacks/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_current_value(self):
"""
ret = self._get_current_value()
if ret is not None and ret != self.last_value:
logger.info("{} at epoch {} is changed to {}".format(
logger.info("{} at epoch {} will change to {}".format(
self.op_name, self.epoch_num, ret))
self.last_value = ret
return ret
Expand Down
3 changes: 2 additions & 1 deletion tensorpack/models/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tensorflow as tf
from functools import wraps
import six
import copy

from ..tfutils import *
from ..tfutils.modelutils import *
Expand Down Expand Up @@ -34,7 +35,7 @@ def wrapped_func(*args, **kwargs):
inputs = args[0]

# update from current argument scope
actual_args = get_arg_scope()[func.__name__]
actual_args = copy.copy(get_arg_scope()[func.__name__])
actual_args.update(kwargs)

with tf.variable_scope(name) as scope:
Expand Down
6 changes: 4 additions & 2 deletions tensorpack/models/nonlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ def LeakyReLU(x, alpha, name=None):
return tf.mul(x, 0.5, name=name)


def BNReLU(is_training):
def BNReLU(is_training, **kwargs):
"""
:param is_traning: boolean
:param kwargs: args for BatchNorm
:returns: a activation function that performs BN + ReLU (a too common combination)
"""
def BNReLU(x, name=None):
x = BatchNorm('bn', x, is_training)
x = BatchNorm('bn', x, is_training, **kwargs)
x = tf.nn.relu(x, name=name)
return x
return BNReLU
2 changes: 1 addition & 1 deletion tensorpack/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def main_loop(self):
# some final operations that might modify the graph
logger.info("Preparing for training...")
self._init_summary()
get_global_step_var()
get_global_step_var() # ensure there is such var, before finalizing the graph
callbacks = self.config.callbacks
callbacks.before_train(self)
self.config.session_init.init(self.sess)
Expand Down

0 comments on commit 9f1af4c

Please sign in to comment.