# Optimizer tweaks

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
from exp.nb_08 import *

## Imagenette data

In [4]:
path = datasets.untar_data(datasets.URLs.IMAGENETTE_160)

In [5]:
tfms = [make_rgb, ResizeFixed(128), to_byte_tensor, to_float_tensor]
bs = 128

il = ImageList.from_files(path, tfms=tfms)
sd = SplitData.split_by_func(il, partial(grandparent_splitter, valid_name='val'))
ll = label_by_func(sd, parent_labeler, proc_y=CategoryProcessor())
data = ll.to_databunch(bs, c_in=3, c_out=10, num_workers=4)

In [6]:
nfs = [32, 64, 128, 256]

In [7]:
cbfs = [partial(AvgStatsCallback, accuracy),
        CudaCallback,
        partial(BatchTransformXCallback, norm_imagenette)]

In [8]:
learn, run = get_learn_run(nfs, data, .4, conv_layer, cbs=cbfs)

In [9]:
run.fit(1, learn)

train: [1.6899575383899488, tensor(0.4159, device='cuda:0')]
valid: [1.348248046875, tensor(0.5620, device='cuda:0')]


## Refining the optimizer

In [10]:
class Optimizer():
    def __init__(self, params, steppers, **defaults):
        self.param_groups = list(params)
        if not isinstance(self.param_groups[0], list):
            self.param_groups = [self.param_groups]
        self.hypers = [{**defaults} for _ in self.param_groups]
    def grad_params(self):
        return [(p, hyper) for pg, hyper in zip(self.param_groups, self.hypers)
               for p in pg if p.grad is not None]
    def zero_grad(self):
        for p, _ in self.grad_params():
            p.grad.detach_()
            p.grad.zero_()
            
    def step(self):
        for p, hyper in self.grad_params():
            compose(p, self.steppers, **hyper)

In [11]:
def sgd_step(p, lr, **kwargs):
    p.data.add_(-lr, p.grad.data)
    return p

In [12]:
opt_func = partial(Optimizer, steppers=[sgd_step])

In [None]:
class Recorder(Callback):
    # It records the learning rate and the losses at
    # each iteration.
    
    def begin_fit(self): self.lrs, self.losses = [], []
        
    def after_batch(self):
        if not self.in_train: return
        self.lrs.append(self.opt.hypers[-1]['lr'])
        self.losses.append(self.loss.detach().cpu())
    
    def plot_lr
    def plot_loss
    def plot
    
class ParamScheduler(Callback):
    _order = 1
    def __init__
    def begin_batch
    
        