@@ -103,21 +103,24 @@ def __init__(self,
103103 actions = None ,
104104 states = None ,
105105 sequence_length = None ,
106- reuse_scope = None ):
106+ reuse_scope = None ,
107+ prefix = None ):
107108
108109 if sequence_length is None :
109110 sequence_length = FLAGS .sequence_length
110111
111- self .prefix = prefix = tf .placeholder (tf .string , [])
112+ if prefix is None :
113+ prefix = tf .placeholder (tf .string , [])
114+ self .prefix = prefix
112115 self .iter_num = tf .placeholder (tf .float32 , [])
113116 summaries = []
114117
115118 # Split into timesteps.
116- actions = tf .split (axis = 1 , num_or_size_splits = actions .get_shape ()[1 ], value = actions )
119+ actions = tf .split (axis = 1 , num_or_size_splits = int ( actions .get_shape ()[1 ]) , value = actions )
117120 actions = [tf .squeeze (act ) for act in actions ]
118- states = tf .split (axis = 1 , num_or_size_splits = states .get_shape ()[1 ], value = states )
121+ states = tf .split (axis = 1 , num_or_size_splits = int ( states .get_shape ()[1 ]) , value = states )
119122 states = [tf .squeeze (st ) for st in states ]
120- images = tf .split (axis = 1 , num_or_size_splits = images .get_shape ()[1 ], value = images )
123+ images = tf .split (axis = 1 , num_or_size_splits = int ( images .get_shape ()[1 ]) , value = images )
121124 images = [tf .squeeze (img ) for img in images ]
122125
123126 if reuse_scope is None :
@@ -183,17 +186,18 @@ def __init__(self,
183186
184187def main (unused_argv ):
185188
186- print 'Constructing models and inputs.'
189+ print ( 'Constructing models and inputs.' )
187190 with tf .variable_scope ('model' , reuse = None ) as training_scope :
188191 images , actions , states = build_tfrecord_input (training = True )
189- model = Model (images , actions , states , FLAGS .sequence_length )
192+ model = Model (images , actions , states , FLAGS .sequence_length ,
193+ prefix = 'train' )
190194
191195 with tf .variable_scope ('val_model' , reuse = None ):
192196 val_images , val_actions , val_states = build_tfrecord_input (training = False )
193197 val_model = Model (val_images , val_actions , val_states ,
194- FLAGS .sequence_length , training_scope )
198+ FLAGS .sequence_length , training_scope , prefix = 'val' )
195199
196- print 'Constructing saver.'
200+ print ( 'Constructing saver.' )
197201 # Make saver.
198202 saver = tf .train .Saver (
199203 tf .get_collection (tf .GraphKeys .GLOBAL_VARIABLES ), max_to_keep = 0 )
@@ -214,8 +218,7 @@ def main(unused_argv):
214218 # Run training.
215219 for itr in range (FLAGS .num_iterations ):
216220 # Generate new batch of data.
217- feed_dict = {model .prefix : 'train' ,
218- model .iter_num : np .float32 (itr ),
221+ feed_dict = {model .iter_num : np .float32 (itr ),
219222 model .lr : FLAGS .learning_rate }
220223 cost , _ , summary_str = sess .run ([model .loss , model .train_op , model .summ_op ],
221224 feed_dict )
@@ -226,7 +229,6 @@ def main(unused_argv):
226229 if (itr ) % VAL_INTERVAL == 2 :
227230 # Run through validation set.
228231 feed_dict = {val_model .lr : 0.0 ,
229- val_model .prefix : 'val' ,
230232 val_model .iter_num : np .float32 (itr )}
231233 _ , val_summary_str = sess .run ([val_model .train_op , val_model .summ_op ],
232234 feed_dict )
0 commit comments