# Scheduling GNTS in PyTorch

This example demonstrates the implementation of the Gradient Norm Threshold Scheduling algorithm orthogonal to an existing Adam optimizer. The implementation responds to a scheduled value and asks for the gradient norms to be below a threshold before taking a step.

## Environment Setup and Imports

We use magic commands to ensure the environment is setup. Then we run all the needed imports. Note the usage of the cannonical ScheduleAnything import pattern:  torch-schedule-anything -> tsa

```
import torch_schedule_anything as tsa
```

In [None]:
# Setup
%pip install -q transformers datasets torch-schedule-anything torch

# Imports
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
import torch_schedule_anything as tsa

# Type hints
from torch_schedule_anything import SynchronousSchedule
from transformers import PreTrainedTokenizer, PreTrainedModel
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

## Configuration

For easy experimentation, we place the majority of the hyperparameters right here, though we do hardwire the dataset. For the most part, we stick to some fairly boring configurations that should be familiar boilerplate to anyone in NLP.

Training duration and details are specified in terms of number of batches, learning rate has been set to something that is known to train, and the schedules are functional.

### Schedule Overview

Scheduling using builtins in this library generally works by specifying a number of warmup steps (in this case batches) a number of training steps, and some parameters relating to warmup targets and values.

It should always be kept in mind that torch schedules are applied in terms of
value(t) = base_hyperparameter*lambda(t), meaning you will get the base value times a multiplier as your final rate.

The warmup target tells you what lambda will be when warmup finishes, while the final target tells you what it will be at end of training. Largely, the various builtin curves say how we get there. In this case, we use a cosine annealing, and a quadratic curve for learning rate and weight decay respectively.

### Schedule Config

We are going to schedule logical batch size. This is largely inspired by smith's work, but does not use his exact algorithm, as this is simply a demonstation.

### Tuning and Purpose

This exists primarily to demonstrate the technology, not demonstrate a well-tuned example. This example has not been properly tuned besides verifying convergence, and as such do not treat this as having been deployed to be optimal.

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' # Device
LOGGING_RATE = 50 # How frequently to print to console

# Model/Pipeline config
MODEL_NAME = "distilbert-base-uncased" # Model
MAX_LENGTH = 256 # Maximum number of tokens in sample
BATCH_SIZE = 8 # Batches in samples
TOTAL_BATCHES = 30000 # All batches used over training
WARMUP_BATCHES = 4000 # Number of batches used for warmup

# The learning rate/weight decay/norm base details
BASE_LR = 6e-5
BASE_WD = 0.01
BASE_NORM = 1.0

ANNEALING_START_SCHEDULE = 1.0
ANNEALING_END_SCHEDULE = 0.01

# The threshold annealing instead proceeds as...

WARMUP_MULTIPLIER = 20.0
THRESHOLD_START_SCHEDULE = 0.95
THRESHOLD_END_SCHEDULE = 0.25

## Standard Boilerplate

Largely standard boilerplate here.
We make a model, we make an AdamW optimizer,
we make a pipeline that loads imdb and tokenizes it

In [None]:
def make_model()->PreTrainedModel:
    """Load pretrained model with classification head."""
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=2
    )
    return model.to(DEVICE)

def make_dataloader()->DataLoader:
    """Load and tokenize IMDB dataset, return DataLoader."""
    dataset = load_dataset("imdb", split="train")  # Subset for faster demo
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    def tokenize(examples):
        result = tokenizer(
            examples["text"],
            truncation=True,
            max_length=MAX_LENGTH,
            padding="max_length"
        )
        result["labels"] = examples["label"]
        return result

    dataset = dataset.map(tokenize, batched=True)
    dataset = dataset.shuffle(seed=42)
    dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

    return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

def make_optimizer(model: PreTrainedModel)->Optimizer:
    """Create optimizer with base hyperparameter values that schedules multiply against."""
    return AdamW(
        model.parameters(),
        lr=BASE_LR,
        weight_decay=BASE_WD
    )

## Schedule Factory, and the Novelty

This is where ScheduleAnything comes in. We're going to bind a new field to the optimizer, bind two schedules, and define a helper that can tell when it is time to step.

**The pattern:**
1. Add optimizer fields as needed
2. Create a schedule for each hyperparameter you want to control
3. Use `schedule_target` to specify which hyperparameter each schedule controls
4. Wrap them in `SynchronousSchedule` to keep them coordinated
5. Define utilities in the same place using tsa that respond to your extra hyperparameters to invoke in the training loop.

**Crucially**, this means downstream can access through well abstracted utilities, maintaining separation of concern.

In [None]:
def make_schedule(optimizer: Optimizer)->SynchronousSchedule:
    """
    Create coordinated schedules for learning rate and weight decay.

    Returns a SynchronousSchedule that steps both schedules together.
    """
    # Extend optimizer to include threshold
    tsa.extend_optimizer(optimizer,
                         name="gradient_norm_threshold",
                         default_value=BASE_NORM)

    # Learning rate: constant with warmup
    lr_scheduler = tsa.constant_with_warmup(
        optimizer,
        warmup_to_value=1.0, # Base learning rate already encoded
        num_warmup_steps=WARMUP_BATCHES,
        schedule_target='lr'
    )

    # Weight decay: Simulates normal learning rate annealing
    wd_schedule = tsa.cosine_annealing_with_warmup(
        optimizer,
        warmup_to_value=ANNEALING_START_SCHEDULE,
        anneal_to_value=ANNEALING_END_SCHEDULE,
        num_warmup_steps=WARMUP_BATCHES,
        num_training_steps=TOTAL_BATCHES,
        schedule_target='weight_decay'
    )

    # Gradient threshold: Same schedule as weight deay
    grad_schedule = tsa.cosine_annealing_with_inverse_warmup(
        optimizer,
        warmup_to_value=THRESHOLD_START_SCHEDULE,
        anneal_to_value=THRESHOLD_END_SCHEDULE,
        num_warmup_steps=WARMUP_BATCHES,
        num_training_steps=TOTAL_BATCHES,
        warmup_multiplier=WARMUP_MULTIPLIER,
        schedule_target="gradient_norm_threshold"
    )

    # Coordinate them to step together
    return tsa.SynchronousSchedule([lr_scheduler, wd_schedule, grad_schedule])

In [None]:
def get_grad_norm_threshold(optimizer: Optimizer)->float:
    """Get the grad norm threshold used to decide step time"""
    items = []
    for value, _, _ in tsa.get_param_groups_regrouped_by_key(optimizer, 'gradient_norm_threshold'):
        items.append(value)
    return max(items)

In [None]:
def get_grad_norm(model: PreTrainedModel)->float:
    """
    Gets the relevant norm out of the model using
    torch utilities
    """
    grads = [param.grad for param in model.parameters() if param.grad is not None]
    return torch.nn.utils.get_total_norm(grads)

## Train Loop
Standard PyTorch training loop as used in NLP, with schedules per batch. We abstract away the changes to logging, however.

In [None]:
def report_progress(schedule: SynchronousSchedule,
                    batch_idx: int,
                    loss: float,
                    norm: float,
                    accum_steps: int
                    ):
    last_lr = schedule.get_last_lr()[0]
    last_threshold = schedule.get_last_schedule("gradient_norm_threshold")[0]
    last_batch_size = BATCH_SIZE * accum_steps
    msg = (f"Batch {batch_idx+1:4d}/{TOTAL_BATCHES}"
          f" | Loss: {loss.item():.4f}"
          f" | LR: {last_lr:.4e}"
          f" | Target_Threshold: {last_threshold:.4f}"
          f" | Last Norm: {norm:.4f}"
          f" | Last Accum Steps: {accum_steps}"
          f" | Last Batch Size: {last_batch_size}"
          )
    print(msg)


In [None]:
def train(model: PreTrainedModel,
          dataloader: DataLoader,
          optimizer: Optimizer,
          schedule: LRScheduler,
          ):
    """Train for TOTAL_BATCHES batches."""
    model.train()
    data_iter = iter(dataloader)
    accum_steps = 0
    last_norm = 0
    last_num_accum_steps = 0

    for batch_idx in range(TOTAL_BATCHES):
        # Get next batch
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            batch = next(data_iter)

        # Move to device
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        # Forward pass and backwards pass
        # Increase batch size
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        accum_steps += 1

        # Optimizer steps when I hit or exceed my target
        # Fortunately, even for vector averages, f(a*x) = a*f(x), so we can rescale
        # externally.
        current_mean_norm = get_grad_norm(model)/accum_steps
        if get_grad_norm_threshold(optimizer) >= current_mean_norm:
            # Rescale gradients as mean by dividing
            # by number of accum steps.
            for param in model.parameters():
                if param.grad is not None:
                    param.grad /= accum_steps

            # Run update
            optimizer.step()
            optimizer.zero_grad()
            last_num_accum_steps = accum_steps
            last_norm = current_mean_norm
            accum_steps = 0

        # Step schedules
        schedule.step()

        # Log progress
        if (batch_idx + 1) % LOGGING_RATE == 0:
            assert len(schedule.get_last_lr()) == 1, "update logging system when adding param groups"
            report_progress(schedule, batch_idx, loss, last_norm, last_num_accum_steps)


## Putting It All Together

Create the components and train.


In [None]:

def main():
    print("Setting up model and data...")
    model = make_model()
    dataloader = make_dataloader()

    print("Creating optimizer and schedules...")
    optimizer = make_optimizer(model)
    schedule = make_schedule(optimizer)

    #print(f"Scheduling: {schedule.schedule_names}")
    print(f"Training for {TOTAL_BATCHES} batches with {WARMUP_BATCHES} warmup")
    print(f"Device: {DEVICE}\n")

    train(model, dataloader, optimizer, schedule)

    print(f"\nTraining complete!")

if __name__ == '__main__':
    main()