Skip to content

Commit

Permalink
add an inception config
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 22, 2016
1 parent a949bfa commit f67dc2a
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 2 deletions.
201 changes: 201 additions & 0 deletions examples/Inception/inception-bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: inception-bn.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>

import cv2
import argparse
import numpy as np
import os
import tensorflow as tf

from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *


BATCH_SIZE = 64
INPUT_SHAPE = 224

"""
Inception-BN model on ILSVRC12.
See "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift", arxiv:1502.03167
This config reaches 71% single-crop validation error after 300k steps with 6 TitanX.
Learning rate may need a different schedule for different number of GPUs (because batch size will be different).
"""

class Model(ModelDesc):
def __init__(self):
super(Model, self).__init__()

def _get_input_vars(self):
return [InputVar(tf.float32, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputVar(tf.int32, [None], 'label') ]

def _get_cost(self, input_vars, is_training):
image, label = input_vars
image = image / 128.0 - 1

def inception(name, x, nr1x1, nr3x3r, nr3x3, nr233r, nr233, nrpool, pooltype):
stride = 2 if nr1x1 == 0 else 1
with tf.variable_scope(name) as scope:
outs = []
if nr1x1 != 0:
outs.append(Conv2D('conv1x1', x, nr1x1, 1))
x2 = Conv2D('conv3x3r', x, nr3x3r, 1)
outs.append(Conv2D('conv3x3', x2, nr3x3, 3, stride=stride))


x3 = Conv2D('conv233r', x, nr233r, 1)
x3 = Conv2D('conv233a', x3, nr233, 3)
outs.append(Conv2D('conv233b', x3, nr233, 3, stride=stride))

if pooltype == 'max':
x4 = MaxPooling('mpool', x, 3, stride, padding='SAME')
else:
assert pooltype == 'avg'
x4 = AvgPooling('apool', x, 3, stride, padding='SAME')
if nrpool != 0: # pool + passthrough if nrpool == 0
x4 = Conv2D('poolproj', x4, nrpool, 1)
outs.append(x4)
return tf.concat(3, outs, name='concat')

with argscope(Conv2D, nl=BNReLU(is_training), use_bias=False):
l = Conv2D('conv0', image, 64, 7, stride=2)
l = MaxPooling('pool0', l, 3, 2, padding='SAME')
l = Conv2D('conv1', l, 64, 1)
l = Conv2D('conv2', l, 192, 3)
l = MaxPooling('pool2', l, 3, 2, padding='SAME')
# 28
l = inception('incep3a', l, 64, 64, 64, 64, 96, 32, 'avg')
l = inception('incep3b', l, 64, 64, 96, 64, 96, 64, 'avg')
l = inception('incep3c', l, 0, 128, 160, 64, 96, 0, 'max')

br1 = Conv2D('loss1conv', l, 128, 1)
br1 = FullyConnected('loss1fc', br1, 1024)
br1 = FullyConnected('loss1logit', br1, 1000, nl=tf.identity)
loss1 = tf.nn.sparse_softmax_cross_entropy_with_logits(br1, label)
loss1 = tf.reduce_mean(loss1, name='loss1')

# 14
l = inception('incep4a', l, 224, 64, 96, 96, 128, 128, 'avg')
l = inception('incep4b', l, 192, 96, 128, 96, 128, 128, 'avg')
l = inception('incep4c', l, 160, 128, 160, 128, 160, 128, 'avg')
l = inception('incep4d', l, 96, 128, 192, 160, 192, 128, 'avg')
l = inception('incep4e', l, 0, 128, 192, 192, 256, 0, 'max')

br2 = Conv2D('loss2conv', l, 128, 1)
br2 = FullyConnected('loss2fc', br2, 1024)
br2 = FullyConnected('loss2logit', br2, 1000, nl=tf.identity)
loss2 = tf.nn.sparse_softmax_cross_entropy_with_logits(br2, label)
loss2 = tf.reduce_mean(loss2, name='loss2')

# 7
l = inception('incep5a', l, 352, 192, 320, 160, 224, 128, 'avg')
l = inception('incep5b', l, 352, 192, 320, 192, 224, 128, 'max')
l = GlobalAvgPooling('gap', l)

logits = FullyConnected('linear', l, out_dim=1000, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output')
loss3 = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
loss3 = tf.reduce_mean(loss3, name='loss3')

cost = tf.add_n([loss3, 0.3 * loss2, 0.3 * loss1], name='weighted_cost')
for k in [cost, loss1, loss2, loss3]:
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, k)

wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))

# weight decay on all W of fc layers
wd_w = tf.train.exponential_decay(0.0002, get_global_step_var(),
80000, 0.7, True)
wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='l2_regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)

add_param_summary([('.*/W', ['histogram'])]) # monitor W
return tf.add_n([cost, wd_cost], name='cost')

def get_data(train_or_test):
isTrain = train_or_test == 'train'
ds = dataset.ILSVRC12(args.data, train_or_test, shuffle=True if isTrain else False)
meta = dataset.ILSVRCMeta()
pp_mean = meta.get_per_pixel_mean()

if isTrain:
augmentors = [
imgaug.Resize((256, 256)),
imgaug.MapImage(lambda x: x - pp_mean),
imgaug.RandomCrop((224, 224)),
imgaug.Flip(horiz=True),
]
else:
augmentors = [
imgaug.Resize((256, 256)),
imgaug.MapImage(lambda x: x - pp_mean),
imgaug.CenterCrop((224, 224)),
]
ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain:
ds = PrefetchData(ds, 20, 5)
return ds


def get_config():
# prepare dataset
dataset_train = get_data('train')
step_per_epoch = 5000
dataset_val = get_data('val')

sess_config = get_default_sess_config(0.99)

lr = tf.Variable(0.045, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr)

return TrainConfig(
dataset=dataset_train,
optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([
StatPrinter(),
ModelSaver(),
InferenceRunner(dataset_val, ClassificationError()),
#HumanHyperParamSetter('learning_rate', 'hyper-googlenet.txt')
ScheduledHyperParamSetter('learning_rate',
[(8, 0.03), (13, 0.02), (21, 5e-3),
(28, 3e-3), (33, 1e-3), (44, 5e-4),
(49, 1e-4), (59, 2e-5)])
]),
session_config=sess_config,
model=Model(),
step_per_epoch=step_per_epoch,
max_epoch=80,
)

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model')
parser.add_argument('--data', help='ImageNet data root directory',
required=True)
global args
args = parser.parse_args()

basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))

if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

with tf.Graph().as_default():
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train()
2 changes: 1 addition & 1 deletion scripts/dump_train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi))
cv2.imwrite(fname, img * args.scale)

NR_DP_TEST = 100
NR_DP_TEST = args.number
logger.info("Testing dataflow speed:")
with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar:
for idx, dp in enumerate(config.dataset.get_data()):
Expand Down
2 changes: 2 additions & 0 deletions tensorpack/dataflow/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def get_data(self):
yield dp

def __del__(self):
logger.info("Prefetch process exiting...")
self.queue.close()
for x in self.procs:
x.terminate()
logger.info("Prefetch process exited.")

4 changes: 3 additions & 1 deletion tensorpack/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, trainer, queue, enqueue_op, raw_input_var):
super(EnqueueThread, self).__init__()
self.sess = trainer.sess
self.coord = trainer.coord
self.dataflow = trainer.config.dataset
self.dataflow = RepeatedData(trainer.config.dataset, -1)

self.input_vars = raw_input_var
self.op = enqueue_op
Expand All @@ -76,6 +76,8 @@ def run(self):
logger.exception("Exception in EnqueueThread:")
self.sess.run(self.close_op)
self.coord.request_stop()
finally:
logger.info("Enqueue Thread Exited.")

class QueueInputTrainer(Trainer):
"""
Expand Down

0 comments on commit f67dc2a

Please sign in to comment.