From 59e9b3db365e691f993e5f3c3f1483422b7ed6ad Mon Sep 17 00:00:00 2001 From: Yoshiyuki Kobayashi Date: Thu, 7 Feb 2019 16:00:32 +0900 Subject: [PATCH] Fixed learning rate update interval when using multiple GPUs in CLI --- python/src/nnabla/utils/load.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/python/src/nnabla/utils/load.py b/python/src/nnabla/utils/load.py index 5e29a8b34..f31777bae 100644 --- a/python/src/nnabla/utils/load.py +++ b/python/src/nnabla/utils/load.py @@ -452,41 +452,34 @@ class Optimizer: optimizer.solver.set_states_from_protobuf(o) optimizer.comm = current_communicator() + comm_size = optimizer.comm.size if optimizer.comm else 1 optimizer.scheduler = None if o.solver.lr_scheduler_type == 'Polynomial': if o.solver.polynomial_scheduler_param.power != 0.0: optimizer.scheduler = PolynomialScheduler( - init_lr, o.solver.polynomial_scheduler_param.max_iter // (optimizer.comm.size if optimizer.comm else 1), o.solver.polynomial_scheduler_param.power) + init_lr, o.solver.polynomial_scheduler_param.max_iter // comm_size, o.solver.polynomial_scheduler_param.power) elif o.solver.lr_scheduler_type == 'Cosine': optimizer.scheduler = CosineScheduler( - init_lr, o.solver.cosine_scheduler_param.max_iter // (optimizer.comm.size if optimizer.comm else 1)) + init_lr, o.solver.cosine_scheduler_param.max_iter // comm_size) elif o.solver.lr_scheduler_type == 'Exponential': if o.solver.exponential_scheduler_param.gamma != 1.0: optimizer.scheduler = ExponentialScheduler( - init_lr, o.solver.exponential_scheduler_param.gamma, o.solver.exponential_scheduler_param.iter_interval if o.solver.exponential_scheduler_param.iter_interval > 0 else 1) + init_lr, o.solver.exponential_scheduler_param.gamma, o.solver.exponential_scheduler_param.iter_interval // comm_size if o.solver.exponential_scheduler_param.iter_interval > comm_size else 1) elif o.solver.lr_scheduler_type == 'Step': if o.solver.step_scheduler_param.gamma != 1.0 and len(o.solver.step_scheduler_param.iter_steps) > 0: optimizer.scheduler = StepScheduler( - init_lr, o.solver.step_scheduler_param.gamma, o.solver.step_scheduler_param.iter_steps) + init_lr, o.solver.step_scheduler_param.gamma, [step // comm_size for step in o.solver.step_scheduler_param.iter_steps]) elif o.solver.lr_scheduler_type == 'Custom': # ToDo raise NotImplementedError() elif o.solver.lr_scheduler_type == '': if o.solver.lr_decay_interval != 0 or o.solver.lr_decay != 0.0: optimizer.scheduler = ExponentialScheduler( - init_lr, o.solver.lr_decay if o.solver.lr_decay > 0.0 else 1.0, o.solver.lr_decay_interval if o.solver.lr_decay_interval > 0 else 1) + init_lr, o.solver.lr_decay if o.solver.lr_decay > 0.0 else 1.0, o.solver.lr_decay_interval // comm_size if o.solver.lr_decay_interval > comm_size else 1) else: raise ValueError('Learning Rate Scheduler "' + o.solver.lr_scheduler_type + '" is not supported.') - if optimizer.comm is not None: - new_interval = optimizer.lr_decay_interval // optimizer.comm.size - if new_interval == 0: - new_interval = 1 - logger.log(99, 'LR Decay interval divide by {} ({} -> {})'.format( - optimizer.comm.size, optimizer.lr_decay_interval, new_interval)) - optimizer.lr_decay_interval = new_interval - optimizer.forward_sequence = optimizer.network.get_forward_sequence( optimizer.loss_variables) optimizer.backward_sequence = optimizer.network.get_backward_sequence(