In [None]:
#|default_exp callback.tracker

In [None]:
#|exporti
# Contains code from:
# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai

# Tracking Callbacks
> Additional callbacks which make decisions depending how a monitored metric/loss behaves.

In [None]:
#|export
from __future__ import annotations

from fastai.callback.core import Callback, CancelFitException
from fastai.callback.tracker import SaveModelCallback, TrackerCallback

from fastxtend.imports import *

In [None]:
#|hide
import tempfile
from pathlib import Path
from fastxtend.test_utils import *

## TerminateOnTrainNaN -

In [None]:
#|export
class TerminateOnTrainNaN(Callback):
    "A `Callback` that terminates training if the training loss is NaN and ignores valid loss."
    order, run_valid = -9, False
    def after_batch(self):
        "Test if `last_loss` is NaN and interrupts training."
        if torch.isinf(self.loss) or torch.isnan(self.loss): raise CancelFitException

In [None]:
#|hide
learn = synth_learner()
learn.fit(10, lr=100, cbs=TerminateOnTrainNaN())

assert len(learn.recorder.losses) < 10 * len(learn.dls.train)
for l in learn.recorder.losses:
    assert not torch.isinf(l) and not torch.isnan(l) 

epoch,train_loss,valid_loss,time


## SaveModelAtEndCallback -

In [None]:
#|export
class SaveModelAtEnd(SaveModelCallback):
    "A `SaveModelCallback` which only saves the model at the end so loggers can find it."
    order = TrackerCallback.order+1
    def __init__(self,
        fname='model', # Model filename
        with_opt=False # Include optimizer state
    ):
        # keep track of file path for loggers
        self.last_saved_path = None
        store_attr()

    def before_fit(self):
        pass

    def after_epoch(self):
        pass

    def after_fit(self, **kwargs):
        self.last_saved_path = self.learn.save(f'{self.fname}', with_opt=self.with_opt)

In [None]:
#|hide
with no_random():
    tmp_d = tempfile.TemporaryDirectory()
    tmp_p = Path(tmp_d.name)
    learn = synth_learner(n_trn=2, path=tmp_p)
    learn.fit(n_epoch=2, cbs=SaveModelAtEnd())
    assert (tmp_p/'models/model.pth').exists()
    tmp_d.cleanup()

epoch,train_loss,valid_loss,time
0,15.257207,11.782478,00:00
1,15.102713,11.396034,00:00


In [None]:
#|export
class LastMetricCallback(Callback):
    "A `Callback` which stores the last metric(s) value by name (or all if None) in the `Learner.lastmetric` dictionary"
    order,remove_on_fetch,_only_train_loop = 60,True,True
    def __init__(self, metrics:Listified[str]|None=None):
        self._all_metrics = metrics is None
        self._last_metrics=L(metrics)

    def before_fit(self):
        "Prepare the monitored value(s)"
        self.run = not hasattr(self, "lr_finder") and not hasattr(self, "gather_preds")
        self.idx, self.learn.lastmetric = [], {}
        if self._all_metrics:
            self._last_metrics = L([m for m in self.recorder.metric_names[1:] if m !='time'])
        for m in self._last_metrics:
            assert m in self.recorder.metric_names[1:], f'Metric {m} does not exist'
            self.idx.append(list(self.recorder.metric_names[1:]).index(m))

    def after_fit(self):
        "Store the last the monitored value(s)"
        for i, idx in enumerate(self.idx):
            self.learn.lastmetric[self._last_metrics[i]] = self.recorder.values[-1][idx]
        self.run = True

    def after_fit_exception(self):
        try:
            self.after_fit()
        finally:
            self.run = True

In [None]:
#|hide
with no_random():
    learn = synth_learner(n_trn=2, path=tmp_p)
    learn.fit(n_epoch=2, cbs=LastMetricCallback())
    print(learn.lastmetric)

epoch,train_loss,valid_loss,time
0,15.257207,11.782478,00:00
1,15.102713,11.396034,00:00


{'valid_loss': 11.396034240722656}
