-
Notifications
You must be signed in to change notification settings - Fork 99
/
onecycle_lr.py
42 lines (38 loc) 路 1.49 KB
/
onecycle_lr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
""" OneCycle Scheduler
"""
import logging
import math
import numpy as np
import torch
from .scheduler import Scheduler
_logger = logging.getLogger(__name__)
class OneCycleLRScheduler(Scheduler):
def __init__(self,
optimizer: torch.optim.Optimizer,
t_initial: int,
t_mul: float = 1.,
lr_min: float = 0.,
decay_rate: float = 1.,
warmup_t=0,
warmup_lr_init=0,
warmup_prefix=False,
cycle_limit=0,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
initialize=True) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
assert warmup_t == 0, "this schedule has warmup built in"
assert t_initial > 0
self.t_initial = t_initial
def get_frac_epoch_values(self, frac_epoch: int):
sched = lambda t, lr_max: np.interp([t], [0, self.t_initial*2//5, self.t_initial*4//5, self.t_initial],
[0, lr_max, lr_max/20.0, 0])[0]
return [sched(frac_epoch, v) for v in self.base_values]
def get_epoch_values(self, epoch: int):
return self.get_frac_epoch_values(epoch)