Skip to content

Commit

Permalink
Added in Large-Batch SGD with a warmup, and a LARS startegy. Also add… (
Browse files Browse the repository at this point in the history
apache#8918)

* Added in Large-Batch SGD with a warmup, and a LARS startegy. Also added in a Polynomial Decay learning rate scheduler. Modified the example image fit code to allow these options to be selectable.

* Fix pylint issues

* pylint fixes

* remove duplicate num_update

* remove unused count
  • Loading branch information
ashokei authored and zhreshold committed Jan 29, 2018
1 parent c20730b commit 200a06a
Show file tree
Hide file tree
Showing 3 changed files with 324 additions and 36 deletions.
138 changes: 102 additions & 36 deletions example/image-classification/common/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
""" example train fit utility """
import logging
import os
import time
import re
import math
import mxnet as mx


def _get_lr_scheduler(args, kv):
if 'lr_factor' not in args or args.lr_factor >= 1:
Expand All @@ -27,17 +31,26 @@ def _get_lr_scheduler(args, kv):
if 'dist' in args.kv_store:
epoch_size /= kv.num_workers
begin_epoch = args.load_epoch if args.load_epoch else 0
if 'pow' in args.lr_step_epochs:
lr = args.lr
max_up = args.num_epochs * epoch_size
pwr = float(re.sub('pow[- ]*', '', args.lr_step_epochs))
poly_sched = mx.lr_scheduler.PolyScheduler(max_up, lr, pwr)
return (lr, poly_sched)
step_epochs = [int(l) for l in args.lr_step_epochs.split(',')]
lr = args.lr
for s in step_epochs:
if begin_epoch >= s:
lr *= args.lr_factor
if lr != args.lr:
logging.info('Adjust learning rate to %e for epoch %d' %(lr, begin_epoch))
logging.info('Adjust learning rate to %e for epoch %d',
lr, begin_epoch)

steps = [epoch_size * (x-begin_epoch) for x in step_epochs if x-begin_epoch > 0]
steps = [epoch_size * (x - begin_epoch)
for x in step_epochs if x - begin_epoch > 0]
return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor))


def _load_model(args, rank=0):
if 'load_epoch' not in args or args.load_epoch is None:
return (None, None, None)
Expand All @@ -50,6 +63,7 @@ def _load_model(args, rank=0):
logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch)
return (sym, arg_params, aux_params)


def _save_model(args, rank=0):
if args.model_prefix is None:
return None
Expand All @@ -59,6 +73,7 @@ def _save_model(args, rank=0):
return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % (
args.model_prefix, rank))


def add_fit_args(parser):
"""
parser : argparse.ArgumentParser
Expand All @@ -68,7 +83,8 @@ def add_fit_args(parser):
train.add_argument('--network', type=str,
help='the neural network to use')
train.add_argument('--num-layers', type=int,
help='number of layers in the neural network, required by some networks such as resnet')
help='number of layers in the neural network, \
required by some networks such as resnet')
train.add_argument('--gpus', type=str,
help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu')
train.add_argument('--kv-store', type=str, default='device',
Expand All @@ -81,6 +97,8 @@ def add_fit_args(parser):
help='the ratio to reduce lr on each step')
train.add_argument('--lr-step-epochs', type=str,
help='the epochs to reduce the lr, e.g. 30,60')
train.add_argument('--initializer', type=str, default='default',
help='the initializer type')
train.add_argument('--optimizer', type=str, default='sgd',
help='the optimizer type')
train.add_argument('--mom', type=float, default=0.9,
Expand Down Expand Up @@ -108,8 +126,16 @@ def add_fit_args(parser):
takes `2bit` or `none` for now')
train.add_argument('--gc-threshold', type=float, default=0.5,
help='threshold for 2bit gradient compression')
# additional parameters for large batch sgd
train.add_argument('--macrobatch-size', type=int, default=0,
help='distributed effective batch size')
train.add_argument('--warmup-epochs', type=int, default=5,
help='the epochs to ramp-up lr to scaled large-batch value')
train.add_argument('--warmup-strategy', type=str, default='linear',
help='the ramping-up strategy for large batch sgd')
return train


def fit(args, network, data_loader, **kwargs):
"""
train a model
Expand All @@ -135,14 +161,13 @@ def fit(args, network, data_loader, **kwargs):
for i, batch in enumerate(train):
for j in batch.data:
j.wait_to_read()
if (i+1) % args.disp_batches == 0:
logging.info('Batch [%d]\tSpeed: %.2f samples/sec' % (
i, args.disp_batches*args.batch_size/(time.time()-tic)))
if (i + 1) % args.disp_batches == 0:
logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i,
args.disp_batches * args.batch_size / (time.time() - tic))
tic = time.time()

return


# load model
if 'arg_params' in kwargs and 'aux_params' in kwargs:
arg_params = kwargs['arg_params']
Expand All @@ -156,22 +181,22 @@ def fit(args, network, data_loader, **kwargs):
checkpoint = _save_model(args, kv.rank)

# devices for training
devs = mx.cpu() if args.gpus is None or args.gpus is '' else [
devs = mx.cpu() if args.gpus is None or args.gpus == "" else [
mx.gpu(int(i)) for i in args.gpus.split(',')]

# learning rate
lr, lr_scheduler = _get_lr_scheduler(args, kv)

# create model
model = mx.mod.Module(
context = devs,
symbol = network
context=devs,
symbol=network
)

lr_scheduler = lr_scheduler
lr_scheduler = lr_scheduler
optimizer_params = {
'learning_rate': lr,
'wd' : args.wd,
'wd': args.wd,
'lr_scheduler': lr_scheduler,
'multi_precision': True}

Expand All @@ -180,40 +205,81 @@ def fit(args, network, data_loader, **kwargs):
if args.optimizer in has_momentum:
optimizer_params['momentum'] = args.mom

monitor = mx.mon.Monitor(args.monitor, pattern=".*") if args.monitor > 0 else None
monitor = mx.mon.Monitor(
args.monitor, pattern=".*") if args.monitor > 0 else None

if args.network == 'alexnet':
# AlexNet will not converge using Xavier
initializer = mx.init.Normal()
else:
initializer = mx.init.Xavier(
rnd_type='gaussian', factor_type="in", magnitude=2)
# A limited number of optimizers have a warmup period
has_warmup = {'lbsgd', 'lbnag'}
if args.optimizer in has_warmup:
if 'dist' in args.kv_store:
nworkers = kv.num_workers
else:
nworkers = 1
epoch_size = args.num_examples / args.batch_size / nworkers
if epoch_size < 1:
epoch_size = 1
macrobatch_size = args.macrobatch_size
if macrobatch_size < args.batch_size * nworkers:
macrobatch_size = args.batch_size * nworkers
#batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999)
batch_scale = math.ceil(
float(macrobatch_size) / args.batch_size / nworkers)
optimizer_params['updates_per_epoch'] = epoch_size
optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0
optimizer_params['batch_scale'] = batch_scale
optimizer_params['warmup_strategy'] = args.warmup_strategy
optimizer_params['warmup_epochs'] = args.warmup_epochs
optimizer_params['num_epochs'] = args.num_epochs

if args.initializer == 'default':
if args.network == 'alexnet':
# AlexNet will not converge using Xavier
initializer = mx.init.Normal()
else:
initializer = mx.init.Xavier(
rnd_type='gaussian', factor_type="in", magnitude=2)
# initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
elif args.initializer == 'xavier':
initializer = mx.init.Xavier()
elif args.initializer == 'msra':
initializer = mx.init.MSRAPrelu()
elif args.initializer == 'orthogonal':
initializer = mx.init.Orthogonal()
elif args.initializer == 'normal':
initializer = mx.init.Normal()
elif args.initializer == 'uniform':
initializer = mx.init.Uniform()
elif args.initializer == 'one':
initializer = mx.init.One()
elif args.initializer == 'zero':
initializer = mx.init.Zero()

# evaluation metrices
eval_metrics = ['accuracy']
if args.top_k > 0:
eval_metrics.append(mx.metric.create('top_k_accuracy', top_k=args.top_k))
eval_metrics.append(mx.metric.create(
'top_k_accuracy', top_k=args.top_k))

# callbacks that run after each batch
batch_end_callbacks = [mx.callback.Speedometer(args.batch_size, args.disp_batches)]
batch_end_callbacks = [mx.callback.Speedometer(
args.batch_size, args.disp_batches)]
if 'batch_end_callback' in kwargs:
cbs = kwargs['batch_end_callback']
batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs]

# run
model.fit(train,
begin_epoch = args.load_epoch if args.load_epoch else 0,
num_epoch = args.num_epochs,
eval_data = val,
eval_metric = eval_metrics,
kvstore = kv,
optimizer = args.optimizer,
optimizer_params = optimizer_params,
initializer = initializer,
arg_params = arg_params,
aux_params = aux_params,
batch_end_callback = batch_end_callbacks,
epoch_end_callback = checkpoint,
allow_missing = True,
monitor = monitor)
begin_epoch=args.load_epoch if args.load_epoch else 0,
num_epoch=args.num_epochs,
eval_data=val,
eval_metric=eval_metrics,
kvstore=kv,
optimizer=args.optimizer,
optimizer_params=optimizer_params,
initializer=initializer,
arg_params=arg_params,
aux_params=aux_params,
batch_end_callback=batch_end_callbacks,
epoch_end_callback=checkpoint,
allow_missing=True,
monitor=monitor)
32 changes: 32 additions & 0 deletions python/mxnet/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,35 @@ def __call__(self, num_update):
else:
return self.base_lr
return self.base_lr

class PolyScheduler(LRScheduler):
""" Reduce the learning rate by given a list of steps.
Calculate the new learning rate by::
base_lr * (1-nup/max_nup)^pwr
if nup < max_nup, 0 otherwise.
Parameters
----------
max_update: maximum number of updates before the decay reaches 0.
base_lr: base learning rate
pwr: power of the decay term as a funtion of the current number of updates.
"""

def __init__(self, max_update, base_lr=0.01, pwr=2):
super(PolyScheduler, self).__init__(base_lr)
assert isinstance(max_update, int)
if max_update < 1:
raise ValueError("maximum number of updates must be strictly positive")
self.base_lr_orig = self.base_lr
self.max_update = max_update
self.power = pwr
self.base_lr = self.base_lr_orig

def __call__(self, num_update):
if num_update <= self.max_update:
self.base_lr = self.base_lr_orig * pow(1.0 - float(num_update) / float(self.max_update),
self.power)
return self.base_lr
Loading

0 comments on commit 200a06a

Please sign in to comment.