Skip to content

Commit

Permalink
Merge pull request #346 from sony/fix/20190205-mgpu-update-interval
Browse files Browse the repository at this point in the history
Fixed learning rate update interval when using multiple GPUs in CLI
  • Loading branch information
TE-HidehoGomi committed Feb 7, 2019
2 parents 588b253 + 59e9b3d commit 45e80bd
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions python/src/nnabla/utils/load.py
Expand Up @@ -452,41 +452,34 @@ class Optimizer:
optimizer.solver.set_states_from_protobuf(o)

optimizer.comm = current_communicator()
comm_size = optimizer.comm.size if optimizer.comm else 1
optimizer.scheduler = None
if o.solver.lr_scheduler_type == 'Polynomial':
if o.solver.polynomial_scheduler_param.power != 0.0:
optimizer.scheduler = PolynomialScheduler(
init_lr, o.solver.polynomial_scheduler_param.max_iter // (optimizer.comm.size if optimizer.comm else 1), o.solver.polynomial_scheduler_param.power)
init_lr, o.solver.polynomial_scheduler_param.max_iter // comm_size, o.solver.polynomial_scheduler_param.power)
elif o.solver.lr_scheduler_type == 'Cosine':
optimizer.scheduler = CosineScheduler(
init_lr, o.solver.cosine_scheduler_param.max_iter // (optimizer.comm.size if optimizer.comm else 1))
init_lr, o.solver.cosine_scheduler_param.max_iter // comm_size)
elif o.solver.lr_scheduler_type == 'Exponential':
if o.solver.exponential_scheduler_param.gamma != 1.0:
optimizer.scheduler = ExponentialScheduler(
init_lr, o.solver.exponential_scheduler_param.gamma, o.solver.exponential_scheduler_param.iter_interval if o.solver.exponential_scheduler_param.iter_interval > 0 else 1)
init_lr, o.solver.exponential_scheduler_param.gamma, o.solver.exponential_scheduler_param.iter_interval // comm_size if o.solver.exponential_scheduler_param.iter_interval > comm_size else 1)
elif o.solver.lr_scheduler_type == 'Step':
if o.solver.step_scheduler_param.gamma != 1.0 and len(o.solver.step_scheduler_param.iter_steps) > 0:
optimizer.scheduler = StepScheduler(
init_lr, o.solver.step_scheduler_param.gamma, o.solver.step_scheduler_param.iter_steps)
init_lr, o.solver.step_scheduler_param.gamma, [step // comm_size for step in o.solver.step_scheduler_param.iter_steps])
elif o.solver.lr_scheduler_type == 'Custom':
# ToDo
raise NotImplementedError()
elif o.solver.lr_scheduler_type == '':
if o.solver.lr_decay_interval != 0 or o.solver.lr_decay != 0.0:
optimizer.scheduler = ExponentialScheduler(
init_lr, o.solver.lr_decay if o.solver.lr_decay > 0.0 else 1.0, o.solver.lr_decay_interval if o.solver.lr_decay_interval > 0 else 1)
init_lr, o.solver.lr_decay if o.solver.lr_decay > 0.0 else 1.0, o.solver.lr_decay_interval // comm_size if o.solver.lr_decay_interval > comm_size else 1)
else:
raise ValueError('Learning Rate Scheduler "' + o.solver.lr_scheduler_type +
'" is not supported.')

if optimizer.comm is not None:
new_interval = optimizer.lr_decay_interval // optimizer.comm.size
if new_interval == 0:
new_interval = 1
logger.log(99, 'LR Decay interval divide by {} ({} -> {})'.format(
optimizer.comm.size, optimizer.lr_decay_interval, new_interval))
optimizer.lr_decay_interval = new_interval

optimizer.forward_sequence = optimizer.network.get_forward_sequence(
optimizer.loss_variables)
optimizer.backward_sequence = optimizer.network.get_backward_sequence(
Expand Down

0 comments on commit 45e80bd

Please sign in to comment.