In [None]:
# | default_exp training

In [None]:
# | export
import time
from pathlib import Path

import torch
import torch.nn.functional as F
from fastprogress import master_bar, progress_bar
from omegaconf import OmegaConf
from torch import optim

from looped_experiments.utils import def_device, to_cpu, to_device

In [None]:
# | export
class with_cbs:
    def __init__(self, nm): self.nm = nm

    def __call__(self, f):
        def _f(o, *args, **kwargs):
            o.callback(f'before_{self.nm}')
            f(o, *args, **kwargs)
            o.callback(f'after_{self.nm}')
        return _f


class Learner:
    def __init__(self, model, dl_train, dl_valid, n_epoch, optimizer=optim.Adam, loss_fn=F.mse_loss, lr=1e-3, wd=0.0, cbs=None):
        self.model, self.dl_train, self.dl_valid, self.n_epoch, self.loss_fn, self.cbs \
            = model, dl_train, dl_valid, n_epoch, loss_fn, cbs
        self.optimizer = optimizer(model.parameters(), lr=lr, weight_decay=wd)
        self.mb = None

    @with_cbs("fit")
    def fit(self, lr=None, n_epoch=None):
        # update params if provided
        if lr is not None: self.optimizer.param_groups[0]['lr'] = lr
        n_epoch = n_epoch or self.n_epoch

        self.train_step = 0
        self.mb = master_bar(range(n_epoch))
        for i in self.mb:
            self.one_epoch(True, self.dl_train)
            self.mb.write(f"Epoch {i} Loss: {self.loss.item()}")
            torch.no_grad()(self.one_epoch)(False, self.dl_valid)
            self.mb.write(f"Valid Loss: {self.loss.item()}")
    
    @with_cbs("fit")
    def eval(self, dl=None):
        dl = dl or self.dl_valid
        torch.no_grad()(self.one_epoch)(False, dl)
    
    @with_cbs("epoch")
    def one_epoch(self, training, dl):
        self.model.train(training)
        for b in progress_bar(dl, parent=self.mb, total=len(dl)):
            self.xb, self.yb = b
            self.one_batch()
            if self.mb:
                self.mb.child.comment = f"Loss: {self.loss.item()}"

    @with_cbs("batch")
    def one_batch(self):
        self.preds = self.model(*self.xb)
        self.loss = self.loss_fn(self.preds, self.yb)
        if self.model.training:
            self.loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.train_step += 1

    def callback(self, nm):
        if self.cbs is not None:
            for cb in sorted(self.cbs):
                if hasattr(cb, nm): getattr(cb, nm)(self)

In [None]:
# | export


class Callback:
    order = 0
    def __gt__(self, other): return self.order > other.order


class ToDeviceCB(Callback):
    def __init__(self, device=def_device): self.device = device

    def before_fit(self, learn):
        self.model = learn.model.to(self.device)

    def before_batch(self, learn):
        learn.xb, learn.yb = to_device(learn.xb, self.device), to_device(learn.yb, self.device)


class FnCallback(Callback):
    order = 1
    def __init__(self, nm): self.nm = nm

    def __call__(self, fn):
        setattr(self, self.nm, fn)
        return self


class SaveModelCB(Callback):
    order = 1

    def __init__(self, save_dir, save_every_n=1000, max_to_keep=5, save_training=True):
        self.dir, self.save_every_n, self.max_to_keep = Path(save_dir), save_every_n, max_to_keep
        self.save_training = save_training
        self.dir.mkdir(parents=True, exist_ok=True)

    def after_batch(self, learn):
        if learn.train_step % self.save_every_n == 0:
            if self.save_training:
                torch.save({
                    'step': learn.train_step,
                    'model_state_dict': learn.model.state_dict(),
                    'optimizer_state_dict': learn.optimizer.state_dict()},
                    self.dir / f"model_{learn.train_step}+train.pt")
            else: torch.save(learn.model.state_dict(), self.dir / f"model_{learn.train_step}.pt")
        files = sorted(self.dir.glob("model_*.pt"), key=lambda x: x.stem.split("_")[1], reverse=True)
        if len(files) > self.max_to_keep:
            for f in files[self.max_to_keep:]: f.unlink()

    def after_fit(self, learn):
        torch.save(learn.model.state_dict(), self.dir / f"model_last.pt")


class TimerCB(Callback):
    def before_fit(self, learn):
        if not hasattr(learn, "stats"):
            learn.stats = {}

    def before_batch(self, learn):
        if learn.model.training:
            self.start = time.time()

    def after_batch(self, learn):
        if learn.model.training:
            end = time.time()
            learn.stats['steps_per_s'] = 1 / (end - self.start)


class WandbCB(Callback):
    order = SaveModelCB.order + 10

    def __init__(self, config):
        try: import wandb; wandb.require("core")
        except ImportError: raise ImportError("Please install wandb to use this callback")
        self.wandb = wandb
        self.cfg = config
        self.wcfg = config.wandb

    def before_fit(self, learn):
        learn.stats = {}
        self.run = self.wandb.init(project=self.wcfg.project,
                                   name=self.wcfg.name,
                                   config=OmegaConf.to_container(self.cfg, resolve=True),
                                   mode="disabled" if self.cfg.debug_mode else "online")

    def after_batch(self, learn):
        if learn.train_step % self.wcfg.log_every_steps == 0:
            if learn.model.training:
                self.run.log({
                    "loss": learn.loss,
                    "lr": learn.optimizer.param_groups[0]['lr'],
                    **learn.stats
                }, step=learn.train_step)
            else:
                self.run.log({"valid_loss": learn.loss}, step=learn.train_step)

    def after_fit(self, learn):
        for cb in learn.cbs:
            if isinstance(cb, SaveModelCB):
                self.run.save(str(cb.dir / "model_last.pt"))
        self.run.finish()


def repr_cbs(cbs):
    return " ".join(cb.__class__.__name__ for cb in cbs)

Curriculum for tasks and loops

In [None]:
class CurriculumCB(Callback):
    def __init__(self, curriculum_config):
        self.cfg = curriculum_config

    def get_params(self, name, step):
        cfg = self.cfg[name]
        return min(cfg.end, cfg.start + step // cfg.interval * cfg.inc)

    def update_task(self, learn, ds):
        ds.task.n_points = self.get_params("points", learn.train_step)
        n_dims = self.get_params("dims", learn.train_step)
        ds.task.truncated_dims = ds.task.n_dims - n_dims
        if hasattr(learn, 'stats'):
            learn.stats['n_points'] = ds.task.n_points
            learn.stats['n_dims'] = n_dims

    def after_batch(self, learn):
        if learn.model.training: self.update_task(learn, learn.dl_train.dataset)
        else: self.update_task(learn, learn.dl_valid.dataset)


class LoopCB(CurriculumCB):
    def update_model(self, learn):
        learn.model.n_loops = self.get_params("loops", learn.train_step)
        if hasattr(learn, 'stats'):
            learn.stats['n_loops'] = learn.model.n_loops

    def after_batch(self, learn): self.update_model(learn)

In [None]:
from looped_experiments.models import get_loss, get_model
from looped_experiments.tasks import LinearRegression, dataloader
from looped_experiments.utils import get_config

In [None]:
cfg = get_config(overrides=["model=loop", "training=loop"])

In [None]:
train = cfg.training
task = LinearRegression(train.batch_size, **cfg.task)
dl_train = dataloader(task, train.train_steps)
dl_eval = dataloader(task, train.eval_steps)

In [None]:
@FnCallback("before_batch")
def trans_input(learner): learner.xb = (learner.xb, learner.yb)


torch.manual_seed(1)

model = get_model(cfg.model)
cbs = [ToDeviceCB(), CurriculumCB(cfg.training.curriculum), trans_input, LoopCB(cfg.model.curriculum)]
learn = Learner(model, dl_train, dl_eval, cfg.training.n_epoch, loss_fn=get_loss(cfg.model), cbs=cbs)

In [None]:
learn.fit(lr=cfg.training.learning_rate)