# src - scheduler

## Notebook运行提示
- 代码已拆分为多个小单元, 按顺序运行即可在每一步观察输出与中间变量。
- 涉及 `Path(__file__)` 或相对路径的脚本会自动注入 `__file__` 解析逻辑, Notebook 环境下也能引用原项目资源。
- 可在每个单元下追加说明或参数试验记录, 以跟踪核心算法和数据处理步骤。


In [None]:
from transformers import get_linear_schedule_with_warmup, \
  get_cosine_schedule_with_warmup, \
  get_polynomial_decay_schedule_with_warmup, get_constant_schedule_with_warmup

In [None]:


def get_scheduler(optimizer, config, num_train_steps):

    if config.scheduler.scheduler_type == 'constant_schedule_with_warmup':
        scheduler = get_constant_schedule_with_warmup(
            optimizer,
            num_warmup_steps=config.scheduler.constant_schedule_with_warmup.n_warmup_steps
        )
    elif config.scheduler.scheduler_type == 'linear_schedule_with_warmup':
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=config.scheduler.linear_schedule_with_warmup.n_warmup_steps,
            num_training_steps=num_train_steps
        )
    elif config.scheduler.scheduler_type == 'cosine_schedule_with_warmup':
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=config.scheduler.cosine_schedule_with_warmup.n_warmup_steps,
            num_cycles=config.scheduler.cosine_schedule_with_warmup.n_cycles,
            num_training_steps=num_train_steps,
        )
    elif config.scheduler.scheduler_type == 'polynomial_decay_schedule_with_warmup':
        scheduler = get_polynomial_decay_schedule_with_warmup(
            optimizer,
            num_warmup_steps=config.scheduler.polynomial_decay_schedule_with_warmup.n_warmup_steps,
            num_training_steps=num_train_steps,
            power=config.scheduler.polynomial_decay_schedule_with_warmup.power,
            lr_end=config.scheduler.polynomial_decay_schedule_with_warmup.min_lr
        )
    else:
        raise ValueError(f'Unknown scheduler: {config.scheduler.scheduler_type}')

    return scheduler