Skip to content
Permalink
Browse files

Use SmartInit globally - a simpler interface to initialization

  • Loading branch information...
ppwwyyxx committed Sep 1, 2019
1 parent cbd698a commit 17b34c69d4bc8128ae6d4dba55e209ae34bd8821
Showing with 85 additions and 105 deletions.
  1. +5 −5 docs/tutorial/save-load.md
  2. +2 −2 examples/A3C-Gym/train-atari.py
  3. +1 −2 examples/CTC-TIMIT/train-timit.py
  4. +1 −3 examples/CaffeModels/load-alexnet.py
  5. +1 −2 examples/CaffeModels/load-cpm.py
  6. +1 −1 examples/CaffeModels/load-vgg16.py
  7. +1 −1 examples/CaffeModels/load-vgg19.py
  8. +2 −3 examples/Char-RNN/char-rnn.py
  9. +2 −3 examples/DeepQNetwork/DQN.py
  10. +4 −5 examples/DoReFa-Net/alexnet-dorefa.py
  11. +2 −2 examples/DoReFa-Net/resnet-dorefa.py
  12. +1 −2 examples/DynamicFilterNetwork/steering-filter.py
  13. +3 −3 examples/FasterRCNN/predict.py
  14. +2 −2 examples/FasterRCNN/train.py
  15. +1 −1 examples/GAN/BEGAN.py
  16. +2 −2 examples/GAN/ConditionalGAN-mnist.py
  17. +1 −1 examples/GAN/CycleGAN.py
  18. +2 −2 examples/GAN/DCGAN.py
  19. +1 −1 examples/GAN/DiscoGAN-CelebA.py
  20. +2 −2 examples/GAN/Image2Image.py
  21. +1 −1 examples/GAN/Improved-WGAN.py
  22. +2 −2 examples/GAN/InfoGAN-mnist.py
  23. +1 −1 examples/GAN/WGAN.py
  24. +2 −3 examples/HED/hed.py
  25. +1 −1 examples/ImageNetModels/imagenet_utils.py
  26. +1 −2 examples/ImageNetModels/inception-bn.py
  27. +3 −4 examples/ImageNetModels/shufflenet.py
  28. +2 −2 examples/OpticalFlow/flownet2.py
  29. +1 −2 examples/PennTreebank/PTB-LSTM.py
  30. +1 −1 examples/ResNet/cifar10-preact18-mixup.py
  31. +1 −1 examples/ResNet/cifar10-resnet.py
  32. +3 −4 examples/ResNet/imagenet-resnet.py
  33. +2 −2 examples/ResNet/load-resnet.py
  34. +2 −3 examples/Saliency/CAM-resnet.py
  35. +1 −1 examples/Saliency/saliency-maps.py
  36. +3 −5 examples/SimilarityLearning/mnist-embeddings.py
  37. +2 −3 examples/SpatialTransformer/mnist-addition.py
  38. +3 −3 examples/SuperResolution/enet-pat.py
  39. +1 −2 examples/basics/cifar-convnet.py
  40. +4 −4 examples/basics/export-model.py
  41. +1 −1 examples/basics/svhn-digit-convnet.py
  42. +1 −3 examples/boilerplate.py
  43. +9 −9 tensorpack/tfutils/sessinit.py
@@ -39,13 +39,13 @@ For inference, use `session_init` in `PredictConfig(...)`.

There are a few ways a session can be initialized:
```
session_init=SmartRestore("path/to/checkpoint") # load a TF checkpoint
session_init=SmartRestore("path/to/model_zoo.npz") # load tensorpack model zoo
session_init=SmartRestore(dict_of_parameters) # load a dictionary
session_init=SmartRestore(["path1", dict2]) # load them sequentially
session_init=SmartInit("path/to/checkpoint") # load a TF checkpoint
session_init=SmartInit("path/to/model_zoo.npz") # load tensorpack model zoo
session_init=SmartInit(dict_of_parameters) # load a dictionary
session_init=SmartInit(["path1", dict2]) # load them sequentially
```

[SmartRestore](../modules/tfutils.html#tensorpack.tfutils.sessinit.SmartRestore)
[SmartInit](../modules/tfutils.html#tensorpack.tfutils.sessinit.SmartInit)
is in fact a small helper which uses some heuristics to return you one of
[SaverRestore](../modules/tfutils.html#tensorpack.tfutils.sessinit.SaverRestore) or
[DictRestore](../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore).
@@ -265,7 +265,7 @@ def train():
],
session_creator=sesscreate.NewSessionCreator(config=get_default_sess_config(0.5)),
steps_per_epoch=STEPS_PER_EPOCH,
session_init=get_model_loader(args.load) if args.load else None,
session_init=SmartInit(args.load),
max_epoch=1000,
)
trainer = SimpleTrainer() if num_gpu == 1 else AsyncMultiGPUTrainer(train_tower)
@@ -294,7 +294,7 @@ def train():
assert args.load is not None
pred = OfflinePredictor(PredictConfig(
model=Model(),
session_init=get_model_loader(args.load),
session_init=SmartInit(args.load),
input_names=['state'],
output_names=['policy']))
if args.task == 'play':
@@ -119,6 +119,5 @@ def get_config(ds_train, ds_test):
ds_test = get_data(args.test, False, args.stat)

config = get_config(ds_train, ds_test)
if args.load:
config.session_init = SaverRestore(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SimpleTrainer())
@@ -5,7 +5,6 @@

from __future__ import print_function
import argparse
import numpy as np
import os
import cv2
import tensorflow as tf
@@ -39,11 +38,10 @@ def tower_func(image):


def run_test(path, input):
param_dict = dict(np.load(path))
predictor = OfflinePredictor(PredictConfig(
input_signature=[tf.TensorSpec((None, 227, 227, 3), tf.float32, 'input')],
tower_func=tower_func,
session_init=DictRestore(param_dict),
session_init=SmartInit(path),
input_names=['input'],
output_names=['prob']
))
@@ -95,11 +95,10 @@ def add_stage(stage, l):


def run_test(model_path, img_file):
param_dict = dict(np.load(model_path))
predict_func = OfflinePredictor(PredictConfig(
input_signature=[tf.TensorSpec((None, 368, 368, 3), tf.float32, 'input')],
tower_func=CPM,
session_init=DictRestore(param_dict),
session_init=SmartInit(model_path),
input_names=['input'],
output_names=['resized_map']
))
@@ -61,7 +61,7 @@ def run_test(path, input):
predict_func = OfflinePredictor(PredictConfig(
input_signature=[tf.TensorSpec((None, 224, 224, 3), tf.float32, 'input')],
tower_func=tower_func,
session_init=DictRestore(param_dict),
session_init=SmartInit(param_dict),
input_names=['input'],
output_names=['prob'] # prob:0 is the probability distribution
))
@@ -64,7 +64,7 @@ def run_test(path, input):
predict_func = OfflinePredictor(PredictConfig(
input_signature=[tf.TensorSpec((None, 224, 224, 3), tf.float32, 'input')],
tower_func=tower_func,
session_init=DictRestore(param_dict),
session_init=SmartInit(param_dict),
input_names=['input'],
output_names=['prob'] # prob:0 is the probability distribution
))
@@ -141,7 +141,7 @@ def sample(path, start, length):

pred = OfflinePredictor(PredictConfig(
model=Model(),
session_init=SaverRestore(path),
session_init=SmartInit(path),
input_names=['input', 'c0', 'h0', 'c1', 'h1'],
output_names=['prob', 'last_state']))

@@ -193,6 +193,5 @@ def pick(prob):
else:
param.corpus = args.corpus
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SimpleTrainer())
@@ -171,7 +171,7 @@ def get_config(model):
assert args.load is not None
pred = OfflinePredictor(PredictConfig(
model=model,
session_init=get_model_loader(args.load),
session_init=SmartInit(args.load),
input_names=['state'],
output_names=['Qvalue']))
if args.task == 'play':
@@ -183,6 +183,5 @@ def get_config(model):
os.path.join('train_log', 'DQN-{}'.format(
os.path.basename(args.env).split('.')[0])))
config = get_config(model)
if args.load:
config.session_init = get_model_loader(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SimpleTrainer())
@@ -12,7 +12,7 @@

from tensorpack import *
from tensorpack.dataflow import dataset
from tensorpack.tfutils.sessinit import get_model_loader
from tensorpack.tfutils.sessinit import SmartInit
from tensorpack.tfutils.summary import add_param_summary
from tensorpack.tfutils.varreplace import remap_variables
from tensorpack.utils.gpu import get_num_gpu
@@ -214,12 +214,12 @@ def run_image(model, sess_init, inputs):

if args.run:
assert args.load.endswith('.npz')
run_image(Model(), DictRestore(dict(np.load(args.load))), args.run)
run_image(Model(), SmartInit(args.load), args.run)
sys.exit()
if args.eval:
BATCH_SIZE = 128
ds = get_data('val')
eval_classification(Model(), get_model_loader(args.load), ds)
eval_classification(Model(), SmartInit(args.load), ds)
sys.exit()

nr_tower = max(get_num_gpu(), 1)
@@ -229,6 +229,5 @@ def run_image(model, sess_init, inputs):
logger.info("Batch per tower: {}".format(BATCH_SIZE))

config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SyncMultiGPUTrainerReplicated(nr_tower))
@@ -163,7 +163,7 @@ def run_image(model, sess_init, inputs):
ds = dataset.ILSVRC12(args.data, 'val', shuffle=False)
ds = AugmentImageComponent(ds, get_inference_augmentor())
ds = BatchData(ds, 192, remainder=True)
eval_classification(Model(), get_model_loader(args.load), ds)
eval_classification(Model(), SmartInit(args.load), ds)
elif args.run:
assert args.load.endswith('.npz')
run_image(Model(), DictRestore(dict(np.load(args.load))), args.run)
run_image(Model(), SmartInit(args.load), args.run)
@@ -255,6 +255,5 @@ def get_config():
with change_gpu(args.gpu):
NGPU = len(args.gpu.split(','))
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SyncMultiGPUTrainer(NGPU))
@@ -14,7 +14,7 @@

import tensorpack.utils.viz as tpviz
from tensorpack.predict import MultiTowerOfflinePredictor, OfflinePredictor, PredictConfig
from tensorpack.tfutils import get_model_loader, get_tf_version_tuple
from tensorpack.tfutils import SmartInit, get_tf_version_tuple
from tensorpack.tfutils.export import ModelExporter
from tensorpack.utils import fs, logger

@@ -38,7 +38,7 @@ def do_visualize(model, model_path, nr_visualize=100, output_dir='output'):

pred = OfflinePredictor(PredictConfig(
model=model,
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
input_names=['image', 'gt_boxes', 'gt_labels'],
output_names=[
'generate_{}_proposals/boxes'.format('fpn' if cfg.MODE_FPN else 'rpn'),
@@ -146,7 +146,7 @@ def do_predict(pred_func, input_file):
else:
predcfg = PredictConfig(
model=MODEL,
session_init=get_model_loader(args.load),
session_init=SmartInit(args.load),
input_names=MODEL.get_inference_tensor_names()[0],
output_names=MODEL.get_inference_tensor_names()[1])

@@ -103,9 +103,9 @@
else:
if args.load:
# ignore mismatched values, so you can `--load` a model for fine-tuning
session_init = SmartRestore(args.load, ignore_mismatch=True)
session_init = SmartInit(args.load, ignore_mismatch=True)
else:
session_init = SmartRestore(cfg.BACKBONE.WEIGHTS)
session_init = SmartInit(cfg.BACKBONE.WEIGHTS)

traincfg = TrainConfig(
model=MODEL,
@@ -146,5 +146,5 @@ def optimizer(self):
StatMonitorParamSetter(
'learning_rate', 'losses/measure', lambda x: x * 0.5, 0, 10)
],
session_init=SaverRestore(args.load) if args.load else None,
session_init=SmartInit(args.load),
steps_per_epoch=500, max_epoch=400)
@@ -114,7 +114,7 @@ def get_data():

def sample(model_path):
pred = PredictConfig(
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
model=Model(),
input_names=['label', 'z'],
output_names=['gen/gen'])
@@ -145,5 +145,5 @@ def sample(model_path):
callbacks=[ModelSaver()],
steps_per_epoch=500,
max_epoch=100,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load),
)
@@ -224,5 +224,5 @@ def _trigger(self):
],
max_epoch=195,
steps_per_epoch=data.size(),
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load)
)
@@ -121,7 +121,7 @@ def get_data():

def sample(model, model_path, output_name='gen/gen'):
pred = PredictConfig(
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
model=model,
input_names=['z'],
output_names=[output_name, 'z'])
@@ -167,5 +167,5 @@ def get_args(default_batch=128, default_z_dim=100):
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load),
)
@@ -211,5 +211,5 @@ def get_filelist(idxlist):
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=250,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load),
)
@@ -179,7 +179,7 @@ def get_data():

def sample(datadir, model_path):
pred = PredictConfig(
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
model=Model(),
input_names=['input', 'output'],
output_names=['viz'])
@@ -226,5 +226,5 @@ def sample(datadir, model_path):
],
steps_per_epoch=data.size(),
max_epoch=300,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load)
)
@@ -97,5 +97,5 @@ def optimizer(self):
callbacks=[ModelSaver()],
steps_per_epoch=300,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load)
)
@@ -218,7 +218,7 @@ def get_data():

def sample(model_path):
pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
model=Model(),
input_names=['z_code', 'z_noise'],
output_names=['gen/viz']))
@@ -276,5 +276,5 @@ def sample(model_path):
callbacks=[ModelSaver(keep_checkpoint_every_n_hours=0.1)],
steps_per_epoch=500,
max_epoch=100,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load)
)
@@ -80,5 +80,5 @@ def _trigger_step(self):
callbacks=[ModelSaver(), ClipCallback()],
steps_per_epoch=500,
max_epoch=200,
session_init=SaverRestore(args.load) if args.load else None
session_init=SmartInit(args.load)
)
@@ -271,7 +271,7 @@ def get_config():
def run(model_path, image_path, output):
pred_config = PredictConfig(
model=Model(),
session_init=get_model_loader(model_path),
session_init=SmartInit(model_path),
input_names=['image'],
output_names=['output' + str(k) for k in range(1, 7)])
predictor = OfflinePredictor(pred_config)
@@ -309,8 +309,7 @@ def run(model_path, image_path, output):
run(args.load, args.run, args.output)
else:
config = get_config()
if args.load:
config.session_init = get_model_loader(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(
config,
SyncMultiGPUTrainer(max(get_num_gpu(), 1)))
@@ -431,7 +431,7 @@ def create_predict_config(self, session_init):
Examples:
pred = OfflinePredictor(model.create_predict_config(get_model_loader(args.load)))
pred = OfflinePredictor(model.create_predict_config(SmartInit(args.load)))
prob = pred(NCHW_image)[0] # Nx1000 probabilities
"""
return PredictConfig(model=self, input_names=['input'], output_names=['prob'], session_init=session_init)
@@ -166,8 +166,7 @@ def get_config():
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.session_init = SmartInit(args.load)
nr_tower = get_num_gpu()
assert nr_tower == NUM_GPU
launch_train_with_config(config, SyncMultiGPUTrainer(NUM_GPU))
@@ -11,7 +11,7 @@

from tensorpack import *
from tensorpack.dataflow import imgaug
from tensorpack.tfutils import argscope, get_model_loader, model_utils
from tensorpack.tfutils import argscope, SmartInit, model_utils
from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.utils import logger
from tensorpack.utils.gpu import get_num_gpu
@@ -251,7 +251,7 @@ def get_config(model, nr_tower):
if args.eval:
batch = 128 # something that can run on one gpu
ds = get_data('val', batch)
eval_classification(model, get_model_loader(args.load), ds)
eval_classification(model, SmartInit(args.load), ds)
elif args.flops:
# manually build the graph with batch=1
with TowerContext('', is_training=False):
@@ -277,6 +277,5 @@ def get_config(model, nr_tower):

nr_tower = max(get_num_gpu(), 1)
config = get_config(model, nr_tower)
if args.load:
config.session_init = get_model_loader(args.load)
config.session_init = SmartInit(args.load)
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_tower))

0 comments on commit 17b34c6

Please sign in to comment.
You can’t perform that action at this time.