# Scheduling Weight Decay and Learning Rate in PyTorch

This example demonstrates scheduling multiple hyperparameters concurrently using ScheduleAnything.

We'll schedule both learning rate (cosine annealing) and weight decay (quadratic growth) to show how different hyperparameters can follow different curves during training.

The most important thing to pay attention to is how simple the train loop remains. As far as that level of abstraction goes, it looks like we are just stepping a schedule like normal

Do not expect this particular exercise to follow the standard behaviors of staching metrics and having eval distributions. To keep the example simple and easy to learn from, these were deliberately ignored.

## Environment Setup and Imports

We use magic commands to ensure the environment is setup. Then we run all the needed imports. Note the usage of the cannonical ScheduleAnything import pattern:  torch-schedule-anything -> tsa

```
import torch_schedule_anything as tsa
```

In [1]:
# Setup
%pip install -q transformers datasets torch-schedule-anything torch

# Imports
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
import torch_schedule_anything as tsa

# Type hints
from torch_schedule_anything import SynchronousSchedule
from transformers import PreTrainedTokenizer, PreTrainedModel
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

## Configuration

For easy experimentation, we place the majority of the hyperparameters right here, though we do hardwire the dataset. For the most part, we stick to some fairly boring configurations that should be familiar boilerplate to anyone in NLP.

Training duration and details are specified in terms of number of batches, learning rate has been set to something that is known to train, and the schedules are functional.

### Schedule Overview

Scheduling using builtins in this library generally works by specifying a number of warmup steps (in this case batches) a number of training steps, and some parameters relating to warmup targets and values.

It should always be kept in mind that torch schedules are applied in terms of
value(t) = base_hyperparameter*lambda(t), meaning you will get the base value times a multiplier as your final rate.

The warmup target tells you what lambda will be when warmup finishes, while the final target tells you what it will be at end of training. Largely, the various builtin curves say how we get there. In this case, we use a cosine annealing, and a quadratic curve for learning rate and weight decay respectively.

### Schedule Config

We are going to adopt the view from "Understanding Decoupled and Early Weight Decay" that weight decay needs to get weaker as training proceeds, and additionally are going to say it needs to happen faster than learning rate. This is more or less arbitrary, however, and only used to get some sort of default

### Tuning and Purpose

This exists primarily to demonstrate the technology, not demonstrate a well-tuned example. This example has not been properly tuned besides verifying convergence, and as such do not treat this as having been deployed to be optimal.

In [2]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' # Device
LOGGING_RATE = 5 # How frequently to print to console

# Model/Pipeline config
MODEL_NAME = "distilbert-base-uncased" # Model
MAX_LENGTH = 256 # Maximum number of tokens in sample
BATCH_SIZE = 32 # Batches in samples
TOTAL_BATCHES = 2000 # All batches used over training
WARMUP_BATCHES = 100 # Number of batches used for warmup

# The learning rate and schedule.
#
# Keep in mind when reaching warmup target,
# the actual lr value is
# lr = BASE_LR*LR_WARMUP_TARGET =
BASE_LR = 0.00001
LR_WARMUP_TARGET = 1.0
LR_FINAL_TARGET = 0.1

# The weight decay schedule
#
# We already get some decay from decreasing
# the learning rate. We just add a bit more
# on top using an inverse warmup linear schedule
BASE_WD = 0.1
INVERSE_WARMUP_STARTING_MULTIPLIER = 2
WD_WARMUP_TARGET = 1.0
WD_FINAL_TARGET = 0.1

## Standard Boilerplate

Largely standard boilerplate here.
We make a model, we make an AdamW optimizer,
we make a pipeline that loads imdb and tokenizes it

In [3]:
def make_model()->PreTrainedModel:
    """Load pretrained model with classification head."""
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=2
    )
    return model.to(DEVICE)

def make_dataloader()->DataLoader:
    """Load and tokenize IMDB dataset, return DataLoader."""
    dataset = load_dataset("imdb", split="train")  # Subset for faster demo
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    def tokenize(examples):
        result = tokenizer(
            examples["text"],
            truncation=True,
            max_length=MAX_LENGTH,
            padding="max_length"
        )
        result["labels"] = examples["label"]
        return result

    dataset = dataset.map(tokenize, batched=True)
    dataset = dataset.shuffle(seed=42)
    dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

def make_optimizer(model: PreTrainedModel)->Optimizer:
    """Create optimizer with base hyperparameter values that schedules multiply against."""
    return AdamW(
        model.parameters(),
        lr=BASE_LR,
        weight_decay=BASE_WD
    )

## Schedule Factory, and the Novelty

This is where ScheduleAnything comes in. We're going to create two different schedules and bind them to two different hyperparameters.

PyTorch schedulers only work on learning rate by default. ScheduleAnything lets you bind schedulers to *any* hyperparameter using the `schedule_target` parameter.

**The pattern:**
1. Create a schedule for each hyperparameter you want to control
2. Use `schedule_target` to specify which hyperparameter each schedule controls
3. Wrap them in `SynchronousSchedule` to keep them coordinated

These are standard PyTorch `LRScheduler` objects under the hood - we're just routing them to different hyperparameters instead of all controlling 'lr'. The result can be a schedule object that, as far as stuff downstream is concerned, still has a usable step method by duck typing, all the other method, but has some additional logging extensions if desired as well.

**Crucially**, this means downstream does not need to know or care that anything weird is happening to the schedule at all. In theory, if we do not want custom logging, it does not matter in the slightest to the main training loop

In [4]:
def make_schedule(optimizer: Optimizer)->SynchronousSchedule:
    """
    Create coordinated schedules for learning rate and weight decay.

    Returns a SynchronousSchedule that steps both schedules together.
    """
    # Learning rate: cosine annealing
    lr_scheduler = tsa.cosine_annealing_with_warmup(
        optimizer,
        warmup_to_value=LR_WARMUP_TARGET,
        anneal_to_value=LR_FINAL_TARGET,
        num_warmup_steps=WARMUP_BATCHES,
        num_training_steps=TOTAL_BATCHES,
        schedule_target='lr'
    )

    # Weight decay: linear with inverse warmup for heavy constraint while
    # starting up.
    wd_scheduler = tsa.linear_schedule_with_inverse_warmup(
        optimizer,
        warmup_to_value=WD_WARMUP_TARGET,
        anneal_to_value=WD_FINAL_TARGET,
        num_warmup_steps=WARMUP_BATCHES,
        num_training_steps=TOTAL_BATCHES,
        warmup_multiplier=INVERSE_WARMUP_STARTING_MULTIPLIER,
        schedule_target='weight_decay'
    )

    # Coordinate them to step together
    return tsa.SynchronousSchedule([lr_scheduler, wd_scheduler])


In [5]:
def get_schedule_statistics(schedule: LRScheduler)->str:
  """Reports the last statistics of the scheduling system"""

  # Setup output report
  output = ""

  # Handle lr
  last_lr = schedule.get_last_lr()[0]
  output += f" | LR: {last_lr:.4e}"

  # Handle wd
  if isinstance(schedule, tsa.SynchronousSchedule):
    last_wd = schedule.get_last_schedule('weight_decay')[0]
    output += f" | WD: {last_wd:.4e}"

  return output

## Train Loop
Standard PyTorch training loop as used in NLP, with schedules per batch. We abstract away the changes to logging, however.

In [6]:
def train(model: PreTrainedModel,
          dataloader: DataLoader,
          optimizer: Optimizer,
          schedule: LRScheduler,
          ):
    """Train for TOTAL_BATCHES batches."""
    model.train()
    data_iter = iter(dataloader)

    for batch_idx in range(TOTAL_BATCHES):
        # Get next batch
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            batch = next(data_iter)

        # Move to device
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Step schedules
        schedule.step()

        # Log progress
        if (batch_idx + 1) % LOGGING_RATE == 0:
            msg = (f"Batch {batch_idx+1:4d}/{TOTAL_BATCHES}"
                  f" | Loss: {loss.item():.4f}"
                  )
            msg += get_schedule_statistics(schedule)
            print(msg)


## Putting It All Together

Create the components and train.


In [None]:
def main():
    print("Setting up model and data...")
    model = make_model()
    dataloader = make_dataloader()

    print("Creating optimizer and schedules...")
    optimizer = make_optimizer(model)
    schedule = make_schedule(optimizer)

    #print(f"Scheduling: {schedule.schedule_names}")
    print(f"Training for {TOTAL_BATCHES} batches with {WARMUP_BATCHES} warmup")
    print(f"Device: {DEVICE}\n")

    train(model, dataloader, optimizer, schedule)

    print(f"\nTraining complete!")
    lr = schedule.get_last_lr()[0]
    wd = schedule.get_last_schedule('weight_decay')[0]
    print(f"Final LR: {lr}")
    print(f"Final WD: {wd}")

if __name__ == '__main__':
    main()

Setting up model and data...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Creating optimizer and schedules...
Training for 2000 batches with 100 warmup
Device: cuda

Batch    5/2000 | Loss: 0.6726 | LR: 5.0000e-07 | WD: 1.9500e-01
Batch   10/2000 | Loss: 0.7182 | LR: 1.0000e-06 | WD: 1.9000e-01
Batch   15/2000 | Loss: 0.6881 | LR: 1.5000e-06 | WD: 1.8500e-01
Batch   20/2000 | Loss: 0.7019 | LR: 2.0000e-06 | WD: 1.8000e-01
Batch   25/2000 | Loss: 0.7012 | LR: 2.5000e-06 | WD: 1.7500e-01
Batch   30/2000 | Loss: 0.7089 | LR: 3.0000e-06 | WD: 1.7000e-01
Batch   35/2000 | Loss: 0.7043 | LR: 3.5000e-06 | WD: 1.6500e-01
Batch   40/2000 | Loss: 0.7020 | LR: 4.0000e-06 | WD: 1.6000e-01
Batch   45/2000 | Loss: 0.6898 | LR: 4.5000e-06 | WD: 1.5500e-01
Batch   50/2000 | Loss: 0.6743 | LR: 5.0000e-06 | WD: 1.5000e-01
Batch   55/2000 | Loss: 0.6827 | LR: 5.5000e-06 | WD: 1.4500e-01
Batch   60/2000 | Loss: 0.6794 | LR: 6.0000e-06 | WD: 1.4000e-01
Batch   65/2000 | Loss: 0.6656 | LR: 6.5000e-06 | WD: 1.3500e-01
Batch   70/2000 | Loss: 0.6819 | LR: 7.0000e-06 | WD: 1.3000e-0