<!-- ---
title: <required-title>
date: 2022-03-11
downloads: true
weight: 12
sidebar: True
summary: This example demonstrates how to use the [torch.optim.lr_scheduler](https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.LambdaLR) to adjust the learning rate of a model.
tags:
  - lr scheduler
--- -->

# How to use LR-Schedulers

This how-to guide demonstrates how we can use LR-Schedulers to adjust the learning rate of a model.

## Basic Setup

Install Dependencies

In [1]:
%%capture
! pip install pytorch-ignite

Import Dependencies

In [2]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR

from ignite.engine import Engine, Events
from ignite.handlers import create_lr_scheduler_with_warmup

## Create a `Dummy Trainer`

In [3]:
def train_step(e, b):
    print(trainer.state.epoch, trainer.state.iteration, " | ", optimizer.param_groups[0]["lr"])

In [4]:
trainer = Engine(train_step)
optimizer = optim.SGD([torch.tensor([0.1])], lr=0.1234)

## Initiate a `LRScheduler`

In [5]:
torch_lr_scheduler = ExponentialLR(optimizer=optimizer, gamma=0.5)

data = [0] * 8
epoch_length = len(data)
warmup_duration = 5
scheduler = create_lr_scheduler_with_warmup(torch_lr_scheduler,
                                            warmup_start_value=0.0,
                                            warmup_duration=warmup_duration)

## Trigger LR-Scheduler:

    - Step 1: Trigger scheduler on interation_started events before reaching warm-up.
    - Step 2: Trigger scheduler on epoch_started events after the warm-up. 

Note: Epochs are 1-based, thus we do 1 + warmup_duration / epoch_length 


In [6]:
combined_events = Events.ITERATION_STARTED(event_filter=lambda _, __: trainer.state.iteration <= warmup_duration)
combined_events |= Events.EPOCH_STARTED(event_filter=lambda _, __: trainer.state.epoch > 1 + warmup_duration / epoch_length)
trainer.add_event_handler(combined_events, scheduler)

<ignite.engine.events.RemovableEventHandle at 0x7fb7a979bfa0>

## Execute Trainer

In [7]:
trainer.run(data, max_epochs=10)

1 1  |  0.0
1 2  |  0.03085
1 3  |  0.0617
1 4  |  0.09255
1 5  |  0.1234
1 6  |  0.1234
1 7  |  0.1234
1 8  |  0.1234
2 9  |  0.0617
2 10  |  0.0617
2 11  |  0.0617
2 12  |  0.0617
2 13  |  0.0617
2 14  |  0.0617
2 15  |  0.0617
2 16  |  0.0617
3 17  |  0.03085
3 18  |  0.03085
3 19  |  0.03085
3 20  |  0.03085
3 21  |  0.03085
3 22  |  0.03085
3 23  |  0.03085
3 24  |  0.03085
4 25  |  0.015425
4 26  |  0.015425
4 27  |  0.015425
4 28  |  0.015425
4 29  |  0.015425
4 30  |  0.015425
4 31  |  0.015425
4 32  |  0.015425
5 33  |  0.0077125
5 34  |  0.0077125
5 35  |  0.0077125
5 36  |  0.0077125
5 37  |  0.0077125
5 38  |  0.0077125
5 39  |  0.0077125
5 40  |  0.0077125
6 41  |  0.00385625
6 42  |  0.00385625
6 43  |  0.00385625
6 44  |  0.00385625
6 45  |  0.00385625
6 46  |  0.00385625
6 47  |  0.00385625
6 48  |  0.00385625
7 49  |  0.001928125
7 50  |  0.001928125
7 51  |  0.001928125
7 52  |  0.001928125
7 53  |  0.001928125
7 54  |  0.001928125
7 55  |  0.001928125
7 56  |  0.0019

State:
	iteration: 80
	epoch: 10
	epoch_length: 8
	max_epochs: 10
	output: <class 'NoneType'>
	batch: 0
	metrics: <class 'dict'>
	dataloader: <class 'list'>
	seed: <class 'NoneType'>
	times: <class 'dict'>