Skip to content

Commit

Permalink
batch_norm and xavier init added
Browse files Browse the repository at this point in the history
  • Loading branch information
yunjey committed Dec 2, 2016
1 parent 230f3f4 commit 9ad74fd
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions core/model.py
Expand Up @@ -48,14 +48,14 @@ def __init__(self, word_to_idx, dim_feature=[196, 512], dim_embed=512, dim_hidde
self._start = word_to_idx['<START>']
self._null = word_to_idx['<NULL>']

self.weight_initializer = tf.random_normal_initializer(mean=0.0, stddev=0.01)
self.weight_initializer = tf.contrib.layers.xavier_initializer()
self.const_initializer = tf.constant_initializer(0.0)
self.emb_initializer = tf.random_uniform_initializer(minval=-1.0, maxval=1.0)

# Place holder for features and captions
self.features = tf.placeholder(tf.float32, [None, self.L, self.D])
self.captions = tf.placeholder(tf.int32, [None, self.T + 1])

def _get_initial_lstm(self, features):
with tf.variable_scope('initial_lstm'):
features_mean = tf.reduce_mean(features, 1)
Expand Down Expand Up @@ -126,6 +126,15 @@ def _decode_lstm(self, x, h, context, dropout=False, reuse=False):
h_logits = tf.nn.dropout(h_logits, 0.5)
out_logits = tf.matmul(h_logits, w_out) + b_out
return out_logits

def _batch_norm(self, x, mode='train', name=None):
return tf.contrib.layers.batch_norm(inputs=x,
decay=0.95,
center=True,
scale=True,
is_training=(mode=='train'),
updates_collections=None,
scope=(name+'batch_norm'))

def build_model(self):
features = self.features
Expand All @@ -136,6 +145,10 @@ def build_model(self):
captions_out = captions[:, 1:]
mask = tf.to_float(tf.not_equal(captions_out, self._null))


# batch normalize feature vectors
features = self._batch_norm(features, mode='train', name='conv_features')

c, h = self._get_initial_lstm(features=features)
x = self._word_embedding(inputs=captions_in)
features_proj = self._project_features(features=features)
Expand Down Expand Up @@ -168,6 +181,9 @@ def build_model(self):
def build_sampler(self, max_len=20):
features = self.features

# batch normalize feature vectors
features = self._batch_norm(features, mode='test', name='conv_features')

c, h = self._get_initial_lstm(features=features)
features_proj = self._project_features(features=features)

Expand Down

0 comments on commit 9ad74fd

Please sign in to comment.