Skip to content

Latest commit

 

History

History
426 lines (349 loc) · 17.3 KB

optimizer_reinitialization.rst

File metadata and controls

426 lines (349 loc) · 17.3 KB

Optimizer Reinitialization

Overview

~finetuning_scheduler.fts.FinetuningScheduler (FTS) supports the initialization of new optimizers according to a user-specified fine-tuning schedule. Similarly motivated to Fine-Tuning Scheduler's lr scheduler reinitialization feature<lrs-reinit-overview>, one can initialize new optimizers (or reinitialize an existing one) at the beginning of one or more scheduled training phases.

Optimizer reinitialization is supported:

  • In both explicit and implicit fine-tuning schedule modes (see the Fine-Tuning Scheduler intro<motivation> for more on basic usage modes)
  • With or without concurrent lr scheduler reinitialization
  • In the context of all supported training strategies (including FSDP)
  • With FTS >= 2.0.2

We'll cover both implicit and explicit configuration modes below and provide a slightly altered version of the lr scheduler reinitialization example<advanced-fine-tuning-lr-example> that demonstrates concurrent reinitialization of optimizers and lr schedulers at different phases.

Specifying Optimizer Configurations For Specific Fine-Tuning Phases

When defining a fine-tuning schedule (see the intro<specifying schedule> for basic schedule specification), a new optimizer configuration can be applied to the existing training session at the beginning of a given phase by specifying the desired configuration in the new_optimizer key. The new_optimizer dictionary format is described in the annotated yaml schedule below and can be explored using the advanced usage example<advanced-fine-tuning-optimizer-example>.

When specifying an optimizer configuration for a given phase, the new_optimizer dictionary requires at minimum an optimizer_init dictionary containing a class_path key indicating the class of the optimizer (list of supported optimizers<supported_reinit_optimizers>) to be instantiated.

Any arguments with which you would like to initialize the specified optimizer should be specified in the init_args key of the optimizer_init dictionary.

0:
  params:
  - model.classifier.bias
  - model.classifier.weight
1:
  params:
  - model.pooler.dense.bias
  - model.pooler.dense.weight
  - model.deberta.encoder.LayerNorm.bias
  - model.deberta.encoder.LayerNorm.weight
  new_optimizer:
    optimizer_init:
      class_path: torch.optim.SGD
      init_args:
        lr: 2.0e-03
        momentum: 0.9
        weight_decay: 2.0e-06
...

Optionally, one can also provide an lr scheduler reinitialization directive in the same phase as an optimizer reinitialization directive. If one does not provide a new_lr_scheduler directive, the latest lr state will still be restored and wrapped around the new optimizer prior to the execution of the new phase (as with lr scheduler reinitialization):

0:
  ...
1:
  params:
  - model.pooler.dense.bias
  ...
  new_optimizer:
    optimizer_init:
      class_path: torch.optim.SGD
      init_args:
        lr: 2.0e-03
        momentum: 0.9
        weight_decay: 2.0e-06
  new_lr_scheduler:
    lr_scheduler_init:
      class_path: torch.optim.lr_scheduler.StepLR
      init_args:
        ...
    pl_lrs_cfg:
        ...
    init_pg_lrs: [2.0e-06, 2.0e-06]

All optimizer reinitialization configurations specified in the fine-tuning schedule will have their configurations sanity-checked prior to training initiation.

Note

When reinitializing optimizers, FTS does not fully simulate/evaluate all compatibility scenarios so it is the user's responsibility to ensure compatibility between optimizer instantiations or to set ~finetuning_scheduler.fts.FinetuningScheduler.restore_best to False. For example consider the following training scenario:

Phase 0: SGD training
Phase 1: Reinitialize the optimizer and continue training with an Adam optimizer
Phase 2: Restore best checkpoint from phase 0 (w/ `restore_best` default of `True`)

Phase 2 would fail due to incompatibility between Adam and SGD optimizer states. This issue could be avoided by either reinitializing the Adam optimizer again in phase 2 or setting ~finetuning_scheduler.fts.FinetuningScheduler.restore_best to False`.1

Both lr scheduler and optimizer reinitialization configurations are only supported for phases >= 1. This is because for fine-tuning phase 0, training component configurations will be the ones the user initiated the training session with, usually via the configure_optimizer method of :external+pl~lightning.pytorch.core.module.LightningModule.

As you can observe in the explicit mode optimizer reinitialization example<advanced-fine-tuning-optimizer-example> below, optimizers specified in different fine-tuning phases can be of differing types.

0:
  params:
  - model.classifier.bias
  - model.classifier.weight
1:
  params:
  - model.pooler.dense.bias
  - model.pooler.dense.weight
  - model.deberta.encoder.LayerNorm.bias
  - model.deberta.encoder.LayerNorm.weight
  new_optimizer:
    optimizer_init:
      class_path: torch.optim.SGD
      init_args:
        lr: 2.0e-03
        momentum: 0.9
        weight_decay: 2.0e-06
  ...
2:
  params:
  - model.deberta.encoder.rel_embeddings.weight
  - model.deberta.encoder.layer.{0,11}.(output|attention|intermediate).*
  - model.deberta.embeddings.LayerNorm.bias
  - model.deberta.embeddings.LayerNorm.weight
  new_optimizer:
    optimizer_init:
      class_path: torch.optim.AdamW
      init_args:
        weight_decay: 1.0e-05
        eps: 1.0e-07
        lr: 1.0e-05
  ...

Once a new optimizer is re-initialized, it will continue to be used for subsequent phases unless replaced with another optimizer configuration defined in a subsequent schedule phase.

Optimizer Reinitialization With Generated (Implicit Mode) Fine-Tuning Schedules

One can also specify optimizer reinitialization in the context of implicit mode fine-tuning schedules. Since the fine-tuning schedule is automatically generated, the same optimizer configuration will be applied at each of the phase transitions. In implicit mode, the optimizer reconfiguration should be supplied to the ~finetuning_scheduler.fts.FinetuningScheduler.reinit_optim_cfg parameter of ~finetuning_scheduler.fts.FinetuningScheduler.

For example, configuring this dictionary via the :external+pl~lightning.pytorch.cli.LightningCLI, one could use:

model:
  ...
trainer:
  callbacks:
    - class_path: finetuning_scheduler.FinetuningScheduler
      init_args:
        reinit_optim_cfg:
          optimizer_init:
            class_path: torch.optim.AdamW
            init_args:
              weight_decay: 1.0e-05
              eps: 1.0e-07
              lr: 1.0e-05
        reinit_lr_cfg:
          lr_scheduler_init:
            class_path: torch.optim.lr_scheduler.StepLR
            ...

Note that an initial optimizer configuration should also still be provided per usual (again, typically via the configure_optimizer method of :external+pl~lightning.pytorch.core.module.LightningModule) and the initial optimizer configuration can differ in optimizer type and configuration from the configuration specified in ~finetuning_scheduler.fts.FinetuningScheduler.reinit_optim_cfg applied at each phase transition. As with explicit mode, concurrent ~finetuning_scheduler.fts.FinetuningScheduler.reinit_lr_cfg configurations can also be specified in implicit mode.

Advanced Usage Examples: Explicit and Implicit Mode Concurrent Optimizer and LR Scheduler Reinitialization

Demonstration optimizer and concurrent lr scheduler reinitialization configurations for both explicit and fine-tuning scheduling contexts are available under ./fts_examples/stable/config/advanced/reinit_optim_lr.

The concurrent optimizer and lr scheduler reinitialization examples use the same code and have the same dependencies as the lr scheduler reinitialization-only (with the exception of requiring FTS >= 2.0.2 ) examples<advanced-fine-tuning-lr-example>.

The two different demo schedule configurations are composed with shared defaults (./config/fts_defaults.yaml).

# Demo concurrent optimizer and lr scheduler reinitializations...
cd ./fts_examples/stable

# with an explicitly defined fine-tuning schedule:
python fts_superglue.py fit --config config/advanced/reinit_optim_lr/fts_explicit_reinit_optim_lr.yaml

# with an implicitly defined fine-tuning schedule:
python fts_superglue.py fit --config config/advanced/reinit_optim_lr/fts_implicit_reinit_optim_lr.yaml

# with non-default `use_current_optimizer_pg_lrs` mode (and an implicit schedule):
python fts_superglue.py fit --config config/advanced/reinit_optim_lr/fts_implicit_reinit_optim_lr_use_curr.yaml

Similar to the explicitly defined lr reinitialization-only schedule example, we are using three distinct lr schedulers for three different training phases. In this case, there are also distinctly configured optimizers being used:

  • The configured phase 0<explicit-phase-0-config> in yellow uses an :external+torch~torch.optim.AdamW optimizer and :external+torch~torch.optim.lr_scheduler.LinearLR scheduler with the initial lr and optimizer defined via the shared initial optimizer configuration.
  • The configured phase 1<explicit-phase-1-config> in blue uses a :external+torch~torch.optim.SGD optimizer and :external+torch~torch.optim.lr_scheduler.StepLR scheduler, including the specified initial lr for the existing parameter groups (2.0e-06).
  • The configured phase 2<explicit-phase-2-config> in green switches back to an :external+torch~torch.optim.AdamW optimizer but uses a :external+torch~torch.optim.lr_scheduler.CosineAnnealingWarmRestarts scheduler, with an assigned initial lr for each of the parameter groups.

Because we turned on DEBUG-level logging to trace reinitializations, we observe the following in our training log upon the phase 1 optimizer reinitialization:

Epoch 8: 100%|██████████| 78/78 ...
...
Fine-Tuning Scheduler has reinitialized the optimizer as directed:
Previous optimizer state: AdamW
... (followed by parameter group config details)
New optimizer state: SGD
... (followed by parameter group initial config details, note existing lr state or user directives may subsequently override the `lr`s in this initial config)

In the implicitly defined schedule scenario, we begin using the :external+torch~torch.optim.AdamW optimizer but the :external+torch~torch.optim.SGD optimizer and :external+torch~torch.optim.lr_scheduler.StepLR lr scheduler are specified via ~finetuning_scheduler.fts.FinetuningScheduler.reinit_optim_cfg and ~finetuning_scheduler.fts.FinetuningScheduler.reinit_lr_cfg respectively. Both training components are reinitialized at each phase transition and applied to all optimizer parameter groups.

...
- class_path: finetuning_scheduler.FinetuningScheduler
  init_args:
    # note, we're not going to see great performance due
    # to the shallow depth, just demonstrating the lr scheduler
    # reinitialization behavior in implicit mode
    max_depth: 4
    restore_best: false  # disable restore_best for lr pattern clarity
    logging_level: 10  # enable DEBUG logging to trace all reinitializations
    reinit_optim_cfg:
      optimizer_init:
        class_path: torch.optim.SGD
        init_args:
          lr: 1.0e-05
          momentum: 0.9
          weight_decay: 1.0e-06
    reinit_lr_cfg:
      lr_scheduler_init:
        class_path: torch.optim.lr_scheduler.StepLR
        init_args:
          step_size: 1
          gamma: 0.7
      pl_lrs_cfg:
        interval: epoch
        frequency: 1
        name: Implicit_Reinit_LR_Scheduler
      # non-default behavior set in `fts_implicit_reinit_optim_lr_use_curr.yaml`
      use_current_optimizer_pg_lrs: true
LR log for parameter group 1 reflecting repeated reinitialization of the :external+torch~torch.optim.SGD optimizer and :external+torch~torch.optim.lr_scheduler.StepLR lr scheduler (initial target lr = 1.0e-05) at each phase transition. The behavioral impact of use_current_optimizer_pg_lrs (line 28 above) on the lr scheduler reinitializations can be clearly observed.LR log for parameter group 1 reflecting repeated reinitialization of the :external+torch~torch.optim.SGD optimizer and :external+torch~torch.optim.lr_scheduler.StepLR lr scheduler (initial target lr = 1.0e-05) at each phase transition. The behavioral impact of use_current_optimizer_pg_lrs (line 28 above) on the lr scheduler reinitializations can be clearly observed.

Note that we have disabled ~finetuning_scheduler.fts.FinetuningScheduler.restore_best in both examples for clarity of lr patterns.

Note

Optimizer reinitialization with ~finetuning_scheduler.fts.FinetuningScheduler is currently in beta.

Configuration Appendix

Effective phase 0 config defined in ./config/advanced/reinit_optim_lr/fts_explicit_reinit_optim_lr.yaml, applying defaults defined in ./config/fts_defaults.yaml ⏎<explicit-config-overview>

...
model:
  class_path: fts_examples.stable.fts_superglue.RteBoolqModule
  init_args:
    optimizer_init:
      class_path: torch.optim.AdamW
      init_args:
        weight_decay: 1.0e-05
        eps: 1.0e-07
        lr: 1.0e-05
    ...
    lr_scheduler_init:
      class_path: torch.optim.lr_scheduler.LinearLR
      init_args:
        start_factor: 0.1
        total_iters: 4
    pl_lrs_cfg:
      interval: epoch
      frequency: 1
      name: Explicit_Reinit_LR_Scheduler

Phase 1 config, defined in our explicit schedule ./config/advanced/reinit_optim_lr/explicit_reinit_optim_lr.yaml ⏎<explicit-config-overview>

...
1:
  params:
  - model.pooler.dense.bias
  - model.pooler.dense.weight
  - model.deberta.encoder.LayerNorm.bias
  - model.deberta.encoder.LayerNorm.weight
  new_optimizer:
    optimizer_init:
      class_path: torch.optim.SGD
      init_args:
        lr: 1.0e-05
        momentum: 0.9
        weight_decay: 1.0e-06
  new_lr_scheduler:
    lr_scheduler_init:
      class_path: torch.optim.lr_scheduler.StepLR
      init_args:
        step_size: 1
        gamma: 0.7
    pl_lrs_cfg:
      interval: epoch
      frequency: 1
      name: Explicit_Reinit_LR_Scheduler
    init_pg_lrs: [2.0e-06, 2.0e-06]

Phase 2 config, like all non-zero phases, defined in our explicit schedule ./config/advanced/reinit_optim_lr/explicit_reinit_optim_lr.yaml ⏎<explicit-config-overview>

...
2:
  params:
  - model.deberta.encoder.rel_embeddings.weight
  - model.deberta.encoder.layer.{0,11}.(output|attention|intermediate).*
  - model.deberta.embeddings.LayerNorm.bias
  - model.deberta.embeddings.LayerNorm.weight
  new_optimizer:
    optimizer_init:
      class_path: torch.optim.AdamW
      init_args:
        weight_decay: 1.0e-05
        eps: 1.0e-07
        lr: 1.0e-05
  new_lr_scheduler:
    lr_scheduler_init:
      class_path: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
      init_args:
        T_0: 3
        T_mult: 2
        eta_min: 1.0e-07
    pl_lrs_cfg:
      interval: epoch
      frequency: 1
      name: Explicit_Reinit_LR_Scheduler
    init_pg_lrs: [1.0e-06, 1.0e-06, 2.0e-06, 2.0e-06]

Footnotes


  1. While FTS could theoretically cache optimizer state prior to checkpoint restoration for potentially incompatible optimizer reinitialization configurations, that functionality is not currently implemented because of the resource overhead and unnecessary complexity it would add to the default restoration path. If there is sufficient interest in the user community, that functionality may be added in the future. ⏎<optimizer-compat-note>