Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement 1cycle learning rate policy #21258

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/optim.rst
Expand Up @@ -167,3 +167,5 @@ should write your code this way:
:members:
.. autoclass:: torch.optim.lr_scheduler.CyclicLR
:members:
.. autoclass:: torch.optim.lr_scheduler.OneCycleLR
:members:
66 changes: 63 additions & 3 deletions test/test_optim.py
Expand Up @@ -13,7 +13,7 @@
from torch import sparse
from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, \
ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler, \
CyclicLR, CosineAnnealingWarmRestarts
CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR
from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \
skipIfRocm

Expand Down Expand Up @@ -1013,6 +1013,58 @@ def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self):
adam_opt = optim.Adam(self.net.parameters())
scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)

def test_onecycle_lr_invalid_anneal_strategy(self):
with self.assertRaises(ValueError):
scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, anneal_strategy="CATS")

def test_onecycle_lr_invalid_pct_start(self):
with self.assertRaises(ValueError):
scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, pct_start=1.1)

def test_onecycle_lr_cannot_calculate_total_steps(self):
with self.assertRaises(ValueError):
scheduler = OneCycleLR(self.opt, max_lr=1e-3)

def test_onecycle_lr_linear_annealing(self):
lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
vincentqb marked this conversation as resolved.
Show resolved Hide resolved
momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
lr_targets = [lr_target, lr_target]
momentum_targets = [momentum_target, momentum_target]
scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22,
total_steps=10, anneal_strategy='linear')
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)

def test_onecycle_lr_cosine_annealing(self):
def annealing_cos(start, end, pct):
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out
lr_target = [1, 13, 25, annealing_cos(25, 0.5, 1 / 7.0), annealing_cos(25, 0.5, 2 / 7.0),
annealing_cos(25, 0.5, 3 / 7.0), annealing_cos(25, 0.5, 4 / 7.0), annealing_cos(25, 0.5, 5 / 7.0),
annealing_cos(25, 0.5, 6 / 7.0), 0.5]
momentum_target = [22, 11.5, 1, annealing_cos(1, 22, 1 / 7.0), annealing_cos(1, 22, 2 / 7.0),
annealing_cos(1, 22, 3 / 7.0), annealing_cos(1, 22, 4 / 7.0), annealing_cos(1, 22, 5 / 7.0),
annealing_cos(1, 22, 6 / 7.0), 22]
lr_targets = [lr_target, lr_target]
momentum_targets = [momentum_target, momentum_target]
scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22,
total_steps=10)
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)

def test_cycle_lr_with_adam(self):
old_opt = self.opt
self.opt = optim.Adam(
[{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}],
lr=0.05)

lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
lr_targets = [lr_target, lr_target]
momentum_targets = [momentum_target, momentum_target]
scheduler = OneCycleLR(self.opt, max_lr=25, final_div_factor=2, base_momentum=1, max_momentum=22,
total_steps=10, anneal_strategy='linear')
self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10, use_beta1=True)
self.opt = old_opt # set optimizer back to SGD

def test_lambda_lr(self):
epochs = 10
self.opt.param_groups[0]['lr'] = 0.05
Expand Down Expand Up @@ -1206,13 +1258,16 @@ def _test_reduce_lr_on_plateau(self, schedulers, targets, metrics, epochs=10, ve
msg='LR is wrong in epoch {}: expected {}, got {}'.format(
epoch, target[epoch], param_group['lr']), delta=1e-5)

def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iterations, verbose=False):
def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iterations, verbose=False, use_beta1=False):
for batch_num in range(batch_iterations):
scheduler.step(batch_num)
if verbose:
if 'momentum' in self.opt.param_groups[0].keys():
print('batch{}:\tlr={},momentum={}'.format(batch_num, self.opt.param_groups[0]['lr'],
self.opt.param_groups[0]['momentum']))
elif use_beta1 and 'betas' in self.opt.param_groups[0].keys():
print('batch{}:\tlr={},beta1={}'.format(batch_num, self.opt.param_groups[0]['lr'],
self.opt.param_groups[0]['betas'][0]))
else:
print('batch{}:\tlr={}'.format(batch_num, self.opt.param_groups[0]['lr']))

Expand All @@ -1222,7 +1277,12 @@ def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iteratio
msg='LR is wrong in batch_num {}: expected {}, got {}'.format(
batch_num, lr_target[batch_num], param_group['lr']), delta=1e-5)

if 'momentum' in param_group.keys():
if use_beta1 and 'betas' in param_group.keys():
self.assertAlmostEqual(
momentum_target[batch_num], param_group['betas'][0],
msg='Beta1 is wrong in batch_num {}: expected {}, got {}'.format(
batch_num, momentum_target[batch_num], param_group['betas'][0]), delta=1e-5)
elif 'momentum' in param_group.keys():
self.assertAlmostEqual(
momentum_target[batch_num], param_group['momentum'],
msg='Momentum is wrong in batch_num {}: expected {}, got {}'.format(
Expand Down
220 changes: 220 additions & 0 deletions torch/optim/lr_scheduler.py
Expand Up @@ -759,3 +759,223 @@ def step(self, epoch=None):
self.last_epoch = math.floor(epoch)
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr

class OneCycleLR(_LRScheduler):
r"""Sets the learning rate of each parameter group according to the
1cycle learning rate policy. The 1cycle policy anneals the learning
rate from an initial learning rate to some maximum learning rate and then
from that maximum learning rate to some minimum learning rate much lower
than the initial learning rate.
This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_.

The 1cycle learning rate policy changes the learning rate after every batch.
`step` should be called after a batch has been used for training.

This scheduler is not chainable.

This class has two built-in annealing strategies:
"cos":
Cosine annealing
"linear":
Linear annealing

Note also that the total number of steps in the cycle can be determined in one
of two ways (listed in order of precedence):
1) A value for total_steps is explicitly provided.
2) A number of epochs (epochs) and a number of steps per epoch
(steps_per_epoch) are provided.
In this case, the number of total steps is inferred by
total_steps = epochs * steps_per_epoch
You must either provide a value for total_steps or provide a value for both
epochs and steps_per_epoch.

Args:
optimizer (Optimizer): Wrapped optimizer.
max_lr (float or list): Upper learning rate boundaries in the cycle
for each parameter group.
total_steps (int): The total number of steps in the cycle. Note that
if a value is provided here, then it must be inferred by providing
a value for epochs and steps_per_epoch.
Default: None
epochs (int): The number of epochs to train for. This is used along
with steps_per_epoch in order to infer the total number of steps in the cycle
if a value for total_steps is not provided.
Default: None
steps_per_epoch (int): The number of steps per epoch to train for. This is
used along with epochs in order to infer the total number of steps in the
cycle if a value for total_steps is not provided.
Default: None
pct_start (float): The percentage of the cycle (in number of steps) spent
increasing the learning rate.
Default: 0.3
anneal_strategy (str): {'cos', 'linear'}
Specifies the annealing strategy.
Default: 'cos'
cycle_momentum (bool): If ``True``, momentum is cycled inversely
to learning rate between 'base_momentum' and 'max_momentum'.
Default: True
base_momentum (float or list): Lower momentum boundaries in the cycle
for each parameter group. Note that momentum is cycled inversely
to learning rate; at the peak of a cycle, momentum is
'base_momentum' and learning rate is 'max_lr'.
Default: 0.85
max_momentum (float or list): Upper momentum boundaries in the cycle
for each parameter group. Functionally,
it defines the cycle amplitude (max_momentum - base_momentum).
Note that momentum is cycled inversely
to learning rate; at the start of a cycle, momentum is 'max_momentum'
and learning rate is 'base_lr'
Default: 0.95
div_factor (float): Determines the initial learning rate via
initial_lr = max_lr/div_factor
Default: 25
final_div_factor (float): Determines the minimum learning rate via
min_lr = initial_lr/final_div_factor
Default: 1e4
last_epoch (int): The index of the last batch. This parameter is used when
resuming a training job. Since `step()` should be invoked after each
batch instead of after each epoch, this number represents the total
number of *batches* computed, not the total number of epochs computed.
When last_epoch=-1, the schedule is started from the beginning.
Default: -1

Example:
>>> data_loader = torch.utils.data.DataLoader(...)
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()


.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
https://arxiv.org/abs/1708.07120
"""
def __init__(self,
optimizer,
max_lr,
total_steps=None,
epochs=None,
steps_per_epoch=None,
pct_start=0.3,
anneal_strategy='cos',
cycle_momentum=True,
base_momentum=0.85,
max_momentum=0.95,
div_factor=25.,
final_div_factor=1e4,
last_epoch=-1):

# Validate optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError('{} is not an Optimizer'.format(
type(optimizer).__name__))
self.optimizer = optimizer

# Validate total_steps
if total_steps is None and epochs is None and steps_per_epoch is None:
raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
elif total_steps is not None:
if total_steps <= 0 or not isinstance(total_steps, int):
raise ValueError("Expected non-negative integer total_steps, but got {}".format(total_steps))
self.total_steps = total_steps
else:
if epochs <= 0 or not isinstance(epochs, int):
raise ValueError("Expected non-negative integer epochs, but got {}".format(epochs))
if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
raise ValueError("Expected non-negative integer steps_per_epoch, but got {}".format(steps_per_epoch))
self.total_steps = epochs * steps_per_epoch
self.step_size_up = float(pct_start * self.total_steps) - 1
self.step_size_down = float(self.total_steps - self.step_size_up) - 1

# Validate pct_start
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))

# Validate anneal_strategy
if anneal_strategy not in ['cos', 'linear']:
raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
elif anneal_strategy == 'cos':
self.anneal_func = self._annealing_cos
elif anneal_strategy == 'linear':
self.anneal_func = self._annealing_linear

# Initialize learning rate variables
max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
if last_epoch == -1:
for idx, group in enumerate(self.optimizer.param_groups):
group['lr'] = max_lrs[idx] / div_factor
group['max_lr'] = max_lrs[idx]
group['min_lr'] = group['lr'] / final_div_factor

# Initialize momentum variables
self.cycle_momentum = cycle_momentum
if self.cycle_momentum:
if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
self.use_beta1 = 'betas' in self.optimizer.defaults
max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
if last_epoch == -1:
for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
if self.use_beta1:
_, beta2 = group['betas']
group['betas'] = (m_momentum, beta2)
else:
group['momentum'] = m_momentum
group['max_momentum'] = m_momentum
group['base_momentum'] = b_momentum

super(OneCycleLR, self).__init__(optimizer, last_epoch)

def _format_param(self, name, optimizer, param):
"""Return correctly formatted lr/momentum for each param group."""
if isinstance(param, (list, tuple)):
if len(param) != len(optimizer.param_groups):
raise ValueError("expected {} values for {}, got {}".format(
len(optimizer.param_groups), name, len(param)))
return param
else:
return [param] * len(optimizer.param_groups)

def _annealing_cos(self, start, end, pct):
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out

def _annealing_linear(self, start, end, pct):
vincentqb marked this conversation as resolved.
Show resolved Hide resolved
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
return (end - start) * pct + start

def get_lr(self):
lrs = []
step_num = self.last_epoch

if step_num > self.total_steps:
raise ValueError("Tried to step {} times. The specified number of total steps is {}"
.format(step_num + 1, self.total_steps))

for group in self.optimizer.param_groups:
if step_num <= self.step_size_up:
computed_lr = self.anneal_func(group['initial_lr'], group['max_lr'], step_num / self.step_size_up)
if self.cycle_momentum:
computed_momentum = self.anneal_func(group['max_momentum'], group['base_momentum'],
step_num / self.step_size_up)
else:
down_step_num = step_num - self.step_size_up
computed_lr = self.anneal_func(group['max_lr'], group['min_lr'], down_step_num / self.step_size_down)
if self.cycle_momentum:
computed_momentum = self.anneal_func(group['base_momentum'], group['max_momentum'],
down_step_num / self.step_size_down)

lrs.append(computed_lr)
if self.cycle_momentum:
if self.use_beta1:
_, beta2 = group['betas']
group['betas'] = (computed_momentum, beta2)
else:
group['momentum'] = computed_momentum

return lrs