diff --git a/docs/source/optim.rst b/docs/source/optim.rst index d6d89c915596..7ca7725d5d4e 100644 --- a/docs/source/optim.rst +++ b/docs/source/optim.rst @@ -167,3 +167,5 @@ should write your code this way: :members: .. autoclass:: torch.optim.lr_scheduler.CyclicLR :members: +.. autoclass:: torch.optim.lr_scheduler.OneCycleLR + :members: diff --git a/test/test_optim.py b/test/test_optim.py index e334962c91a7..f66e53d343cc 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -13,7 +13,7 @@ from torch import sparse from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, \ ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler, \ - CyclicLR, CosineAnnealingWarmRestarts + CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \ skipIfRocm @@ -1013,6 +1013,58 @@ def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self): adam_opt = optim.Adam(self.net.parameters()) scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True) + def test_onecycle_lr_invalid_anneal_strategy(self): + with self.assertRaises(ValueError): + scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, anneal_strategy="CATS") + + def test_onecycle_lr_invalid_pct_start(self): + with self.assertRaises(ValueError): + scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, pct_start=1.1) + + def test_onecycle_lr_cannot_calculate_total_steps(self): + with self.assertRaises(ValueError): + scheduler = OneCycleLR(self.opt, max_lr=1e-3) + + def test_onecycle_lr_linear_annealing(self): + lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5] + momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22] + lr_targets = [lr_target, lr_target] + momentum_targets = [momentum_target, momentum_target] + scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22, + total_steps=10, anneal_strategy='linear') + self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) + + def test_onecycle_lr_cosine_annealing(self): + def annealing_cos(start, end, pct): + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + lr_target = [1, 13, 25, annealing_cos(25, 0.5, 1 / 7.0), annealing_cos(25, 0.5, 2 / 7.0), + annealing_cos(25, 0.5, 3 / 7.0), annealing_cos(25, 0.5, 4 / 7.0), annealing_cos(25, 0.5, 5 / 7.0), + annealing_cos(25, 0.5, 6 / 7.0), 0.5] + momentum_target = [22, 11.5, 1, annealing_cos(1, 22, 1 / 7.0), annealing_cos(1, 22, 2 / 7.0), + annealing_cos(1, 22, 3 / 7.0), annealing_cos(1, 22, 4 / 7.0), annealing_cos(1, 22, 5 / 7.0), + annealing_cos(1, 22, 6 / 7.0), 22] + lr_targets = [lr_target, lr_target] + momentum_targets = [momentum_target, momentum_target] + scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22, + total_steps=10) + self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10) + + def test_cycle_lr_with_adam(self): + old_opt = self.opt + self.opt = optim.Adam( + [{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}], + lr=0.05) + + lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5] + momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22] + lr_targets = [lr_target, lr_target] + momentum_targets = [momentum_target, momentum_target] + scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22, + total_steps=10, anneal_strategy='linear') + self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10, use_beta1=True) + self.opt = old_opt # set optimizer back to SGD + def test_lambda_lr(self): epochs = 10 self.opt.param_groups[0]['lr'] = 0.05 @@ -1206,13 +1258,16 @@ def _test_reduce_lr_on_plateau(self, schedulers, targets, metrics, epochs=10, ve msg='LR is wrong in epoch {}: expected {}, got {}'.format( epoch, target[epoch], param_group['lr']), delta=1e-5) - def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iterations, verbose=False): + def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iterations, verbose=False, use_beta1=False): for batch_num in range(batch_iterations): scheduler.step(batch_num) if verbose: if 'momentum' in self.opt.param_groups[0].keys(): print('batch{}:\tlr={},momentum={}'.format(batch_num, self.opt.param_groups[0]['lr'], self.opt.param_groups[0]['momentum'])) + elif use_beta1 and 'betas' in self.opt.param_groups[0].keys(): + print('batch{}:\tlr={},beta1={}'.format(batch_num, self.opt.param_groups[0]['lr'], + self.opt.param_groups[0]['betas'][0])) else: print('batch{}:\tlr={}'.format(batch_num, self.opt.param_groups[0]['lr'])) @@ -1222,7 +1277,12 @@ def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iteratio msg='LR is wrong in batch_num {}: expected {}, got {}'.format( batch_num, lr_target[batch_num], param_group['lr']), delta=1e-5) - if 'momentum' in param_group.keys(): + if use_beta1 and 'betas' in param_group.keys(): + self.assertAlmostEqual( + momentum_target[batch_num], param_group['betas'][0], + msg='Beta1 is wrong in batch_num {}: expected {}, got {}'.format( + batch_num, momentum_target[batch_num], param_group['betas'][0]), delta=1e-5) + elif 'momentum' in param_group.keys(): self.assertAlmostEqual( momentum_target[batch_num], param_group['momentum'], msg='Momentum is wrong in batch_num {}: expected {}, got {}'.format( diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 6101e8ca9b03..ea18caa276f9 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -759,3 +759,223 @@ def step(self, epoch=None): self.last_epoch = math.floor(epoch) for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group['lr'] = lr + +class OneCycleLR(_LRScheduler): + r"""Sets the learning rate of each parameter group according to the + 1cycle learning rate policy. The 1cycle policy anneals the learning + rate from an initial learning rate to some maximum learning rate and then + from that maximum learning rate to some minimum learning rate much lower + than the initial learning rate. + This policy was initially described in the paper `Super-Convergence: + Very Fast Training of Neural Networks Using Large Learning Rates`_. + + The 1cycle learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + This scheduler is not chainable. + + This class has two built-in annealing strategies: + "cos": + Cosine annealing + "linear": + Linear annealing + + Note also that the total number of steps in the cycle can be determined in one + of two ways (listed in order of precedence): + 1) A value for total_steps is explicitly provided. + 2) A number of epochs (epochs) and a number of steps per epoch + (steps_per_epoch) are provided. + In this case, the number of total steps is inferred by + total_steps = epochs * steps_per_epoch + You must either provide a value for total_steps or provide a value for both + epochs and steps_per_epoch. + + Args: + optimizer (Optimizer): Wrapped optimizer. + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. + total_steps (int): The total number of steps in the cycle. Note that + if a value is provided here, then it must be inferred by providing + a value for epochs and steps_per_epoch. + Default: None + epochs (int): The number of epochs to train for. This is used along + with steps_per_epoch in order to infer the total number of steps in the cycle + if a value for total_steps is not provided. + Default: None + steps_per_epoch (int): The number of steps per epoch to train for. This is + used along with epochs in order to infer the total number of steps in the + cycle if a value for total_steps is not provided. + Default: None + pct_start (float): The percentage of the cycle (in number of steps) spent + increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy. + Default: 'cos' + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum'. + Default: True + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.85 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is 'max_momentum' + and learning rate is 'base_lr' + Default: 0.95 + div_factor (float): Determines the initial learning rate via + initial_lr = max_lr/div_factor + Default: 25 + final_div_factor (float): Determines the minimum learning rate via + min_lr = initial_lr/final_div_factor + Default: 1e4 + last_epoch (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning. + Default: -1 + + Example: + >>> data_loader = torch.utils.data.DataLoader(...) + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> scheduler.step() + + + .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: + https://arxiv.org/abs/1708.07120 + """ + def __init__(self, + optimizer, + max_lr, + total_steps=None, + epochs=None, + steps_per_epoch=None, + pct_start=0.3, + anneal_strategy='cos', + cycle_momentum=True, + base_momentum=0.85, + max_momentum=0.95, + div_factor=25., + final_div_factor=1e4, + last_epoch=-1): + + # Validate optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + + # Validate total_steps + if total_steps is None and epochs is None and steps_per_epoch is None: + raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)") + elif total_steps is not None: + if total_steps <= 0 or not isinstance(total_steps, int): + raise ValueError("Expected non-negative integer total_steps, but got {}".format(total_steps)) + self.total_steps = total_steps + else: + if epochs <= 0 or not isinstance(epochs, int): + raise ValueError("Expected non-negative integer epochs, but got {}".format(epochs)) + if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int): + raise ValueError("Expected non-negative integer steps_per_epoch, but got {}".format(steps_per_epoch)) + self.total_steps = epochs * steps_per_epoch + self.step_size_up = float(pct_start * self.total_steps) - 1 + self.step_size_down = float(self.total_steps - self.step_size_up) - 1 + + # Validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start)) + + # Validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy)) + elif anneal_strategy == 'cos': + self.anneal_func = self._annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = self._annealing_linear + + # Initialize learning rate variables + max_lrs = self._format_param('max_lr', self.optimizer, max_lr) + if last_epoch == -1: + for idx, group in enumerate(self.optimizer.param_groups): + group['lr'] = max_lrs[idx] / div_factor + group['max_lr'] = max_lrs[idx] + group['min_lr'] = group['lr'] / final_div_factor + + # Initialize momentum variables + self.cycle_momentum = cycle_momentum + if self.cycle_momentum: + if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults: + raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') + self.use_beta1 = 'betas' in self.optimizer.defaults + max_momentums = self._format_param('max_momentum', optimizer, max_momentum) + base_momentums = self._format_param('base_momentum', optimizer, base_momentum) + if last_epoch == -1: + for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups): + if self.use_beta1: + _, beta2 = group['betas'] + group['betas'] = (m_momentum, beta2) + else: + group['momentum'] = m_momentum + group['max_momentum'] = m_momentum + group['base_momentum'] = b_momentum + + super(OneCycleLR, self).__init__(optimizer, last_epoch) + + def _format_param(self, name, optimizer, param): + """Return correctly formatted lr/momentum for each param group.""" + if isinstance(param, (list, tuple)): + if len(param) != len(optimizer.param_groups): + raise ValueError("expected {} values for {}, got {}".format( + len(optimizer.param_groups), name, len(param))) + return param + else: + return [param] * len(optimizer.param_groups) + + def _annealing_cos(self, start, end, pct): + "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + def _annealing_linear(self, start, end, pct): + "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." + return (end - start) * pct + start + + def get_lr(self): + lrs = [] + step_num = self.last_epoch + + if step_num > self.total_steps: + raise ValueError("Tried to step {} times. The specified number of total steps is {}" + .format(step_num + 1, self.total_steps)) + + for group in self.optimizer.param_groups: + if step_num <= self.step_size_up: + computed_lr = self.anneal_func(group['initial_lr'], group['max_lr'], step_num / self.step_size_up) + if self.cycle_momentum: + computed_momentum = self.anneal_func(group['max_momentum'], group['base_momentum'], + step_num / self.step_size_up) + else: + down_step_num = step_num - self.step_size_up + computed_lr = self.anneal_func(group['max_lr'], group['min_lr'], down_step_num / self.step_size_down) + if self.cycle_momentum: + computed_momentum = self.anneal_func(group['base_momentum'], group['max_momentum'], + down_step_num / self.step_size_down) + + lrs.append(computed_lr) + if self.cycle_momentum: + if self.use_beta1: + _, beta2 = group['betas'] + group['betas'] = (computed_momentum, beta2) + else: + group['momentum'] = computed_momentum + + return lrs