Skip to content

Commit 9d96e9f

Browse files
committed
Wrap the cifar10 multigpu model construction part with a variable_scope
Without the new variable_scope, creating apply_gradient_op raises an error that additional moving average or slot variables could not be created. This is because of the 'leaky reuse' of variable scope, so we correct the problem by explicitly introducing a new variable scope. Related issues: #901, tensorflow/tensorflow#6220
1 parent eb62b91 commit 9d96e9f

File tree

1 file changed

+20
-19
lines changed

1 file changed

+20
-19
lines changed

tutorials/image/cifar10/cifar10_multi_gpu_train.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -162,25 +162,26 @@ def train():
162162

163163
# Calculate the gradients for each model tower.
164164
tower_grads = []
165-
for i in xrange(FLAGS.num_gpus):
166-
with tf.device('/gpu:%d' % i):
167-
with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
168-
# Calculate the loss for one tower of the CIFAR model. This function
169-
# constructs the entire CIFAR model but shares the variables across
170-
# all towers.
171-
loss = tower_loss(scope)
172-
173-
# Reuse variables for the next tower.
174-
tf.get_variable_scope().reuse_variables()
175-
176-
# Retain the summaries from the final tower.
177-
summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
178-
179-
# Calculate the gradients for the batch of data on this CIFAR tower.
180-
grads = opt.compute_gradients(loss)
181-
182-
# Keep track of the gradients across all towers.
183-
tower_grads.append(grads)
165+
with tf.variable_scope(tf.get_variable_scope()):
166+
for i in xrange(FLAGS.num_gpus):
167+
with tf.device('/gpu:%d' % i):
168+
with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
169+
# Calculate the loss for one tower of the CIFAR model. This function
170+
# constructs the entire CIFAR model but shares the variables across
171+
# all towers.
172+
loss = tower_loss(scope)
173+
174+
# Reuse variables for the next tower.
175+
tf.get_variable_scope().reuse_variables()
176+
177+
# Retain the summaries from the final tower.
178+
summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
179+
180+
# Calculate the gradients for the batch of data on this CIFAR tower.
181+
grads = opt.compute_gradients(loss)
182+
183+
# Keep track of the gradients across all towers.
184+
tower_grads.append(grads)
184185

185186
# We must calculate the mean of each gradient. Note that this is the
186187
# synchronization point across all towers.

0 commit comments

Comments
 (0)