Skip to content

Commit

Permalink
test: add LinearCyclicalScheduler using use_sawtooth option test
Browse files Browse the repository at this point in the history
  • Loading branch information
sihyeong671 committed Jan 17, 2024
1 parent d091eea commit 6d10fde
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions tests/ignite/handlers/test_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def test_linear_scheduler_asserts():
with pytest.raises(ValueError, match=r"Argument cycle_size should be positive and larger than 1"):
LinearCyclicalScheduler(optimizer, "lr", 1, 0, cycle_size=1)

with pytest.raises(
ValueError, match=r"can not use use_sawtooth option and warmup duration please remove one of them"
):
LinearCyclicalScheduler(optimizer, "lr", 1, 0, cycle_size=2, warmup_duration=1, use_sawtooth=True)


def test_linear_scheduler():
tensor = torch.zeros([1], requires_grad=True)
Expand Down Expand Up @@ -144,6 +149,79 @@ def save_lr(engine):
scheduler.load_state_dict(state_dict)


def test_linear_scheduler_use_sawtooth():
tensor = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([tensor], lr=0)
scheduler = LinearCyclicalScheduler(optimizer, "lr", 1, 0, 10, use_sawtooth=True)
state_dict = scheduler.state_dict()

def save_lr(engine):
lrs.append(optimizer.param_groups[0]["lr"])

trainer = Engine(lambda engine, batch: None)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)
lr_values_in_cycle = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1]
for _ in range(2):
lrs = []
trainer.run([0] * 10, max_epochs=2)
assert lrs == pytest.approx([*lr_values_in_cycle, *lr_values_in_cycle])
scheduler.load_state_dict(state_dict)

optimizer = torch.optim.SGD([tensor], lr=0)
scheduler = LinearCyclicalScheduler(optimizer, "lr", 1, 0, 10, cycle_mult=2, use_sawtooth=True)
state_dict = scheduler.state_dict()

trainer = Engine(lambda engine, batch: None)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

for _ in range(2):
lrs = []
trainer.run([0] * 10, max_epochs=3)

assert lrs == list(
map(
pytest.approx,
[
# Cycle 1
1.0,
0.9,
0.8,
0.7,
0.6,
0.5,
0.4,
0.3,
0.2,
0.1,
# Cycle 2
1.0,
0.95,
0.9,
0.85,
0.8,
0.75,
0.7,
0.65,
0.6,
0.55,
0.5,
0.45,
0.4,
0.35,
0.3,
0.25,
0.2,
0.15,
0.1,
0.05,
],
)
)
scheduler.load_state_dict(state_dict)


def test_linear_scheduler_cycle_size_two():
tensor = torch.zeros([1], requires_grad=True)
optimizer = torch.optim.SGD([tensor], lr=0)
Expand Down

0 comments on commit 6d10fde

Please sign in to comment.