In [None]:
# all_slow

In [None]:
!pip install nbdev

In [None]:
!pip install wwf -q

[K     |████████████████████████████████| 194kB 21.9MB/s 
[K     |████████████████████████████████| 245kB 31.7MB/s 
[K     |████████████████████████████████| 61kB 8.4MB/s 
[?25h

In [None]:
from wwf.utils import state_versions

state_versions(['wwf', 'fastai', 'fastcore', 'nbdev'])


---
This article is also a Jupyter Notebook available to be run from the top down. There
will be code snippets that you can then run in any environment.

Below are the versions of `wwf`, `fastai`, `fastcore`, and `nbdev` currently running at the time of writing this:
* `wwf`: 0.0.13 
* `fastai`: 2.2.5 
* `fastcore`: 1.3.19 
* `nbdev`: 1.1.12 
---

# Goals for today:

Look at 2-3 Callbacks:

- `ShortEpochCallback`
- `SaveModelCallback`
- Teacher/Student (or advanced model inputs) example by [@goralpl](https://github.com/goralpl/learning_fastai/blob/master/seq2seq_fastai_datablocks_custom_model.ipynb)

In [None]:
from fastai.vision.all import *

In [None]:
from fastai.test_utils import synth_dbunch, synth_learner

In [None]:
from wwf.basics.training_loop import *

In [None]:
learn = synth_learner()

In [None]:
doc(Learner.show_training_loop)

In [None]:
learn.train.device

In [None]:
learn.show_training_loop(verbose=True)

Start Fit
    - before_fit:
        - TrainEvalCallback: 
            - Set the iter and epoch counters to 0, put the model and the right device
        - Recorder: 
            - Prepare state for training
        - ProgressCallback: 
            - Setup the master bar over the epochs
   Start Epoch Loop
       - before_epoch:
           - Recorder: 
               - Set timer if `self.add_time=True`
           - ProgressCallback: 
               - Update the master bar
      Start Train
          - before_train:
              - TrainEvalCallback: 
                  - Set the model in training mode
              - Recorder: 
                  - Reset loss and metrics state
              - ProgressCallback: 
                  - Launch a progress bar over the training dataloader
         Start Batch Loop
             - before_batch:
             - after_pred:
             - after_loss:
             - before_backward:
             - before_step:
             - after_step:
             - 

In [None]:
from fastai.vision.all import *

In [None]:
class ShortEpochCallback(Callback):
    "Fit just `pct` of an epoch, then stop"
    def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid
    def after_batch(self):
        if self.iter/self.n_iter < self.pct: return
        if self.training:    raise CancelTrainException
        if self.short_valid: raise CancelValidException

In [None]:
path = untar_data(URLs.PETS)/'images'
dls = ImageDataLoaders.from_name_func(
    path, get_image_files(path), valid_pct=0.2,
    label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))

In [None]:
learn = cnn_learner(dls, resnet18, metrics=accuracy, cbs=[ShortEpochCallback()])

In [None]:
learn.fit(1)

epoch,train_loss,valid_loss,accuracy,time
0,00:03,,,


In [None]:
learn.fit_one_cycle(4, cbs=[SaveModelCallback()])

In [None]:
doc(GradientAccumulation)

In [None]:
class TrackerCallback(Callback):
    "A `Callback` that keeps track of the best value in `monitor`."
    order,remove_on_fetch = 60,True
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., reset_on_fit=True):
        if comp is None: comp = np.less if 'loss' in monitor or 'error' in monitor else np.greater
        if comp == np.less: min_delta *= -1
        self.monitor,self.comp,self.min_delta,self.reset_on_fit,self.best= monitor,comp,min_delta,reset_on_fit,None

    def before_fit(self):
        "Prepare the monitored value"
        self.run = not hasattr(self, "lr_finder") and not hasattr(self, "gather_preds")
        if self.reset_on_fit or self.best is None: self.best = float('inf') if self.comp == np.less else -float('inf')
        assert self.monitor in self.recorder.metric_names[1:]
        self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)

    def after_epoch(self):
        "Compare the last value to the best up to now"
        val = self.recorder.values[-1][self.idx]
        if self.comp(val - self.min_delta, self.best): self.best,self.new_best = val,True
        else: self.new_best = False

    def after_fit(self): self.run=True

In [None]:
class SaveModelCallback(TrackerCallback):
    "A `TrackerCallback` that saves the model's best during training and loads it at the end."
    _only_train_loop = True
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., fname='model', every_epoch=False,
                 with_opt=False, reset_on_fit=True):
        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)
        # keep track of file path for loggers
        self.last_saved_path = None
        store_attr('fname,every_epoch,with_opt')

    def _save(self, name): self.last_saved_path = self.learn.save(name, with_opt=self.with_opt)

    def after_epoch(self):
        "Compare the value monitored to its best score and save if best."
        if self.every_epoch: self._save(f'{self.fname}_{self.epoch}')
        else: #every improvement
            super().after_epoch()
            if self.new_best:
                print(f'Better model found at epoch {self.epoch} with {self.monitor} value: {self.best}.')
                self._save(f'{self.fname}')

    def after_fit(self, **kwargs):
        "Load the best model."
        if not self.every_epoch: self.learn.load(f'{self.fname}', with_opt=self.with_opt)

In [None]:
import pdb

In [None]:
class ReduceLROnPlateau(TrackerCallback):
    "A `TrackerCallback` that reduces learning rate when a metric has stopped improving."
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1, factor=10., min_lr=0, reset_on_fit=True):
        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)
        self.patience,self.factor,self.min_lr = patience,factor,min_lr

    def before_fit(self): self.wait = 0; super().before_fit()
    def after_epoch(self):
        "Compare the value monitored to its best score and reduce LR by `factor` if no improvement."
        super().after_epoch()
        if self.new_best: self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                old_lr = self.opt.hypers[-1]['lr']
                for h in self.opt.hypers: h['lr'] = max(h['lr'] / self.factor, self.min_lr)
                self.wait = 0
                if self.opt.hypers[-1]["lr"] < old_lr:
                    print(f'Epoch {self.epoch}: reducing lr to {self.opt.hypers[-1]["lr"]}')

In [None]:
class EarlyStoppingCallback(TrackerCallback):
    "A `TrackerCallback` that terminates training when monitored quantity stops improving."
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1, reset_on_fit=True):
        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)
        self.patience = patience

    def before_fit(self): self.wait = 0; super().before_fit()
    def after_epoch(self):
        "Compare the value monitored to its best score and maybe stop training."
        super().after_epoch()
        if self.new_best: self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                print(f'No improvement since epoch {self.epoch-self.wait}: early stopping')
                raise CancelFitException()

In [None]:
class DebuggerCallback(Callback):

    def after_pred(self):
        import pdb; pdb.set_trace()
    def after_loss(self):
        

In [None]:
learn = cnn_learner(dls, resnet18)

In [None]:
learn.fit(1, cbs=[DebuggerCallback()])

epoch,train_loss,valid_loss,time


--Return--
> <ipython-input-29-f982048ad433>(3)after_pred()->None
-> import pdb; pdb.set_trace()
(Pdb) self.pred
tensor([[ 0.9407, -0.3457],
        [ 2.5843,  1.1408],
        [-3.4780, -0.8390],
        [ 5.6596,  1.1011],
        [ 0.4419,  0.8785],
        [-0.5365, -0.6911],
        [-0.0663,  0.0168],
        [-2.9532, -1.2067],
        [-3.1945,  0.2682],
        [-3.4333, -0.3774],
        [ 1.0653,  0.3292],
        [-0.6266,  2.1911],
        [-0.7886, -1.1957],
        [ 2.0624, -1.7225],
        [ 1.9406, -0.3745],
        [-4.3735,  1.1237],
        [-0.0839, -2.6352],
        [ 1.1814,  2.2882],
        [ 1.5963, -0.0936],
        [-1.4546,  2.1967],
        [-0.8374,  1.4447],
        [-0.9292,  0.0943],
        [ 1.9357, -0.1770],
        [ 1.2264, -4.6810],
        [ 1.7194, -1.5016],
        [-0.9380, -1.6953],
        [ 2.0211, -2.0814],
        [ 1.0008,  0.4885],
        [-1.0704,  2.3088],
        [-2.7820, -1.1778],
        [-0.2106,  1.8410],
        [ 0.0425,  

BdbQuit: ignored

In [None]:
learn.fit(cbs=[SaveModelCallback()])
learn.fit()

In [None]:
learn.fit()
learn.fit()