Skip to content

Commit

Permalink
fix async training late-binding bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 25, 2016
1 parent b6a775f commit fd21c3b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
sys.path.insert(0, os.path.abspath('../'))

import mock
MOCK_MODULES = ['numpy', 'scipy', 'tensorflow', 'scipy.misc', 'h5py', 'nltk', 'cv2']
MOCK_MODULES = ['numpy', 'scipy', 'tensorflow', 'scipy.misc', 'h5py', 'nltk',
'cv2', 'scipy.io']
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock()

Expand Down
5 changes: 3 additions & 2 deletions tensorpack/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import threading
import copy
import re
import functools
from six.moves import zip

from .base import Trainer
Expand Down Expand Up @@ -175,7 +176,7 @@ def train(self):
else:
grad_list = [self.process_grads(g) for g in grad_list]
# pretend to average the grads, in order to make async and
# sync have consistent semantics
# sync have consistent effective learning rate
def scale(grads):
return [(grad / self.config.nr_tower, var) for grad, var in grads]
grad_list = map(scale, grad_list)
Expand All @@ -192,7 +193,7 @@ def scale(grads):
self.threads = []
for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
f = lambda : self.sess.run([train_op])
f = lambda op=train_op: self.sess.run([op]) # avoid late-binding
th = LoopThread(f)
th.pause()
th.start()
Expand Down

0 comments on commit fd21c3b

Please sign in to comment.