In [None]:
#|default_exp callback.tracker

In [None]:
#|export
from __future__ import annotations

from fastai.callback.core import Callback, CancelFitException
from fastai.callback.tracker import SaveModelCallback, TrackerCallback

from fastxtend.imports import *

# Tracking Callbacks
> Additional callbacks which make decisions depending how a monitored metric/loss behaves.

In [None]:
#|hide
import tempfile
from pathlib import Path
from nbdev.showdoc import *
from fastxtend.test_utils import *

## TerminateOnTrainNaN -

In [None]:
#|export
class TerminateOnTrainNaN(Callback):
    "A `Callback` that terminates training if the training loss is NaN and ignores valid loss."
    order,run_valid=-9,False
    def after_batch(self):
        "Test if `last_loss` is NaN and interrupts training."
        if torch.isinf(self.loss) or torch.isnan(self.loss): raise CancelFitException

In [None]:
#|hide
learn = synth_learner()
learn.fit(10, lr=100, cbs=TerminateOnTrainNaN())

assert len(learn.recorder.losses) < 10 * len(learn.dls.train)
for l in learn.recorder.losses:
    assert not torch.isinf(l) and not torch.isnan(l) 

epoch,train_loss,valid_loss,time


## SaveModelAtEndCallback -

In [None]:
#|export
class SaveModelAtEnd(SaveModelCallback):
    "A `SaveModelCallback` which only saves the model at the end so loggers can find it."
    order = TrackerCallback.order+1
    def __init__(self,
        fname='model', # Model filename
        with_opt=False # Include optimizer state
    ):
        # keep track of file path for loggers
        self.last_saved_path = None
        store_attr()

    def before_fit(self):
        pass

    def after_epoch(self):
        pass

    def after_fit(self, **kwargs):
        self.last_saved_path = self.learn.save(f'{self.fname}', with_opt=self.with_opt)

In [None]:
#|hide
with no_random():
    tmp_d = tempfile.TemporaryDirectory()
    tmp_p = Path(tmp_d.name)
    learn = synth_learner(n_trn=2, path=tmp_p)
    learn.fit(n_epoch=2, cbs=SaveModelAtEnd())
    assert (tmp_p/'models/model.pth').exists()
    tmp_d.cleanup()

epoch,train_loss,valid_loss,time
0,15.257207,11.782478,00:00
1,15.102713,11.396034,00:00
