In [None]:
# default_exp optimizer

In [None]:
#export
from local.torch_basics import *
from local.test import *
from local.notebook.showdoc import show_doc

# Optimizer

> Define the general fastai optimizer and the variants

## Optimizer -

In [None]:
# export
class Optimizer():
    "Base optimizer class for the fastai library, updating `params` with `steppers`"
    def __init__(self, params, steppers, stats=None, train_bn=True, **defaults):
        steppers,params = L(steppers),L(params)
        self.stats,self.state,self.train_bn = L(stats),{},train_bn
        for stat in self.stats: defaults = {**getattr(stat, 'defaults', {}), **defaults}
        for step in steppers: defaults = {**getattr(step, 'defaults', {}), **defaults}
        self.param_groups = params if isinstance(params[0], (L,list)) else L([params])
        self.step_func = compose(*steppers)
        self.hypers = L({**defaults} for p in self.param_groups)

    def _grad_params(self):
        "Helper function to loop over param groups then params that have a grad"
        return [(p,hyper) for pg,hyper in zip(self.param_groups,self.hypers)
            for p in pg if p.grad is not None]

    def zero_grad(self):
        for p,hyper in self._grad_params():
            p.grad.detach_()
            p.grad.zero_()

    def step(self):
        for p,hyper in self._grad_params():
            state = self.state.get(p, {})
            for stat in self.stats: state = stat(state, p, **hyper)
            self.step_func(p, **{**state, **hyper})
            self.state[p] = state
    
    def _set_require_grad(self, pg, rg):
        for p in pg: p.requires_grad_(rg or self.state.get(p, {}).get('force_train', False))
    
    def freeze_to(self, n):
        for pg in self.param_groups[:n]: self._set_require_grad(pg, False)
        for pg in self.param_groups[n:]: self._set_require_grad(pg, True)

    def freeze(self):
        assert(len(self.param_groups)>1)
        self.freeze_to(-1)

    def unfreeze(self): self.freeze_to(0)

In [None]:
add_docs(Optimizer, 
         zero_grad="Zero all the grad attributes of the parameters",
         step="Update the stats and execute the steppers in on all parameters that have a grad",
         freeze_to="Freeze parameter groups up to `n`",
         freeze="Freeze up to last parameter group",
         unfreeze="Unfreeze the entire model")

### Initializing an Optimizer

`params` will be used to create the `param_groups` of the optimizer. If it's a collection (or a generator) of parameters, it will be a `L` containing one `L` with all the parameters. To define multiple parameter groups `params` should be passed as a collection (or a generator) of `L`s.

> Note: In PyTorch, `model.parameters()` returns a generator with all the parameters, that you can directly pass to `Optimizer`.

In [None]:
opt = Optimizer([1,2,3], noop)
test_eq(opt.param_groups, [[1,2,3]])
opt = Optimizer(range(3), noop)
test_eq(opt.param_groups, [[0,1,2]])
opt = Optimizer([[1,2],[3]], noop)
test_eq(opt.param_groups, [[1,2],[3]])
opt = Optimizer(([o,o+1] for o in range(0,4,2)), noop)
test_eq(opt.param_groups, [[0,1],[2,3]])

`steppers` is a list of functions that will be composed when applying the step. For instance, you can compose a function making the SGD step, with another one applying weight decay. Additionally, each `stepper` can have a `defaults` attribute that contains hyper-parameters and their default value. Those are all gathered at initialization, and new values can be passed to override those defaults with the `defaults` kwargs.

Once the defaults have all been pulled off, they are copied as many times as there are `param_groups` and stored in `hypers`. To apply different hyper-parameters to different groups (differential learning rates, or no weight decay for certain layers for instance), you will need to adjsut those values after the init. 

In [None]:
def tst_arg(p, lr=0, **kwargs): return p
tst_arg.defaults = dict(lr=1e-2)

opt = Optimizer([1,2,3], tst_arg)
test_eq(opt.hypers, [{'lr': 1e-2}])
opt = Optimizer([1,2,3], tst_arg, lr=0.1)
test_eq(opt.hypers, [{'lr': 0.1}])
opt = Optimizer([[1,2],[3]], tst_arg)
test_eq(opt.hypers, [{'lr': 1e-2}, {'lr': 1e-2}])
opt = Optimizer([[1,2],[3]], tst_arg, lr=0.1)
test_eq(opt.hypers, [{'lr': 0.1}, {'lr': 0.1}])

### Basic steppers

To be able to give examples of optimizer steps, we will need some steppers, like the following:

In [None]:
#export
def sgd_step(p, lr, **kwargs):
    p.data.add_(-lr, p.grad.data)
    return p

In [None]:
def tst_param(val, grad):
    "Create a tensor with `val` and a gradient of `grad` for testing"
    res = tensor([val]).float()
    res.grad = tensor([grad]).float()
    return res

In [None]:
p = tst_param(1., 0.1)
p = sgd_step(p, 1.)
test_eq(p, tensor([0.9]))
test_eq(p.grad, tensor([0.1]))

In [None]:
#export
def weight_decay(p, lr, wd, do_wd=True, **kwargs):
    "Weight decay as decaying `p` with `lr*wd`"
    if do_wd: p.data.mul_(1 - lr*wd)
    return p
weight_decay.defaults = dict(wd=0.)

In [None]:
p = tst_param(1., 0.1)
p = weight_decay(p, 1., 0.1)
test_eq(p, tensor([0.9]))
test_eq(p.grad, tensor([0.1]))

In [None]:
#export
def l2_reg(p, lr, wd, do_wd=True, **kwargs):
    "L2 regularization as adding `wd*p` to `p.grad`"
    if do_wd: p.grad.data.add_(wd, p.data)
    return p
l2_reg.defaults = dict(wd=0.)

In [None]:
p = tst_param(1., 0.1)
p = l2_reg(p, 1., 0.1)
test_eq(p, tensor([1.]))
test_eq(p.grad, tensor([0.2]))

> Warning: Weight decay and L2 regularization is the same thing for basic SGD, but for more complex optimizers, they are very different. See [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) for more information.

### Making the step

In [None]:
show_doc(Optimizer.step)

<h4 id="Optimizer.step" class="doc_header"><code>Optimizer.step</code><a href="https://github.com/fastai/fastai_dev/tree/master/dev/__main__.py#L23" class="source_link" style="float:right">[source]</a></h4>

> <code>Optimizer.step</code>()

Update the stats and execute the steppers in on all parameters that have a grad

This method will loop over all param groups, then all parameters for which `grad` is not None and call each function in `stepper`, passing it the parameter `p` with the hyper-parameters in the corresponding dict in `hypers`.

In [None]:
#test basic step
def tst_params(): return [tst_param(i, 0.1*i) for i in range(4)]

params = tst_params()
opt = Optimizer(params, sgd_step, lr=0.1)
opt.step()
test_close([p.item() for p in params], [i*0.99 for i in range(4)])

In [None]:
#test two steps
params = tst_params()
opt = Optimizer(params, [weight_decay, sgd_step], lr=0.1, wd=0.1)
opt.step()
test_close([p.item() for p in params], [i*0.98 for i in range(4)])

In [None]:
#test None gradients are ignored
params = tst_params()
opt = Optimizer(params, sgd_step, lr=0.1)
params[-1].grad = None
opt.step()
test_close([p.item() for p in params], [0., 0.99, 1.98, 3.])

In [None]:
#test discriminative lrs
params = tst_params()
opt = Optimizer([params[:2], params[2:]], sgd_step, lr=0.1)
opt.hypers[0]['lr'] = 0.01
opt.step()
test_close([p.item() for p in params], [0., 0.999, 1.98, 2.97])

In [None]:
show_doc(Optimizer.zero_grad)

<h4 id="Optimizer.zero_grad" class="doc_header"><code>Optimizer.zero_grad</code><a href="https://github.com/fastai/fastai_dev/tree/master/dev/__main__.py#L18" class="source_link" style="float:right">[source]</a></h4>

> <code>Optimizer.zero_grad</code>()

Zero all the grad attributes of the parameters

In [None]:
params = tst_params()
opt = Optimizer(params, [weight_decay, sgd_step], lr=0.1, wd=0.1)
opt.zero_grad()
[test_eq(p.grad, tensor([0.])) for p in params];

In [None]:
#hide
params = tst_params()
opt = Optimizer([params[:2], params[2:]], sgd_step, lr=0.1)
opt.hypers[0]['lr'] = 0.01
test_eq(opt._grad_params(), [(tensor([0.]), {'lr': 0.01}),
                            (tensor([1.]), {'lr': 0.01}),
                            (tensor([2.]), {'lr': 0.1}),
                            (tensor([3.]), {'lr': 0.1})])

`Optimizer` has `stats` which are functions taking the state associated with a parameter.  `stats` use that parameter, plus the optimizer hyper-parameters, to update the state. 
That state can then be used by any stepper.  The best example is a momentum calculation. 
`stats` are initialized to an empty dictionary the first time we try to access it, and after that the `stat` function will have to be properly initialized.

In [None]:
def tst_stat(state, p, **kwargs): 
    state['sum'] = state.get('sum', torch.zeros_like(p)) + p.data
    return state
tst_stat.defaults = {'mom': 0.9}

#Test Optimizer init
opt = Optimizer([1,2,3], noop, stats=tst_stat)
test_eq(opt.hypers, [{'mom': 0.9}])
opt = Optimizer([1,2,3], noop, stats=tst_stat, mom=0.99)
test_eq(opt.hypers, [{'mom': 0.99}])

#Test stat
x = torch.randn(4,5)
state = tst_stat({}, x)
assert 'sum' in state
test_eq(state['sum'], x)
state = tst_stat(state, x)
test_eq(state['sum'], 2*x)

## Statistics

In [None]:
# export
def average_grad(state, p, mom, dampening=False, **kwargs):
    "Keeps track of the avg grads of `p` in `state` with `mom`."
    if 'grad_avg' not in state: state['grad_avg'] = torch.zeros_like(p.grad.data)
    damp = 1-mom if dampening else 1.
    state['grad_avg'].mul_(mom).add_(damp, p.grad.data)
    return state

average_grad.defaults = dict(mom=0.9)

`dampening=False` gives the classical formula for momentum in SGD: 
```
new_val = old_val * mom + grad
```
whereas `dampening=True` makes it an exponential moving average:
```
new_val = old_val * mom + grad * (1-mom)
```

In [None]:
p = tst_param([1,2,3], [4,5,6])
state = {}
state = average_grad(state, p, mom=0.9)
test_eq(state['grad_avg'], p.grad)
state = average_grad(state, p, mom=0.9)
test_eq(state['grad_avg'], p.grad * 1.9)
#Test dampening
state = {}
state = average_grad(state, p,  mom=0.9, dampening=True)
test_eq(state['grad_avg'], 0.1*p.grad)
state = average_grad(state, p, mom=0.9, dampening=True)
test_eq(state['grad_avg'], (0.1*0.9+0.1)*p.grad)

In [None]:
# export
def average_sqr_grad(state, p, sqr_mom, dampening=True, **kwargs):
    if 'sqr_avg' not in state: state['sqr_avg'] = torch.zeros_like(p.grad.data)
    damp = 1-sqr_mom if dampening else 1.
    state['sqr_avg'].mul_(sqr_mom).addcmul_(damp, p.grad.data, p.grad.data)
    return state

average_sqr_grad.defaults = dict(sqr_mom=0.99)

`dampening=False` gives the classical formula for momentum in SGD: 
```
new_val = old_val * mom + grad**2
```
whereas `dampening=True` makes it an exponential moving average:
```
new_val = old_val * mom + (grad**2) * (1-mom)
```

In [None]:
p = tst_param([1,2,3], [4,5,6])
state = {}
state = average_sqr_grad(state, p, sqr_mom=0.99, dampening=False)
test_eq(state['sqr_avg'], p.grad.pow(2))
state = average_sqr_grad(state, p, sqr_mom=0.99, dampening=False)
test_eq(state['sqr_avg'], p.grad.pow(2) * 1.99)
#Test dampening
state = {}
state = average_sqr_grad(state, p,  sqr_mom=0.99)
test_close(state['sqr_avg'], 0.01*p.grad.pow(2))
state = average_sqr_grad(state, p, sqr_mom=0.99)
test_close(state['sqr_avg'], (0.01*0.99+0.01)*p.grad.pow(2))

### Freezing part of the model

In [None]:
#Freezing the first layer
params = [tst_params(), tst_params(), tst_params()]
opt = Optimizer(params, sgd_step, lr=0.1)
opt.freeze_to(1)
test_eq(L(params[0]).mapped(lambda x:x.requires_grad), [False]*4)
for i in {1,2}: test_eq(L(params[i]).mapped(lambda x:x.requires_grad), [True]*4)
    
#Unfreezing
opt.unfreeze()
for i in range(2): test_eq(L(params[i]).mapped(lambda x:x.requires_grad), [True]*4)

Parameters such as batchnorm weights/bias can be marked to always be in training mode, just put `force_train=true` in their state.

In [None]:
params = [tst_params(), tst_params(), tst_params()]
opt = Optimizer(params, sgd_step, lr=0.1)
for p in L(params[1])[[1,3]]: opt.state[p] = {'force_train': True}
opt.freeze()
test_eq(L(params[0]).mapped(lambda x:x.requires_grad), [False]*4)
test_eq(L(params[1]).mapped(lambda x:x.requires_grad), [False, True, False, True])
test_eq(L(params[2]).mapped(lambda x:x.requires_grad), [True]*4)

## Optimizers

### SGD with momentum

In [None]:
#export
def momentum_step(p, lr, grad_avg, **kwargs):
    "Step for SGD with momentum with `lr`"
    p.data.add_(-lr, grad_avg)
    return p

In [None]:
#export
def SGD(params, lr, mom=0., wd=0., true_wd=True):
    "A `Optimizer` for SGD with `lr` and `mom` and `params`"
    steppers = [] if wd==0. else [weight_decay] if true_wd else [l2_reg]
    steppers.append(sgd_step if mom==0 else momentum_step)
    if mom == 0.: return Optimizer(params, steppers, lr=lr, wd=wd)
    else: return Optimizer(params, steppers, stats=average_grad, lr=lr, mom=mom, wd=wd)

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `true_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
#Vanilla SGD
params = tst_params()
opt = SGD(params, lr=0.1)
opt.step()
test_close([p.item() for p in params], [i*0.99 for i in range(4)])
opt.step()
[p.item() for p in params]
test_close([p.item() for p in params], [i*0.98 for i in range(4)])

In [None]:
#SGD with momentum
params = tst_params()
opt = SGD(params, lr=0.1, mom=0.9)
assert isinstance(opt, Optimizer)
opt.step()
test_close([p.item() for p in params], [i*0.99 for i in range(4)])
opt.step()
[p.item() for p in params]
test_close([p.item() for p in params], [i*(1 - 0.1 * (0.1 + 0.1*1.9)) for i in range(4)])
for i,p in enumerate(params): test_close(opt.state[p]['grad_avg'].item(), i*0.19)

Test weight decay, notice how we can see that L2 regularization is different from weight decay even for simple SGD with momentum.

In [None]:
params = tst_params()
#Weight decay
opt = SGD(params, lr=0.1, mom=0.9, wd=0.1)
opt.step()
test_close([p.item() for p in params], [i*0.98 for i in range(4)])
#L2 reg
opt = SGD(params, lr=0.1, mom=0.9, wd=0.1, true_wd=False)
opt.step()
test_close([p.item() for p in params], [i*0.97 for i in range(4)])

### RMSProp

In [None]:
#export
def rms_prop_step(p, lr, sqr_avg, eps, grad_avg=None, **kwargs):
    "Step for SGD with momentum with `lr`"
    denom = sqr_avg.sqrt().add_(eps)
    p.data.addcdiv_(-lr, (grad_avg if grad_avg is not None else p.grad), denom)
    return p

rms_prop_step.defaults = dict(eps=1e-8)

In [None]:
#export
def RMSProp(params, lr, sqr_mom=0.99, mom=0., wd=0., true_wd=True):
    "A `Optimizer` for RMSProp with `lr`, `sqr_mom`, `mom` and `params`"
    steppers = [] if wd==0. else [weight_decay] if true_wd else [l2_reg]
    steppers.append(rms_prop_step)
    stats = [average_sqr_grad] if mom==0. else [average_grad, average_sqr_grad]
    return Optimizer(params, steppers, stats=stats, lr=lr, mom=mom, sqr_mom=sqr_mom, wd=wd)

RMSProp was introduced by Geoffrey Hinton in his [course](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf). What is named `sqr_mom` here is the `alpha` in the course. Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `true_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
#Without momentum
import math
params = tst_param([1,2,3], [0.1,0.2,0.3])
opt = RMSProp(params, lr=0.1)
opt.step()
test_close(params[0], tensor([0.,1.,2.]))
opt.step()
step = - 0.1 * 0.1 / (math.sqrt((0.01*0.99+0.01) * 0.1**2) + 1e-8)
test_close(params[0], tensor([step, 1+step, 2+step]))

In [None]:
#With momentum
params = tst_param([1,2,3], [0.1,0.2,0.3])
opt = RMSProp(params, lr=0.1, mom=0.9)
opt.step()
test_close(params[0], tensor([0.,1.,2.]))
opt.step()
step = - 0.1 * (0.1 + 0.9*0.1) / (math.sqrt((0.01*0.99+0.01) * 0.1**2) + 1e-8)
test_close(params[0], tensor([step, 1+step, 2+step]))

### Adam

In [None]:
#export
def step_stat(state, p, **kwargs):
    "Register the number of steps done in `state` for `p`"
    if 'step' not in state: state['step'] = 0
    state['step'] += 1
    return state

In [None]:
p = tst_param(1,0.1)
state = {}
state = step_stat(state, p)
test_eq(state['step'], 1)
for _ in range(5): state = step_stat(state, p)
test_eq(state['step'], 6)

In [None]:
#export
def _debias(mom, damp, step): return damp * (1 - mom**step) / (1-mom)

def adam_step(p, lr, mom, step, sqr_mom, grad_avg, sqr_avg, eps, **kwargs):
    "Step for Adam with `lr` on `p`"
    debias1 = _debias(mom,     1-mom,     step)
    debias2 = _debias(sqr_mom, 1-sqr_mom, step)
    p.data.addcdiv_(-lr / debias1, grad_avg, (sqr_avg/debias2).sqrt() + eps)
    return p

adam_step._defaults = dict(eps=1e-5)

In [None]:
#export
def Adam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., true_wd=True):
    "A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    steppers = [] if wd==0. else [weight_decay] if true_wd else [l2_reg]
    steppers.append(adam_step)
    stats = [partial(average_grad, dampening=True), average_sqr_grad, step_stat]
    return Optimizer(params, steppers, stats=stats, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

Adam was introduced by Diederik P. Kingma and Jimmy Ba in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980). For consistency accross optimizers, we renamed `beta1` and `beta2` in the paper to `mom` and  `sqr_mom`. Note that our defaults also differ from the paper (0.99 for `sqr_mom` or `beta2`, 1e-5 for `eps`). Those values seem to be better from our experiments in a wide range of situations.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `true_wd=True` else as L2 regularization (add the decay to the gradients).

> Note: Don't forget that `eps` is an hyper-parameter you can change. Some models won't train without a very high `eps` like 0.1 (intuitively, the higher `eps` is, the closer we are to normal SGD). The usual default of 1e-8 is often too extreme in the sense we don't manage to get as good results as with SGD. 

In [None]:
params = tst_param([1,2,3], [0.1,0.2,0.3])
opt = Adam(params, lr=0.1)
opt.step()
step = -0.1 * 0.1 / (math.sqrt(0.1**2) + 1e-8)
test_close(params[0], tensor([1+step, 2+step, 3+step]))
opt.step()
test_close(params[0], tensor([1+2*step, 2+2*step, 3+2*step]), eps=1e-4)

### LARS/LARC

In [None]:
#export
def larc_layer_lr(state, p, lr, trust_coeff, wd, eps, clip=True, **kwargs):
    "Computes the local lr before weight decay is applied"
    p_norm,g_norm = torch.norm(p.data),torch.norm(p.grad.data)
    local_lr = lr*trust_coeff * (p_norm) / (g_norm + p_norm * wd + eps)
    state['local_lr'] = min(lr, local_lr) if clip else local_lr
    return state
larc_layer_lr.defaults = dict(trust_coeff=0.02, wd=0., eps=1e-8)

In [None]:
#export
def larc_step(p, local_lr, grad_avg=None, **kwargs):
    "Step for LARC `local_lr` on `p`"
    p.data.add_(-local_lr, p.grad.data if grad_avg is None else grad_avg)
    return p

In [None]:
#export
def Larc(params, lr, mom=0.9, clip=True, trust_coeff=0.02, eps=1e-8, wd=0., true_wd=True):
    "A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    steppers = [] if wd==0. else [weight_decay] if true_wd else [l2_reg]
    steppers.append(larc_step)
    stats = [] if mom==0. else [average_grad]
    stats.append(partial(larc_layer_lr, clip=clip))
    return Optimizer(params, steppers, stats=stats, lr=lr, mom=mom, trust_coeff=trust_coeff, eps=eps, wd=wd)

The LARS optimizer was first introduced in [Large Batch Training of Convolutional Networks](https://arxiv.org/abs/1708.03888) then refined in its LARC variant (original LARS is with `clip=False`). A learning rate is computed for each individual layer with a certain `trust_coefficient`, then clipped to be always less than `lr`.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `true_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
params = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt = Larc(params, lr=0.1)
opt.step()
#First param local lr is 0.02 < lr so it's not clipped
test_close(opt.state[params[0]]['local_lr'], 0.02)
#Second param local lr is 0.2 > lr so it's clipped
test_eq(opt.state[params[1]]['local_lr'], 0.1)
test_close(params[0], tensor([0.998,1.996,2.994]))
test_close(params[1], tensor([0.999,1.998,2.997]))

In [None]:
params = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]
opt = Larc(params, lr=0.1, clip=False)
opt.step()
#No clipping
test_close(opt.state[params[0]]['local_lr'], 0.02)
test_close(opt.state[params[1]]['local_lr'], 0.2)
test_close(params[0], tensor([0.998,1.996,2.994]))
test_close(params[1], tensor([0.998,1.996,2.994]))

### LAMB

In [None]:
#export
def lamb_step(p, lr, mom, step, sqr_mom, grad_avg, sqr_avg, eps, **kwargs):
    "Step for LAMB with `lr` on `p`"
    debias1 = _debias(mom,     1-mom,     step)
    debias2 = _debias(sqr_mom, 1-sqr_mom, step)
    r1 = p.data.pow(2).mean().sqrt()
    step = (grad_avg/debias1) / ((sqr_avg/debias2).sqrt()+eps)
    r2 = step.pow(2).mean().sqrt()
    q = 1 if r1 == 0 or r2 == 0 else min(r1/r2,10)
    p.data.add_(-lr * q, step)
    return p
lamb_step._defaults = dict(eps=1e-6, wd=0.)

In [None]:
#export
def Lamb(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., true_wd=True):
    "A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    steppers = [] if wd==0. else [weight_decay] if true_wd else [l2_reg]
    steppers.append(lamb_step)
    stats = [partial(average_grad, dampening=True), average_sqr_grad, step_stat]
    return Optimizer(params, steppers, stats=stats, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

LAMB was introduced in [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962). Intuitively, it's LARC applied to Adam. As in `Adam`, we renamed `beta1` and `beta2` in the paper to `mom` and  `sqr_mom`. Note that our defaults also differ from the paper (0.99 for `sqr_mom` or `beta2`, 1e-5 for `eps`). Those values seem to be better from our experiments in a wide range of situations.

Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `true_wd=True` else as L2 regularization (add the decay to the gradients).

In [None]:
params = tst_param([1,2,3], [0.1,0.2,0.3])
opt = Lamb(params, lr=0.1)
opt.step()
test_close(params[0], tensor([0.7840,1.7840,2.7840]), eps=1e-3)

## Export -

In [None]:
#hide
from local.notebook.export import *
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_core.ipynb.
Converted 01a_torch_core.ipynb.
Converted 01b_script.ipynb.
Converted 01c_dataloader.ipynb.
Converted 02_data_transforms.ipynb.
Converted 03_data_pipeline.ipynb.
Converted 05_data_core.ipynb.
Converted 06_data_source.ipynb.
Converted 07_vision_core.ipynb.
Converted 08_pets_tutorial.ipynb.
Converted 09_vision_augment.ipynb.
Converted 11_layers.ipynb.
Converted 11a_vision_models_xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_learner.ipynb.
Converted 14_callback_schedule.ipynb.
Converted 15_callback_hook.ipynb.
Converted 16_callback_progress.ipynb.
Converted 17_callback_tracker.ipynb.
Converted 18_callback_fp16.ipynb.
Converted 19_callback_mixup.ipynb.
Converted 20_metrics.ipynb.
Converted 21_tutorial_imagenette.ipynb.
Converted 30_text_core.ipynb.
Converted 31_text_data.ipynb.
Converted 32_text_models_awdlstm.ipynb.
Converted 33_text_models_core.ipynb.
Converted 34_callback_rnn.ipynb.
Converted 35_tutorial_wikitext.ipynb.
Conve