/
lr_scheduler.py
287 lines (233 loc) · 9.37 KB
/
lr_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
"""Contains learning rate scheduler callbacks"""
import sys
# pylint: disable=unused-import
import warnings
import numpy as np
import torch
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import CyclicLR
from torch.optim.lr_scheduler import ExponentialLR
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import StepLR
from torch.optim.optimizer import Optimizer
from skorch.callbacks import Callback
__all__ = ['LRScheduler', 'WarmRestartLR']
def _check_lr(name, optimizer, lr):
"""Return one learning rate for each param group."""
n = len(optimizer.param_groups)
if not isinstance(lr, (list, tuple)):
return lr * np.ones(n)
if len(lr) != n:
raise ValueError("{} lr values were passed for {} but there are "
"{} param groups.".format(n, name, len(lr)))
return np.array(lr)
class LRScheduler(Callback):
"""Callback that sets the learning rate of each
parameter group according to some policy.
Parameters
----------
policy : str or _LRScheduler class (default='WarmRestartLR')
Learning rate policy name or scheduler to be used.
monitor : str or callable (default=None)
Value of the history to monitor or function/callable. In
the latter case, the callable receives the net instance as
argument and is expected to return the score (float) used to
determine the learning rate adjustment.
event_name: str, (default='event_lr')
Name of event to be placed in history when the scheduler takes a step.
Pass ``None`` to disable placing events in history.
**Note:** This feature works only for pytorch version >=1.4
step_every: str, (default='epoch')
Value for when to apply the learning scheduler step. Can be either 'batch'
or 'epoch'.
kwargs
Additional arguments passed to the lr scheduler.
"""
def __init__(self,
policy='WarmRestartLR',
monitor='train_loss',
event_name="event_lr",
step_every='epoch',
**kwargs):
self.policy = policy
self.monitor = monitor
self.event_name = event_name
self.step_every = step_every
vars(self).update(kwargs)
def simulate(self, steps, initial_lr):
"""
Simulates the learning rate scheduler.
Parameters
----------
steps: int
Number of steps to simulate
initial_lr: float
Initial learning rate
Returns
-------
lrs: numpy ndarray
Simulated learning rates
"""
test = torch.ones(1, requires_grad=True)
opt = torch.optim.SGD([{'params': test, 'lr': initial_lr}])
policy_cls = self._get_policy_cls()
sch = policy_cls(opt, **self.kwargs)
lrs = []
for _ in range(steps):
opt.step() # suppress warning about .step call order
lrs.append(opt.param_groups[0]['lr'])
sch.step()
return np.array(lrs)
def initialize(self):
self.policy_ = self._get_policy_cls()
self.lr_scheduler_ = None
self.batch_idx_ = 0
return self
def _get_policy_cls(self):
if isinstance(self.policy, str):
return getattr(sys.modules[__name__], self.policy)
return self.policy
@property
def kwargs(self):
# These are the parameters that are passed to the
# scheduler. Parameters that don't belong there must be
# excluded.
excluded = ('policy', 'monitor', 'event_name', 'step_every')
kwargs = {key: val for key, val in vars(self).items()
if not (key in excluded or key.endswith('_'))}
return kwargs
def on_train_begin(self, net, **kwargs):
if net.history:
try:
self.batch_idx_ = sum(net.history[:, 'train_batch_count'])
except KeyError:
self.batch_idx_ = sum(len(b) for b in net.history[:, 'batches'])
self.lr_scheduler_ = self._get_scheduler(
net, self.policy_, **self.kwargs
)
def _step(self, net, lr_scheduler, score=None):
"""Helper method to step the lr scheduler.
This takes care of two things:
1. If the lr scheduler is ReduceLROnPlateau, we need to pass the score.
2. If the net is uses AccelerateMixin, stepping has to be skipped in
certain conditions.
For more info on the latter, see:
https://huggingface.co/docs/accelerate/quicktour#mixed-precision-training
"""
accelerator_maybe = getattr(net, 'accelerator', None)
accelerator_step_skipped = (
accelerator_maybe and accelerator_maybe.optimizer_step_was_skipped
)
if accelerator_step_skipped:
return
if score is None:
lr_scheduler.step()
else:
lr_scheduler.step(score)
def on_epoch_end(self, net, **kwargs):
if self.step_every != 'epoch':
return
if isinstance(self.lr_scheduler_, ReduceLROnPlateau):
if callable(self.monitor):
score = self.monitor(net)
else:
try:
score = net.history[-1, self.monitor]
except KeyError as e:
raise ValueError(
f"'{self.monitor}' was not found in history. A "
f"Scoring callback with name='{self.monitor}' "
"should be placed before the LRScheduler callback"
) from e
self._step(net, self.lr_scheduler_, score=score)
# ReduceLROnPlateau does not expose the current lr so it can't be recorded
else:
if (
(self.event_name is not None)
and hasattr(self.lr_scheduler_, "get_last_lr")
):
net.history.record(self.event_name, self.lr_scheduler_.get_last_lr()[0])
self._step(net, self.lr_scheduler_)
def on_batch_end(self, net, training, **kwargs):
if not training or self.step_every != 'batch':
return
if (
(self.event_name is not None)
and hasattr(self.lr_scheduler_, "get_last_lr")
):
net.history.record_batch(
self.event_name, self.lr_scheduler_.get_last_lr()[0])
self._step(net, self.lr_scheduler_)
self.batch_idx_ += 1
def _get_scheduler(self, net, policy, **scheduler_kwargs):
"""Return scheduler, based on indicated policy, with appropriate
parameters.
"""
if (
(policy not in [ReduceLROnPlateau])
and ('last_epoch' not in scheduler_kwargs)
):
last_epoch = len(net.history) - 1
scheduler_kwargs['last_epoch'] = last_epoch
return policy(net.optimizer_, **scheduler_kwargs)
class WarmRestartLR(_LRScheduler):
"""Stochastic Gradient Descent with Warm Restarts (SGDR) scheduler.
This scheduler sets the learning rate of each parameter group
according to stochastic gradient descent with warm restarts (SGDR)
policy. This policy simulates periodic warm restarts of SGD, where
in each restart the learning rate is initialize to some value and is
scheduled to decrease.
Parameters
----------
optimizer : torch.optimizer.Optimizer instance.
Optimizer algorithm.
min_lr : float or list of float (default=1e-6)
Minimum allowed learning rate during each period for all
param groups (float) or each group (list).
max_lr : float or list of float (default=0.05)
Maximum allowed learning rate during each period for all
param groups (float) or each group (list).
base_period : int (default=10)
Initial restart period to be multiplied at each restart.
period_mult : int (default=2)
Multiplicative factor to increase the period between restarts.
last_epoch : int (default=-1)
The index of the last valid epoch.
References
----------
.. [1] Ilya Loshchilov and Frank Hutter, 2017, "Stochastic Gradient
Descent with Warm Restarts,". "ICLR"
`<https://arxiv.org/pdf/1608.03983.pdf>`_
"""
def __init__(
self, optimizer,
min_lr=1e-6,
max_lr=0.05,
base_period=10,
period_mult=2,
last_epoch=-1
):
self.min_lr = _check_lr('min_lr', optimizer, min_lr)
self.max_lr = _check_lr('max_lr', optimizer, max_lr)
self.base_period = base_period
self.period_mult = period_mult
super(WarmRestartLR, self).__init__(optimizer, last_epoch)
def _get_current_lr(self, min_lr, max_lr, period, epoch):
return min_lr + 0.5 * (max_lr - min_lr) * (
1 + np.cos(epoch * np.pi / period))
def get_lr(self):
epoch_idx = float(self.last_epoch)
current_period = float(self.base_period)
while epoch_idx / current_period > 1.0:
epoch_idx -= current_period + 1
current_period *= self.period_mult
current_lrs = self._get_current_lr(
self.min_lr,
self.max_lr,
current_period,
epoch_idx
)
return current_lrs.tolist()