@@ -88,7 +88,11 @@ def get_train_ops(self, learning_rate=0.1, k=1, persistent=None):
8888 # Compute the positive phase
8989 ph_mean , ph_sample = self .sample_h_given_v (self .input )
9090 # The old state of the chain
91- chain_start = ph_sample
91+ if persistent is None :
92+ chain_start = ph_sample
93+ else :
94+ chain_start = persistent
95+
9296 # Use tf.while_loop to do the CD-k
9397 cond = lambda i , nv_mean , nv_sample , nh_mean , nh_sample : i < k
9498 body = lambda i , nv_mean , nv_sample , nh_mean , nh_sample : (i + 1 , ) + self .gibbs_hvh (nh_sample )
@@ -103,7 +107,11 @@ def get_train_ops(self, learning_rate=0.1, k=1, persistent=None):
103107 new_W = tf .assign (self .W , update_W )
104108 new_vbias = tf .assign (self .vbias , update_vbias )
105109 new_hbias = tf .assign (self .hbias , update_hbias )
106- return [new_W , new_vbias , new_hbias ] # use for training
110+ if persistent is not None :
111+ new_persistent = [tf .assign (persistent , nh_sample )]
112+ else :
113+ new_persistent = []
114+ return [new_W , new_vbias , new_hbias ] + new_persistent # use for training
107115
108116 def get_reconstruction_cost (self ):
109117 """Compute the cross-entropy of the original input and the reconstruction"""
@@ -126,17 +134,19 @@ def get_reconstruction_cost(self):
126134 rbm = RBM (x , n_visiable = n_visiable , n_hidden = n_hidden )
127135
128136 learning_rate = 0.1
137+ batch_size = 20
129138 cost = rbm .get_reconstruction_cost ()
130- train_ops = rbm .get_train_ops (learning_rate = learning_rate , k = 1 )
139+ # Create the persistent variable
140+ persistent_chain = tf .Variable (tf .zeros ([batch_size , n_hidden ]), dtype = tf .float32 )
141+ train_ops = rbm .get_train_ops (learning_rate = learning_rate , k = 15 , persistent = persistent_chain )
131142 init = tf .global_variables_initializer ()
132143
133144 output_folder = "rbm_plots"
134145 if not os .path .isdir (output_folder ):
135146 os .makedirs (output_folder )
136147 os .chdir (output_folder )
137148
138- training_epochs = 30
139- batch_size = 20
149+ training_epochs = 15
140150 display_step = 1
141151 print ("Start training..." )
142152
@@ -162,7 +172,7 @@ def get_reconstruction_cost(self):
162172 img_shape = (28 , 28 ),
163173 tile_shape = (10 , 10 ),
164174 tile_spacing = (1 , 1 )))
165- image .save ("filters_at_epoch_ {0}.png" .format (epoch ))
175+ image .save ("10filters_at_epoch_ {0}.png" .format (epoch ))
166176
167177 end_time = timeit .default_timer ()
168178 training_time = end_time - start_time
@@ -208,7 +218,7 @@ def get_reconstruction_cost(self):
208218 tile_shape = (1 , n_chains ),
209219 tile_spacing = (1 , 1 ))
210220 image = Image .fromarray (image_data )
211- image .save ("original_and_ {0}samples.png" .format (n_samples ))
221+ image .save ("10original_and_ {0}samples.png" .format (n_samples ))
212222
213223
214224
0 commit comments