Skip to content

amoudgl/celo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Celo: Training Versatile Learned Optimizers on a Compute Diet

Code style: black

JAX implementation of our paper Celo: Training Versatile Learned Optimizers on a Compute Diet:

@article{moudgil2025learning,
  title={Learning Versatile Optimizers on a Compute Diet},
  author={Moudgil, Abhinav and Knyazev, Boris and Lajoie, Guillaume and Belilovsky, Eugene},
  journal={arXiv preprint arXiv:2501.12670},
  year={2025}
}

We open-source all the pretrained optimizers including baselines along with meta-training and evaluation scripts. We also provide scripts to test learned optimizers on MLCommons AlgoPerf benchmark.

Overview

We propose Celo, a compute-efficient learned optimizer, which meta-generalizes and scales better than prior learned optimizer approaches with a given meta-training compute budget. Celo design is much simpler than previous learned optimizers and consists of mainly two modules: (1) learned MLP update rule (2) learned scheduler. These two modules are learned in a decoupled two-phase meta-training procedure with task augmentation using Persistent Evolutionary Strategies (PES):

result

We evaluate Celo meta-generalization capabilities on a set of 17 diverse evaluation tasks from VeLOdrome benchmark including convnets, transformer language modeling, ViTs, autoencoders, etc and also compare with 15 tuned optimizers on these tasks including Adam, Shampoo, etc. We refer the reader to our paper for more details.

Setup

First, setup learned_optimization repo along with its dependencies by following the instructions here.

Then, install celo package:

git clone git@github.com:amoudgl/celo.git
cd celo
pip install -e .

Usage

All the optimizers implemented in this repo follow optax-style functional syntax which takes state containing params as input and return updated state with updated params:

opt = Adam(...)

# init method initializes all internal buffers such as momentum, second-moment, etc
# using model param shapes and returns an optimizer state containing model params
# NOTE: learned optimizers may require more inputs such as number of steps
opt_state = opt.init(params=model_params, num_steps=num_steps)

# in training step, update method takes grad as input along with previous state
# and return next state
# NOTE: learned optimizers may take additional input at each step such as loss
opt_state = opt.update(opt_state, grad=grad, loss=loss)

Check out the base class here for more details on the init and update method arguments, which are consistently followed by all optimizers implemented in this repository.

Download

Trained checkpoints for Celo and all the learned optimizer baselines are available on HuggingFace for download:

Optimizer HuggingFace Download command
celo repo huggingface-cli download amoudgl/celo --local-dir ./
celo-adam repo huggingface-cli download amoudgl/celo-adam --local-dir ./
rnn-mlp-lopt repo huggingface-cli download amoudgl/rnn-mlp-lopt --local-dir ./
adafac-mlp-lopt repo huggingface-cli download amoudgl/adafac-mlp-lopt --local-dir ./
velo-s repo huggingface-cli download amoudgl/velo-s --local-dir ./
velo-4000 repo huggingface-cli download amoudgl/velo-4000 --local-dir ./

All the above learned optimizers except velo-4000 are meta-trained on a set of 4 image classification MLP tasks with task augmentation. We show that Celo meta-generalizes well beyond its meta-training distribution, and its performance on harder/larger tasks scales better than the baselines as the meta-training budget is increased. Hence our proposed recipe is scalable but the pretrained optimizers released above are mainly for research purposes and do not scale to every large-scale task as per our preliminary testing on AlgoPerf (likely due to limited meta-training). Please refer to our paper for more details. velo-4000 is a pretrained optimizer released by Google. It was meta-trained with large-scale compute, 4000 TPU months, on a mixture of tasks. We simply provide the released checkpoint in a format compatible with our repo.

HuggingFace Python API can also be used to download checkpoints like below:

from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="amoudgl/celo", local_dir="./")

To setup hugging CLI for download via command line, refer to this page.

Loading pretrained optimizer

To load learned optimizer params from a checkpoint path, use init_lopt_from_ckpt method:

from celo.utils import init_lopt_from_ckpt

path = 'path/to/pretrained/checkpoint/theta.state'
celo = get_optimizer('celo')
opt = init_lopt_from_ckpt(celo, path)  # ready-to-use learned optimizer!

# use like a standard optimizer as described above
opt_state = opt.init(...)

Example script

Following is a simple training script that trains an MLP on MNIST task using pretrained Celo optimizer:

import jax
from celo.factory import get_optimizer
from learned_optimization.tasks.fixed.image_mlp import ImageMLP_FashionMnist_Relu128x128
from huggingface_hub import hf_hub_download
from celo.utils import init_lopt_from_ckpt
key = jax.random.PRNGKey(7)

# initialize task
task = ImageMLP_FashionMnist_Relu128x128()  # simple example task from learned_optimization
key, key1 = jax.random.split(key)
params, state = task.init_with_state(key1)

# load pretrained optimizer
hf_hub_download(repo_id="amoudgl/celo", filename="theta.state", local_dir="./")
celo = get_optimizer('celo')
opt = init_lopt_from_ckpt(celo, "theta.state")

# initialize optimizer state
num_steps = 1000
opt_state = opt.init(params, model_state=state, num_steps=num_steps)

# training loop
for i in range(num_steps):
    # forward + backward pass
    batch = next(task.datasets.train)
    key, key1 = jax.random.split(key)
    params = opt.get_params(opt_state)
    loss, grad = jax.value_and_grad(task.loss)(params, key1, batch)

    # optimizer step
    opt_state = opt.update(opt_state, grad, loss=loss)

    # log
    if i % 50 == 0:
        print(f'[step {i}] loss: ', loss)

Checkout test script celo/test.py to evaluate any optimizer on all the tasks considered in our work.

Meta-training

To do phase 1 meta-training of Celo (only meta-train MLP update rule) with task augmentation, do:

python -m celo.train --optimizer celo_phase1 --train_partial --outer_iterations 100000 --max_unroll_length 2000 --seed 0 --trainer pes --name train_celo_phase1 --outer_lr 3e-4

To continue phase 2 meta-training of Celo (meta-train scheduler with frozen MLP update) from phase 1 checkpoint, do:

python -m celo.train --optimizer celo --train_partial --init_from_ckpt ./train_celo_phase1/theta.state --outer_iterations 100000 --max_unroll_length 2000 --seed 0 --name celo_phase2 --task fast_velo --outer_lr 3e-4 --aug reparam --aug_reparam_level global --name train_celo_phase2

To meta-train any other learned optimizer (say, RNN with MLP update rule) without two-stage training with task augmentation, simply skip --train_partial flag:

python -m celo.train --optimizer rnn_mlp --outer_iterations 100000 --max_unroll_length 2000 --seed 0 --task fast_velo --outer_lr 1e-4 --trainer pes --aug reparam --aug_reparam_level global --name train_rnn

To meta-train without task augmentation, skip --aug flag which is set to None by default.

Key flags:

  • --optimizer celo_phase1 -- uses only MLP update rule for parameter update without any scheduler
  • --optimizer celo -- does full celo update using both MLP update rule and learned scheduler
  • --train_partial -- meta-trains with some optimizer params frozen, frozen params are returned by the specified optimizer
  • --task -- meta-training task, we use fast_velo which contains a set of 4 small MLP tasks from velo paper
  • --trainer -- outer trainer for meta-gradient estimation (we use PES)
  • --outer_iterations -- total number of meta-iterations i.e. outer updates to optimizer
  • --max_unroll_length -- maximum rollout length during meta-training
  • --init_from_ckpt -- if specified, optimizer params are loaded from this checkpoint in meta-training instead of random initialization
  • --aug -- if specified, task augmentation via reparametrization will be used

Checkout all meta-training flags along with their descriptions in celo/train.py.

Acknowledgements

Huge thanks to the teams behind google/learned_optimization for open-sourcing their work -- this project wouldn't have been possible without it. We’d also like to acknowledge the neat evaluation library google-research/rliable, which we used to compute IQM metrics.

License

This project is released under the MIT License. However, some source files are adapted from google/learned_optimization repository and are licensed under the Apache License 2.0. These files are located in celo/ directory and retain their original license.

Releases

No releases published

Packages

No packages published

Languages