Skip to content

Commit

Permalink
update cifar number & fix multigpu restore bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 19, 2016
1 parent da3da39 commit 76fe1b6
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 16 deletions.
2 changes: 1 addition & 1 deletion examples/ResNet/cifar10_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
Identity Mappings in Deep Residual Networks, arxiv:1603.05027
I can reproduce the results for
n=5, about 7.6% val error
n=5, about 7.2% val error after 93k step with 2 TitanX (6.8it/s)
n=18, about 6.05% val error after 62k step with 2 TitanX (about 10hr)
n=30: a 182-layer network, about 5.5% val error after 51k step with 2 GPUs
This model uses the whole training set instead of a 95:5 train-val split.
Expand Down
12 changes: 6 additions & 6 deletions examples/ResNet/svhn_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@


"""
Reach 1.9% validation error after 90 epochs, with 2 GPUs.
You might need to adjust learning rate schedule when running with 1 GPU.
ResNet-110 for SVHN Digit Classification.
Reach 1.9% validation error after 90 epochs, with 2 TitanX xxhr, 2it/s.
You might need to adjust the learning rate schedule when running with 1 GPU.
"""

BATCH_SIZE = 128
Expand Down Expand Up @@ -98,8 +99,7 @@ def residual(name, l, increase_dim=False, first=False):
logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output')

y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)

Expand Down Expand Up @@ -167,8 +167,8 @@ def get_config():
optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
ValidationError(dataset_test, prefix='test'),
ModelSaver(),
ClassificationError(dataset_test, prefix='validation'),
ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (20, 0.01), (33, 0.001), (60, 0.0001)])
]),
Expand Down
4 changes: 2 additions & 2 deletions examples/cifar10_convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_config():
lr = tf.train.exponential_decay(
learning_rate=1e-2,
global_step=get_global_step_var(),
decay_steps=step_per_epoch * 30 if nr_gpu == 1 else 20,
decay_steps=step_per_epoch * (30 if nr_gpu == 1 else 20),
decay_rate=0.5, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)

Expand All @@ -129,7 +129,7 @@ def get_config():
session_config=sess_config,
model=Model(),
step_per_epoch=step_per_epoch,
max_epoch=3,
max_epoch=20,
)

if __name__ == '__main__':
Expand Down
5 changes: 3 additions & 2 deletions tensorpack/callbacks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ def __init__(self, keep_recent=10, keep_freq=0.5):
def _before_train(self):
self.path = os.path.join(logger.LOG_DIR, 'model')
self.saver = tf.train.Saver(
var_list=self._get_vars(),
var_list=ModelSaver._get_vars(),
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq)

def _get_vars(self):
@staticmethod
def _get_vars():
vars = tf.all_variables()
var_dict = {}
for v in vars:
Expand Down
1 change: 1 addition & 0 deletions tensorpack/callbacks/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def before_train_context(self, trainer):
with create_test_session(trainer) as sess:
self.sess = sess
self.graph = sess.graph
# no tower in test graph. just keep it as what it is
self.saver = tf.train.Saver()
with self.graph.as_default(), self.sess.as_default():
yield
Expand Down
56 changes: 52 additions & 4 deletions tensorpack/tfutils/sessinit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import os
from abc import abstractmethod, ABCMeta
import numpy as np
from collections import defaultdict
import re
import tensorflow as tf
import six

Expand Down Expand Up @@ -38,7 +40,7 @@ def _init(self, sess):

class SaverRestore(SessionInit):
"""
Restore an old model saved by `tf.Saver`.
Restore an old model saved by `ModelSaver`.
"""
def __init__(self, model_path):
"""
Expand All @@ -52,14 +54,60 @@ def __init__(self, model_path):
self.set_path(model_path)

def _init(self, sess):
saver = tf.train.Saver()
saver.restore(sess, self.path)
logger.info(
"Restore checkpoint from {}".format(self.path))
"Restoring checkpoint from {}.".format(self.path))
sess.run(tf.initialize_all_variables())
chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
vars_map = SaverRestore._get_vars_to_restore_multimap(chkpt_vars)
for dic in SaverRestore._produce_restore_dict(vars_map):
saver = tf.train.Saver(var_list=dic)
saver.restore(sess, self.path)

def set_path(self, model_path):
self.path = model_path

@staticmethod
def _produce_restore_dict(vars_multimap):
"""
Produce {var_name: var} dict that can be used by `tf.train.Saver`, from a {var_name: [vars]} dict.
"""
while len(vars_multimap):
ret = {}
for k in vars_multimap.keys():
v = vars_multimap[k]
ret[k] = v[-1]
del v[-1]
if not len(v):
del vars_multimap[k]
yield ret


@staticmethod
def _read_checkpoint_vars(model_path):
reader = tf.train.NewCheckpointReader(model_path)
return set(reader.GetVariableToShapeMap().keys())

@staticmethod
def _get_vars_to_restore_multimap(vars_available):
"""
Get a dict of {var_name: [var, var]} to restore
:param vars_available: varaibles available in the checkpoint, for existence checking
"""
# TODO warn if some variable in checkpoint is not used
vars_to_restore = tf.all_variables()
var_dict = defaultdict(list)
for v in vars_to_restore:
name = v.op.name
if 'tower' in name:
new_name = re.sub('tower[0-9]+/', '', name)
name = new_name
if name in vars_available:
var_dict[name].append(v)
else:
logger.warn("Param {} not found in checkpoint! Will not restore.".format(v.op.name))
return var_dict


class ParamRestore(SessionInit):
"""
Restore trainable variables from a dictionary.
Expand Down
2 changes: 1 addition & 1 deletion tensorpack/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def main_loop(self):
self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step))

for epoch in range(self.config.starting_epoch, self.config.max_epoch):
for epoch in range(self.config.starting_epoch, self.config.max_epoch+1):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, self.global_step + self.config.step_per_epoch)):
Expand Down

0 comments on commit 76fe1b6

Please sign in to comment.