In [None]:
# timm is in colab by default now, but we need the main branch to show new optimizer features
!pip install git+https://github.com/huggingface/pytorch-image-models.git

Collecting git+https://github.com/huggingface/pytorch-image-models.git
  Cloning https://github.com/huggingface/pytorch-image-models.git to /tmp/pip-req-build-ttof0hei
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/pytorch-image-models.git /tmp/pip-req-build-ttof0hei
  Resolved https://github.com/huggingface/pytorch-image-models.git to commit 0b5264a108890f87317558e89adf643f4b330884
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: timm
  Building wheel for timm (pyproject.toml) ... [?25l[?25hdone
  Created wheel for timm: filename=timm-1.0.12.dev0-py3-none-any.whl size=2342090 sha256=2db512d47f33389eef15e185c7846b8fd6eceb5632362c675242acdb6579ce27
  Stored in directory: /tmp/pip-ephem-wheel-cache-7w4hkay2/wheels/db/23/a2/e7496a9eafb64fb93606c0ba3d59675246b74aa78939b80e39
Successfully bu

In [None]:
import torch
import torch.nn as nn
import timm.optim

In [None]:
# list all optimizers (with descriptions) available through timm factory, includes torch.optim, and also select bitsandbytes (bnb) and APEX (fused) optimizers.
for k,v in timm.optim.list_optimizers(with_description=True):
    print(f'{k}: {v}')

adabelief: Adapts learning rate based on gradient prediction error
adadelta: torch.optim Adadelta, Adapts learning rates based on running windows of gradients
adafactor: Memory-efficient implementation of Adam with factored gradients
adafactorbv: Big Vision variant of Adafactor with factored gradients, half precision momentum
adagrad: torch.optim Adagrad, Adapts learning rates using cumulative squared gradients
adahessian: An Adaptive Second Order Optimizer
adam: torch.optim Adam (Adaptive Moment Estimation)
adamax: torch.optim Adamax, Adam with infinity norm for more stable updates
adamp: Adam with built-in projection to unit norm sphere
adamw: torch.optim Adam with decoupled weight decay regularization
adan: Adaptive Nesterov Momentum Algorithm
adanw: Adaptive Nesterov Momentum with decoupled weight decay
adopt: Modified Adam that can converge with any β2 with the optimal rate
adoptw: Modified AdamW (decoupled decay) that can converge with any β2 with the optimal rate
bnbadam: bitsan

In [None]:
# Use the timm factory to pass models directly when creating optimizer.
# NOTE: If you pass a model (nn.Module) instead of parameters to the factory it will
# auto-create param groups for weight-decay (or layer-decay if enabled).
model = nn.Sequential(nn.Linear(1, 16))
opt = timm.optim.create_optimizer_v2(model, 'adafactorbv')
opt

AdafactorBigVision (
Parameter Group 0
    beta2_cap: 0.999
    clipping_threshold: None
    decay_offset: 0
    decay_rate: 0.8
    eps: None
    foreach: False
    lr: 1.0
    min_dim_size_to_factor: 32
    momentum: 0.9
    momentum_dtype: torch.bfloat16
    unscaled_wd: False
    weight_decay: 0.0
)

In [None]:
# The optimizer classes can be fetched dynamically (based on string) to allow config friendly use without using the factory.
opt_class = timm.optim.get_optimizer_class('adoptw')
opt_class

functools.partial(<class 'timm.optim.adopt.Adopt'>, decoupled=True)

In [None]:
opt2 = opt_class(model.parameters())
opt2

Adopt (
Parameter Group 0
    betas: (0.9, 0.9999)
    capturable: False
    decoupled: True
    differentiable: False
    eps: 1e-06
    foreach: None
    lr: 0.001
    maximize: False
    weight_decay: 0.0
)

In [None]:
# The class function will bind default arguments when optimizer info specifies them, e.g. 'sgd' in `timm` has nesterov enabled by default in factory, that will be bound with the class unless disabled.
SgdWithNesterov = timm.optim.get_optimizer_class('sgd')
SgdUnbound = timm.optim.get_optimizer_class('sgd', bind_defaults=False)
print(SgdWithNesterov)
print(SgdUnbound)

functools.partial(<class 'torch.optim.sgd.SGD'>, nesterov=True)
<class 'torch.optim.sgd.SGD'>


In [None]:
# The information dataclasses that the factory registration uses can be queried, these could be expanded to cover more optimizer traits.
opt_info = timm.optim.get_optimizer_info('nadamw')
opt_info

OptimInfo(name='nadamw', opt_class=<class 'timm.optim.nadamw.NAdamW'>, description='Adam with Nesterov momentum and decoupled weight decay', has_eps=True, has_momentum=False, has_betas=True, num_betas=2, second_order=False, defaults=None)

In [None]:
# Optimizer classes can be used directly as with optimizers in `torch.optim`
opt3 = timm.optim.NAdamW(model.parameters())
opt3

NAdamW (
Parameter Group 0
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.001
    maximize: False
    weight_decay: 0.01
)