Skip to content

Commit 2d7bd1d

Browse files
alexlee-gknealwu
authored andcommitted
Fixes for compatibility with TF 1.0 and python 3.
1 parent 5e38011 commit 2d7bd1d

File tree

2 files changed

+15
-19
lines changed

2 files changed

+15
-19
lines changed

video_prediction/lstm_ops.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,11 @@ def init_state(inputs,
3838
if inputs is not None:
3939
# Handle both the dynamic shape as well as the inferred shape.
4040
inferred_batch_size = inputs.get_shape().with_rank_at_least(1)[0]
41-
batch_size = tf.shape(inputs)[0]
4241
dtype = inputs.dtype
4342
else:
4443
inferred_batch_size = 0
45-
batch_size = 0
46-
4744
initial_state = state_initializer(
48-
tf.stack([batch_size] + state_shape),
49-
dtype=dtype)
50-
initial_state.set_shape([inferred_batch_size] + state_shape)
51-
45+
[inferred_batch_size] + state_shape, dtype=dtype)
5246
return initial_state
5347

5448

video_prediction/prediction_train.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,21 +103,24 @@ def __init__(self,
103103
actions=None,
104104
states=None,
105105
sequence_length=None,
106-
reuse_scope=None):
106+
reuse_scope=None,
107+
prefix=None):
107108

108109
if sequence_length is None:
109110
sequence_length = FLAGS.sequence_length
110111

111-
self.prefix = prefix = tf.placeholder(tf.string, [])
112+
if prefix is None:
113+
prefix = tf.placeholder(tf.string, [])
114+
self.prefix = prefix
112115
self.iter_num = tf.placeholder(tf.float32, [])
113116
summaries = []
114117

115118
# Split into timesteps.
116-
actions = tf.split(axis=1, num_or_size_splits=actions.get_shape()[1], value=actions)
119+
actions = tf.split(axis=1, num_or_size_splits=int(actions.get_shape()[1]), value=actions)
117120
actions = [tf.squeeze(act) for act in actions]
118-
states = tf.split(axis=1, num_or_size_splits=states.get_shape()[1], value=states)
121+
states = tf.split(axis=1, num_or_size_splits=int(states.get_shape()[1]), value=states)
119122
states = [tf.squeeze(st) for st in states]
120-
images = tf.split(axis=1, num_or_size_splits=images.get_shape()[1], value=images)
123+
images = tf.split(axis=1, num_or_size_splits=int(images.get_shape()[1]), value=images)
121124
images = [tf.squeeze(img) for img in images]
122125

123126
if reuse_scope is None:
@@ -183,17 +186,18 @@ def __init__(self,
183186

184187
def main(unused_argv):
185188

186-
print 'Constructing models and inputs.'
189+
print('Constructing models and inputs.')
187190
with tf.variable_scope('model', reuse=None) as training_scope:
188191
images, actions, states = build_tfrecord_input(training=True)
189-
model = Model(images, actions, states, FLAGS.sequence_length)
192+
model = Model(images, actions, states, FLAGS.sequence_length,
193+
prefix='train')
190194

191195
with tf.variable_scope('val_model', reuse=None):
192196
val_images, val_actions, val_states = build_tfrecord_input(training=False)
193197
val_model = Model(val_images, val_actions, val_states,
194-
FLAGS.sequence_length, training_scope)
198+
FLAGS.sequence_length, training_scope, prefix='val')
195199

196-
print 'Constructing saver.'
200+
print('Constructing saver.')
197201
# Make saver.
198202
saver = tf.train.Saver(
199203
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=0)
@@ -214,8 +218,7 @@ def main(unused_argv):
214218
# Run training.
215219
for itr in range(FLAGS.num_iterations):
216220
# Generate new batch of data.
217-
feed_dict = {model.prefix: 'train',
218-
model.iter_num: np.float32(itr),
221+
feed_dict = {model.iter_num: np.float32(itr),
219222
model.lr: FLAGS.learning_rate}
220223
cost, _, summary_str = sess.run([model.loss, model.train_op, model.summ_op],
221224
feed_dict)
@@ -226,7 +229,6 @@ def main(unused_argv):
226229
if (itr) % VAL_INTERVAL == 2:
227230
# Run through validation set.
228231
feed_dict = {val_model.lr: 0.0,
229-
val_model.prefix: 'val',
230232
val_model.iter_num: np.float32(itr)}
231233
_, val_summary_str = sess.run([val_model.train_op, val_model.summ_op],
232234
feed_dict)

0 commit comments

Comments
 (0)