Skip to content

Commit

Permalink
batch norm support float16
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed Dec 2, 2017
1 parent fc7f62a commit c99a6a2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
13 changes: 10 additions & 3 deletions tensorlayer/layers.py
Expand Up @@ -1900,7 +1900,7 @@ def __init__(
):
if tf.__version__ < "1.4":
raise Exception("Deformable CNN layer requires tensrflow 1.4 or higher version")

Layer.__init__(self, name=name)
self.inputs = layer.outputs
self.offset_layer = offset_layer
Expand Down Expand Up @@ -3099,6 +3099,7 @@ class BatchNormLayer(Layer):
The initializer for initializing beta
gamma_init : gamma initializer
The initializer for initializing gamma
dtype : tf.float32 (default) or tf.float16
name : a string or None
An optional name to attach to this layer.
Expand All @@ -3116,6 +3117,7 @@ def __init__(
is_train = False,
beta_init = tf.zeros_initializer,
gamma_init = tf.random_normal_initializer(mean=1.0, stddev=0.002), # tf.ones_initializer,
dtype = tf.float32,
name ='batchnorm_layer',
):
Layer.__init__(self, name=name)
Expand All @@ -3136,10 +3138,13 @@ def __init__(
beta_init = beta_init()
beta = tf.get_variable('beta', shape=params_shape,
initializer=beta_init,
dtype=dtype,
trainable=is_train)#, restore=restore)

gamma = tf.get_variable('gamma', shape=params_shape,
initializer=gamma_init, trainable=is_train,
initializer=gamma_init,
dtype=dtype,
trainable=is_train,
)#restore=restore)

## 2.
Expand All @@ -3150,10 +3155,12 @@ def __init__(
moving_mean = tf.get_variable('moving_mean',
params_shape,
initializer=moving_mean_init,
trainable=False,)# restore=restore)
dtype=dtype,
trainable=False)# restore=restore)
moving_variance = tf.get_variable('moving_variance',
params_shape,
initializer=tf.constant_initializer(1.),
dtype=dtype,
trainable=False,)# restore=restore)

## 3.
Expand Down
2 changes: 1 addition & 1 deletion tensorlayer/prepro.py
Expand Up @@ -281,7 +281,7 @@ def crop_multi(x, wrg, hrg, is_random=False, row_index=0, col_index=1, channel_i
return np.asarray(results)

# flip
def flip_axis(x, axis, is_random=False):
def flip_axis(x, axis=1, is_random=False):
"""Flip the axis of an image, such as flip left and right, up and down, randomly or non-randomly,
Parameters
Expand Down

0 comments on commit c99a6a2

Please sign in to comment.