In [None]:
import torch
from fastprogress import master_bar, progress_bar

In [None]:
def default_device():
    is_cuda = torch.cuda.is_available()
    return torch.device("cuda" if is_cuda else "cpu")


class ToDeviceCB:
    def __init__(self, device=default_device()): self.device = device

    def before_batch(self, learner):
        learner.xb, learner.yb = learner.xb.to(self.device), learner.yb.to(self.device)


class with_cbs:
    def __init__(self, nm): self.nm = nm

    def __call__(self, f):
        def _f(o, *args, **kwargs):
            o.callback(f'before_{self.nm}')
            f(o, *args, **kwargs)
            o.callback(f'after_{self.nm}')
        return _f

DEFAULT_CBS = [ToDeviceCB()]

class Learner:
    def __init__(self, model, dl_train, dl_test, n_epoch, optimizer=None, loss_fn=F.mse_loss, lr=1e-3, cbs=DEFAULT_CBS):
        self.model, self.dl_train, self.dl_test, self.n_epoch, self.optimizer, self.loss_fn, self.lr, self.cbs \
            = model, dl_train, dl_test, n_epoch, optimizer, loss_fn, lr, cbs

    def fit(self):
        self.mb = master_bar(range(self.n_epoch))
        for i in self.mb:
            self.one_epoch(True)
            print(f"Epoch {i}: {self.loss.item()}")
            torch.no_grad()(self.one_epoch)(False)
            print(f"Test loss: {self.loss.item()}")

    def one_epoch(self, training):
        self.model.train(training)
        dl = self.dl_train if training else self.dl_test
        for b in progress_bar(dl, parent=self.mb, total=len(dl.dataset)):
            self.xb, self.yb = b
            self.one_batch()

    @with_cbs("batch")
    def one_batch(self):
        y_hat = self.model(self.xb)
        self.loss = self.loss_fn(y_hat, self.yb)
        if self.model.training:
            self.loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

    def callback(self, nm):
        if self.cbs is not None:
            for cb in self.cbs: getattr(cb, nm)(self)