In [None]:
#|default_exp optimizer.finetune

In [None]:
#|exporti
# Contains code from:
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai

# Fine-tuning Weight Decay
> Optimizers with fine-tuning weight decay from Katherine Crowson's [AdamWFineTune](https://gist.github.com/crowsonkb/f646976de8033b371ea17cb9b1c1561f).

`FineTuneOpt` adds and additional optional weight decay `ft_wd` towards the starting value, to prevent overfitting to the new dataset during fine-tuning. This version uses fastai splitters to only apply the fine-tuning weight decay to the pre-trained model body and not the new head.

All fastai optimizers are replicated here with the suffix FT to indicate they are `FineTuneOpt`.

Early experimental results suggest `AdamFT` without weight decay might be equivalent to `AdamW` in vision fine-tuning performance.

In [None]:
#|export
from __future__ import annotations

from fastcore.basics import GetAttr

from fastai.optimizer import (Optimizer, _update, weight_decay, l2_reg, average_grad, sgd_step,
                              momentum_step, average_sqr_grad, rms_prop_step, step_stat, adam_step,
                              radam_step, qhadam_step, larc_layer_lr, larc_step, lamb_step)
from fastxtend.imports import *

In [None]:
#|hide
from fastxtend.test_utils import *

In [None]:
#|hide
def tst_param(val, grad=None):
    "Create a tensor with `val` and a gradient of `grad` for testing"
    res = tensor([val]).float()
    res.grad = tensor([val/10 if grad is None else grad]).float()
    return res

r = L.range(4)
def tst_params(): return r.map(tst_param)

## FineTuneOpt -

In [None]:
#|export
class FineTuneOpt(Optimizer):
    """
    Modification of the base optimizer class for the fastai library, updating `params` with `cbs`

    In combination with the `fine_tune_wd` callback, adds optional weight decay `ft_wd` towards the starting value, 
    to prevent overfitting to the new dataset during fine-tuning.

    By default, will not apply to the fine-tuning head, just the pretrained body.

    From: https://gist.github.com/crowsonkb/f646976de8033b371ea17cb9b1c1561f
    """
    _keep_on_clear = ['force_train', 'do_wd']
    def __init__(self,
        params:Tensor, # Parameters and hyper parameters
        cbs:list, # `Optimizer` callbacks
        train_bn:bool=True, # Batch normalization is always trained
        wd_ft_head:bool=False, # Apply fine tuning weight decay to model head
        **defaults # Default values to set on hyper parameters
    ):
        super().__init__(params, cbs, train_bn, **defaults)
        self.wd_ft_head, self.set_orig_p = wd_ft_head, True

    @torch.no_grad()
    def step(self, closure=None):
        if self.set_orig_p:
            self.set_orig_p = False
            n = slice(None) if self.wd_ft_head or len(self.param_lists)<=1 else slice(None, -1)
            for p,pg,state,hyper in self.all_params(n):
                state['orig_p'] = p.detach().clone()
                self.state[p] = state
        super().step(closure)

    def clear_state(self):
        self.set_orig_p = True
        super().clear_state()

    def state_dict(self):
        state = super().state_dict()
        state.update({'set_orig_p': self.set_orig_p})
        return state

    def load_state_dict(self, sd):
        self.set_orig_p = sd.pop('set_orig_p')
        super().load_state_dict(sd)

In [None]:
#|hide
# Test the initializtion of the FineTuneOpt optimizer
opt = FineTuneOpt([1,2,3], noop)
test_eq(opt.param_lists, [[1,2,3]])
opt = FineTuneOpt(range(3), noop)
test_eq(opt.param_lists, [[0,1,2]])
opt = FineTuneOpt([[1,2],[3]], noop)
test_eq(opt.param_lists, [[1,2],[3]])
opt = FineTuneOpt(([o,o+1] for o in range(0,4,2)), noop)
test_eq(opt.param_lists, [[0,1],[2,3]])

In [None]:
#|hide
# Test that callbacks have not changed
def tst_arg(p, lr=0, **kwargs): return p
tst_arg.defaults = dict(lr=1e-2)

def tst_arg2(p, lr2=0, **kwargs): return p
tst_arg2.defaults = dict(lr2=1e-3)

def tst_arg3(p, mom=0, **kwargs): return p
tst_arg3.defaults = dict(mom=0.9)

def tst_arg4(p, **kwargs): return p

opt = FineTuneOpt([1,2,3], [tst_arg,tst_arg2, tst_arg3])
test_eq(opt.hypers, [{'lr2': 1e-3, 'mom': 0.9, 'lr': 1e-2}])
opt = FineTuneOpt([1,2,3], tst_arg, lr=0.1)
test_eq(opt.hypers, [{'lr': 0.1}])
opt = FineTuneOpt([[1,2],[3]], tst_arg)
test_eq(opt.hypers, [{'lr': 1e-2}, {'lr': 1e-2}])
opt = FineTuneOpt([[1,2],[3]], tst_arg, lr=0.1)
test_eq(opt.hypers, [{'lr': 0.1}, {'lr': 0.1}])

opt = FineTuneOpt([[1,2],[3]], tst_arg, lr=[0.1,0.2])
test_eq(opt.hypers, [{'lr': 0.1}, {'lr': 0.2}])
opt = FineTuneOpt([[1,2],[3],[4]], tst_arg, lr=slice(1e-2))
test_eq(opt.hypers, [{'lr': 1e-3}, {'lr': 1e-3}, {'lr': 1e-2}])
opt = FineTuneOpt([[1,2],[3],[4]], tst_arg, lr=slice(1e-4,1e-2))
test_eq(opt.hypers, [{'lr': 1e-4}, {'lr': 1e-3}, {'lr': 1e-2}])
test_eq(opt.param_groups, [{'params': [1,2], 'lr': 1e-4}, {'params': [3], 'lr': 1e-3}, {'params': [4], 'lr': 1e-2}])
test_fail(lambda: FineTuneOpt([[1,2],[3],[4]], tst_arg, lr=np.array([0.1,0.2])))

## Fine Tune WD Step -

In [None]:
#|export
def fine_tune_wd(p, lr, ft_wd, orig_p=None, do_wd=True, **kwargs):
    "Weight decay `p` towards the starting value `orig_p`"
    if do_wd and ft_wd !=0 and orig_p is not None:
        p.lerp_(orig_p, lr*ft_wd)

In [None]:
#|hide
p = tst_param(1., 0.1)
fine_tune_wd(p, 1., 0.5, tensor([0.5]))
test_eq(p, tensor([0.75]))
test_eq(p.grad, tensor([0.1]))

# Optimizers

In [None]:
#|export
def SGDFT(params, lr, mom=0., wd=0., ft_wd=0., decouple_wd=True, wd_ft_head=False):
    "A `Optimizer` for SGD with `lr` and `mom` and `params`"
    cbs = [weight_decay] if decouple_wd else [l2_reg]
    cbs += [fine_tune_wd]
    if mom != 0: cbs.append(average_grad)
    cbs.append(sgd_step if mom==0 else momentum_step)
    return FineTuneOpt(params, cbs, lr=lr, mom=mom, wd=wd, ft_wd=ft_wd, wd_ft_head=wd_ft_head)

In [None]:
#|hide
#Vanilla SGD
params = tst_params()
opt = SGDFT(params, lr=0.1)
opt.step()
test_close([p.item() for p in params], [i*0.99 for i in range(4)])
opt.step()
test_close([p.item() for p in params], [i*0.98 for i in range(4)])

In [None]:
#|hide
#Vanilla SGD with FT_WD
params = tst_params()
opt = SGDFT(params, lr=0.1, ft_wd=0.1)
opt.step()
test_close([p.item() for p in params], [i*0.99 for i in range(4)])
opt.step()
test_close([p.item() for p in params], [i*0.9801 for i in range(4)])

## RMSPropFT -

In [None]:
#|export
def RMSPropFT(params, lr, sqr_mom=0.99, mom=0., wd=0., ft_wd=0., decouple_wd=True, wd_ft_head=False):
    "A `FineTuneOpt` for RMSProp with `lr`, `sqr_mom`, `mom` and `params`"
    cbs = [weight_decay] if decouple_wd else [l2_reg]
    cbs += [fine_tune_wd] + [average_sqr_grad] if mom==0. else [average_grad, average_sqr_grad]
    cbs.append(rms_prop_step)
    return FineTuneOpt(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, wd=wd, ft_wd=ft_wd, wd_ft_head=wd_ft_head)

## AdamFT -

In [None]:
#|export
def AdamFT(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0.01, ft_wd=0., decouple_wd=True, wd_ft_head=False):
    "A `FineTuneOpt` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    cbs = [weight_decay] if decouple_wd else [l2_reg]
    cbs += [fine_tune_wd, partial(average_grad, dampening=True), average_sqr_grad, step_stat, adam_step]
    return FineTuneOpt(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, ft_wd=ft_wd, wd_ft_head=wd_ft_head)

## RAdamFT -

In [None]:
#|export
def RAdamFT(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., ft_wd=0., beta=0., decouple_wd=True, wd_ft_head=False):
    "A `FineTuneOpt` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    cbs = [weight_decay] if decouple_wd else [l2_reg]
    cbs += [fine_tune_wd, partial(average_grad, dampening=True), average_sqr_grad, step_stat, radam_step]
    return FineTuneOpt(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, ft_wd=ft_wd, beta=beta, wd_ft_head=wd_ft_head)

## QHAdamFT -

In [None]:
#|export
def QHAdamFT(params, lr, mom=0.999, sqr_mom=0.999, nu_1=0.7, nu_2 = 1.0, eps=1e-8, wd=0., ft_wd=0., decouple_wd=True, wd_ft_head=False):
    "An `FineTuneOpt` for Adam with `lr`, `mom`, `sqr_mom`, `nus`, eps` and `params`"
    cbs = [weight_decay] if decouple_wd else [l2_reg]
    cbs += [fine_tune_wd, partial(average_grad, dampening=True), partial(average_sqr_grad, dampening=True), step_stat, qhadam_step]
    return FineTuneOpt(params, cbs, lr=lr, nu_1=nu_1, nu_2=nu_2 ,
                       mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd,
                       ft_wd=ft_wd, wd_ft_head=wd_ft_head)

## LarcFT -

In [None]:
#|export
def LarcFT(params, lr, mom=0.9, clip=True, trust_coeff=0.02, eps=1e-8, wd=0., ft_wd=0., decouple_wd=True, wd_ft_head=False):
    "A `FineTuneOpt` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    cbs = [weight_decay] if decouple_wd else [l2_reg]
    cbs += [fine_tune_wd]
    if mom!=0.: cbs.append(average_grad)
    cbs += [partial(larc_layer_lr, clip=clip), larc_step]
    return FineTuneOpt(params, cbs, lr=lr, mom=mom, trust_coeff=trust_coeff, eps=eps, wd=wd, ft_wd=ft_wd, wd_ft_head=wd_ft_head)

## LambFT -

In [None]:
#|export
def LambFT(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., ft_wd=0., decouple_wd=True, wd_ft_head=False):
    "A `FineTuneOpt` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    cbs = [weight_decay] if decouple_wd else [l2_reg]
    cbs += [fine_tune_wd, partial(average_grad, dampening=True), average_sqr_grad, step_stat, lamb_step]
    return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, ft_wd=ft_wd, wd_ft_head=wd_ft_head)

## LookaheadFT -

In [None]:
#|export
class LookaheadFT(FineTuneOpt, GetAttr):
    "Wrap a `FineTuneOpt` `opt` in a Lookahead optimizer"
    _default='opt'
    def __init__(self, opt, k=6, alpha=0.5):
        store_attr('opt,k,alpha')
        self._init_state()

    def step(self, closure=None):
        if closure is not None: raise NotImplementedError("fastai optimizers currently do not support closure")
        if self.slow_weights is None: self._copy_weights()
        self.opt.step()
        self.count += 1
        if self.count%self.k != 0: return
        for slow_pg,fast_pg in zip(self.slow_weights,self.param_lists):
            for slow_p,fast_p in zip(slow_pg,fast_pg):
                slow_p.data.add_(fast_p.data-slow_p.data, alpha=self.alpha)
                fast_p.data.copy_(slow_p.data)

    def clear_state(self):
        self.opt.clear_state()
        self._init_state()

    def state_dict(self):
        state = self.opt.state_dict()
        state.update({'count': self.count, 'slow_weights': self.slow_weights})
        return state

    def load_state_dict(self, sd):
        self.count = sd.pop('count')
        self.slow_weights = sd.pop('slow_weights')
        self.opt.load_state_dict(sd)

    def _init_state(self): self.count,self.slow_weights = 0,None
    def _copy_weights(self): self.slow_weights = L(L(p.clone().detach() for p in pg) for pg in self.param_lists)

    @property
    def param_lists(self): return self.opt.param_lists
    @param_lists.setter
    def param_lists(self, v): self.opt.param_lists = v

## rangerFT -

In [None]:
#|export
@delegates(RAdamFT)
def rangerFT(p, lr, mom=0.95, wd=0.01, ft_wd=0., eps=1e-6, **kwargs):
    "Convenience method for `LookaheadFT` with `RAdamFT`"
    return LookaheadFT(RAdamFT(p, lr=lr, mom=mom, wd=wd, ft_wd=ft_wd, eps=eps, **kwargs))