Skip to content

Commit be30848

Browse files
authored
Add files via upload
1 parent 61b9f9e commit be30848

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

models/rbm.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,11 @@ def get_train_ops(self, learning_rate=0.1, k=1, persistent=None):
8888
# Compute the positive phase
8989
ph_mean, ph_sample = self.sample_h_given_v(self.input)
9090
# The old state of the chain
91-
chain_start = ph_sample
91+
if persistent is None:
92+
chain_start = ph_sample
93+
else:
94+
chain_start = persistent
95+
9296
# Use tf.while_loop to do the CD-k
9397
cond = lambda i, nv_mean, nv_sample, nh_mean, nh_sample: i < k
9498
body = lambda i, nv_mean, nv_sample, nh_mean, nh_sample: (i+1, ) + self.gibbs_hvh(nh_sample)
@@ -103,7 +107,11 @@ def get_train_ops(self, learning_rate=0.1, k=1, persistent=None):
103107
new_W = tf.assign(self.W, update_W)
104108
new_vbias = tf.assign(self.vbias, update_vbias)
105109
new_hbias = tf.assign(self.hbias, update_hbias)
106-
return [new_W, new_vbias, new_hbias] # use for training
110+
if persistent is not None:
111+
new_persistent = [tf.assign(persistent, nh_sample)]
112+
else:
113+
new_persistent = []
114+
return [new_W, new_vbias, new_hbias] + new_persistent # use for training
107115

108116
def get_reconstruction_cost(self):
109117
"""Compute the cross-entropy of the original input and the reconstruction"""
@@ -126,17 +134,19 @@ def get_reconstruction_cost(self):
126134
rbm = RBM(x, n_visiable=n_visiable, n_hidden=n_hidden)
127135

128136
learning_rate = 0.1
137+
batch_size = 20
129138
cost = rbm.get_reconstruction_cost()
130-
train_ops = rbm.get_train_ops(learning_rate=learning_rate, k=1)
139+
# Create the persistent variable
140+
persistent_chain = tf.Variable(tf.zeros([batch_size, n_hidden]), dtype=tf.float32)
141+
train_ops = rbm.get_train_ops(learning_rate=learning_rate, k=15, persistent=persistent_chain)
131142
init = tf.global_variables_initializer()
132143

133144
output_folder = "rbm_plots"
134145
if not os.path.isdir(output_folder):
135146
os.makedirs(output_folder)
136147
os.chdir(output_folder)
137148

138-
training_epochs = 30
139-
batch_size = 20
149+
training_epochs = 15
140150
display_step = 1
141151
print("Start training...")
142152

@@ -162,7 +172,7 @@ def get_reconstruction_cost(self):
162172
img_shape=(28, 28),
163173
tile_shape=(10, 10),
164174
tile_spacing=(1, 1)))
165-
image.save("filters_at_epoch_{0}.png".format(epoch))
175+
image.save("10filters_at_epoch_{0}.png".format(epoch))
166176

167177
end_time = timeit.default_timer()
168178
training_time = end_time - start_time
@@ -208,7 +218,7 @@ def get_reconstruction_cost(self):
208218
tile_shape=(1, n_chains),
209219
tile_spacing=(1, 1))
210220
image = Image.fromarray(image_data)
211-
image.save("original_and_{0}samples.png".format(n_samples))
221+
image.save("10original_and_{0}samples.png".format(n_samples))
212222

213223

214224

0 commit comments

Comments
 (0)