Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed parameter scheduler bug with CosineAnnealingWarmRestarts #2938

Merged
merged 29 commits into from May 23, 2023

Conversation

AlexanderChaptykov
Copy link
Contributor

@AlexanderChaptykov AlexanderChaptykov commented May 8, 2023

Fixes #2910

Description:

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

Plotting learning rates:

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from ignite.handlers import create_lr_scheduler_with_warmup


def plot(warmup_end_value):
    lr = 0.2
    warm_steps = 5
    steps = 10
    warm_start = 0.023

    def get_optim():
        t1 = torch.zeros([1], requires_grad=True)
        return torch.optim.SGD([t1], lr=lr)

    def get_cos_shed():
        return CosineAnnealingWarmRestarts(optimizer, T_0=12, T_mult=3, verbose=False)

    optimizer = get_optim()
    scheduler = get_cos_shed()
    cosine_lrs = []
    for i in range(steps):
        cosine_lrs.append(optimizer.param_groups[0]["lr"])
        scheduler.step()

    optimizer = get_optim()
    scheduler = create_lr_scheduler_with_warmup(
        get_cos_shed(), warmup_start_value=warm_start, warmup_end_value=warmup_end_value, warmup_duration=warm_steps
    )

    warm_lrs = []
    real_warm_steps = warm_steps if warmup_end_value is not None else (warm_steps - 1)
    for epoch in range(real_warm_steps + steps):
        scheduler(None)
        warm_lrs.append(optimizer.param_groups[0]["lr"])

    if warmup_end_value is not None:
        plt.title("warmup_end_value != lr")
        plt.scatter(range(len(warm_lrs[:real_warm_steps])), warm_lrs[:real_warm_steps])
        plt.scatter(range(warm_steps, len(warm_lrs[real_warm_steps:]) + warm_steps), warm_lrs[real_warm_steps:])
        plt.show()
    else:
        plt.title("warmup_end_value == lr or warmup_end_value is None")
        plt.scatter(range(len(warm_lrs[:warm_steps])), warm_lrs[:warm_steps])
        plt.scatter(range(warm_steps, len(warm_lrs[warm_steps:]) + warm_steps), warm_lrs[warm_steps:])
        plt.show()


plot(None)
plot(.26)
image image

@github-actions github-actions bot added the module: handlers Core Handlers module label May 8, 2023
@vfdev-5 vfdev-5 changed the title Bug cosine scheduler Fixed parameter scheduler bug with CosineAnnealingWarmRestarts May 8, 2023
Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates @AlexanderChaptykov
I left few suggestions on how to improve the PR

ignite/handlers/param_scheduler.py Outdated Show resolved Hide resolved
ignite/handlers/param_scheduler.py Outdated Show resolved Hide resolved
ignite/handlers/param_scheduler.py Outdated Show resolved Hide resolved
tests/ignite/handlers/test_param_scheduler.py Outdated Show resolved Hide resolved
tests/ignite/handlers/test_param_scheduler.py Outdated Show resolved Hide resolved
ignite/handlers/param_scheduler.py Outdated Show resolved Hide resolved
assert warm_lrs[warm_steps:] == cosine_lrs
else:
assert (np.linspace(warm_start, lr, warm_steps).round(3) == np.array(warm_lrs[:warm_steps]).round(3)).all()
assert warm_lrs[warm_steps - 1 : -1] == cosine_lrs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need this, beacuse of shifting lrs if warmup_end_value == None

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 23, 2023

Let's make the test as following:

@pytest.mark.parametrize("warmup_end_value", [0.23, None])
@pytest.mark.parametrize("T_0", [1, 12])
@pytest.mark.parametrize("T_mult", [1, 3])
def test_create_lr_scheduler_with_warmup_cosine(warmup_end_value, T_0, T_mult):
    lr = 0.2
    steps = 200
    warm_steps = 50
    warm_start = 0.023

    def get_optim():
        t1 = torch.zeros([1], requires_grad=True)
        return torch.optim.SGD([t1], lr=lr)

    def get_cos_shed():
        return CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult, verbose=False)

    optimizer = get_optim()
    scheduler = get_cos_shed()
    cosine_lrs = []
    for i in range(steps):
        cosine_lrs.append(optimizer.param_groups[0]["lr"])
        scheduler.step()

    optimizer = get_optim()
    scheduler = create_lr_scheduler_with_warmup(
        get_cos_shed(), warmup_start_value=warm_start, warmup_end_value=warmup_end_value, warmup_duration=warm_steps
    )

    warm_lrs = []
    for epoch in range(warm_steps + steps):
        scheduler(None)
        warm_lrs.append(optimizer.param_groups[0]["lr"])

    if warmup_end_value is not None:
        np.testing.assert_allclose(np.linspace(warm_start, warmup_end_value, warm_steps), warm_lrs[:warm_steps])
        assert warm_lrs[warm_steps:] == cosine_lrs
    else:
        np.testing.assert_allclose(np.linspace(warm_start, lr, warm_steps), warm_lrs[:warm_steps])
        assert warm_lrs[warm_steps - 1:-1] == cosine_lrs

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 23, 2023

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from ignite.handlers import create_lr_scheduler_with_warmup


def plot(warmup_end_value):
    lr = 0.2
    warm_steps = 5
    steps = 100
    warm_start = 0.023

    def get_optim():
        t1 = torch.zeros([1], requires_grad=True)
        return torch.optim.SGD([t1], lr=lr)

    def get_cos_shed():
        return CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, verbose=False)

    optimizer = get_optim()
    scheduler = get_cos_shed()
    cosine_lrs = []
    for i in range(steps):
        cosine_lrs.append(optimizer.param_groups[0]["lr"])
        scheduler.step()

    optimizer = get_optim()
    scheduler = create_lr_scheduler_with_warmup(
        get_cos_shed(), warmup_start_value=warm_start, warmup_end_value=warmup_end_value, warmup_duration=warm_steps
    )

    warm_lrs = []
    for epoch in range(warm_steps + steps):
        scheduler(None)
        warm_lrs.append(optimizer.param_groups[0]["lr"])

    if warmup_end_value is not None:
        plt.figure(figsize=(10, 5))
        plt.subplot(121)
        plt.title("create_lr_scheduler_with_warmup +\nCosineAnnealingWarmRestarts\nwarmup_end_value != lr")
        plt.plot(warm_lrs, "-*")
        plt.subplot(122)
        plt.title("CosineAnnealingWarmRestarts")
        plt.plot(cosine_lrs, "-*")        
        plt.show()
    else:
        plt.figure(figsize=(10, 5))
        plt.subplot(121)
        plt.title("create_lr_scheduler_with_warmup +\nCosineAnnealingWarmRestarts\nwarmup_end_value == lr")
        plt.plot(warm_lrs, "-*")
        plt.subplot(122)
        plt.title("CosineAnnealingWarmRestarts")
        plt.plot(cosine_lrs, "-*")        
        plt.show()


plot(None)
plot(.26)

image
image

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @AlexanderChaptykov for working on this issue!

@vfdev-5 vfdev-5 merged commit e9e5b45 into pytorch:master May 23, 2023
13 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: handlers Core Handlers module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

WarmRestarts seems not working with create_lr_scheduler_with_warmup function
2 participants