In [None]:
# | default_exp training

In [None]:
# | export
from pathlib import Path

import torch
import torch.nn.functional as F
from fastprogress import master_bar, progress_bar
from torch import optim

In [None]:
# | export
class with_cbs:
    def __init__(self, nm): self.nm = nm

    def __call__(self, f):
        def _f(o, *args, **kwargs):
            o.callback(f'before_{self.nm}')
            f(o, *args, **kwargs)
            o.callback(f'after_{self.nm}')
        return _f


class Learner:
    def __init__(self, model, dl_train, dl_valid, n_epoch, optimizer=optim.Adam, loss_fn=F.mse_loss, lr=1e-3, wd=0.0, cbs=None):
        self.model, self.dl_train, self.dl_valid, self.n_epoch, self.loss_fn, self.cbs \
            = model, dl_train, dl_valid, n_epoch, loss_fn, cbs
        self.optimizer = optimizer(model.parameters(), lr=lr, weight_decay=wd)

    @with_cbs("fit")
    def fit(self, lr=None, n_epoch=None):
        # update params if provided
        if lr is not None: self.optimizer.param_groups[0]['lr'] = lr
        n_epoch = n_epoch or self.n_epoch

        self.train_step = 0
        self.mb = master_bar(range(n_epoch))
        for i in self.mb:
            self.one_epoch(True)
            self.mb.write(f"Epoch {i} Loss: {self.loss.item()}")
            torch.no_grad()(self.one_epoch)(False)
            self.mb.write(f"Valid Loss: {self.loss.item()}")
            
    def predict(self, dl=None):
        dl = dl or self.dl_valid
        self.model.eval()
        preds = []
        for b in progress_bar(dl, total=len(dl)):
            xb, _ = b
            preds.append(self.model(*xb).detach().cpu())
        return torch.cat(preds)

    def one_epoch(self, training):
        self.model.train(training)
        dl = self.dl_train if training else self.dl_valid
        for b in progress_bar(dl, parent=self.mb, total=len(dl)):
            self.xb, self.yb = b
            self.one_batch()
            self.mb.child.comment = f"Loss: {self.loss.item()}"

    @with_cbs("batch")
    def one_batch(self):
        y_hat = self.model(*self.xb)
        self.loss = self.loss_fn(y_hat, self.yb)
        if self.model.training:
            self.loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.train_step += 1

    def callback(self, nm):
        if self.cbs is not None:
            for cb in self.cbs:
                if hasattr(cb, nm): getattr(cb, nm)(self)

In [None]:
# | export

def default_device():
    is_cuda = torch.cuda.is_available()
    return torch.device("cuda" if is_cuda else "cpu")


class ToDeviceCB:
    def __init__(self, device=default_device()): self.device = device

    def before_fit(self, learn):
        self.model = learn.model.to(self.device)

    def before_batch(self, learn):
        learn.xb, learn.yb = learn.xb.to(self.device), learn.yb.to(self.device)


class FnCallback:
    def __init__(self, nm): self.nm = nm

    def __call__(self, fn):
        setattr(self, self.nm, fn)
        return self


class SaveModelCB:
    def __init__(self, save_dir, save_every_n=1000):
        self.dir, self.save_every_n = Path(save_dir), save_every_n
        self.dir.mkdir(parents=True, exist_ok=True)

    def after_batch(self, learn):
        if learn.train_step % self.save_every_n == 0:
            torch.save(learn.model.state_dict(), self.dir / f"model_{learn.train_step}.pt")

    def after_fit(self, learn):
        torch.save(learn.model.state_dict(), self.dir / f"model_last.pt")


class WandbCB:
    def __init__(self, config):
        try: import wandb; wandb.require("core")
        except ImportError: raise ImportError("Please install wandb to use this callback")
        self.wandb = wandb
        self.cfg = config
        self.wcfg = config.wandb

    def before_fit(self, learn):
        learn.stats = {}
        self.run = self.wandb.init(project=self.wcfg.project,
                        mode="disabled" if self.cfg.debug_mode else "online")
        self.run.config = self.cfg

    def after_batch(self, learn):
        if learn.train_step % self.wcfg.log_every_steps == 0:
            self.run.log(
                {
                    "loss": learn.loss,
                    "lr": learn.optimizer.param_groups[0]['lr'],
                    **learn.stats
                },
                step=learn.train_step,
            )

    def after_fit(self, _): self.run.finish()
    

DEFAULT_CBS = [ToDeviceCB(), SaveModelCB("models")]

In [None]:
# | export
class CurriculumCB:
    def __init__(self, curriculum_config):
        self.cfg = curriculum_config

    def get_params(self, name, step):
        cfg = self.cfg[name]
        return min(cfg.end, cfg.start + step // cfg.interval * cfg.inc)
    
    def update_task(self, learn, ds):
        n_dims = self.get_params("dims", learn.train_step)
        ds.task.truncated_dims = ds.task.n_dim - n_dims
        ds.task.n_points = self.get_params("points", learn.train_step)
        if hasattr(learn, 'stats'):
            learn.stats['n_dims'] = n_dims
            learn.stats['n_points'] = ds.task.n_points

    def after_batch(self, learn):
        if learn.model.training: self.update_task(learn, learn.dl_train.dataset)
        else: self.update_task(learn, learn.dl_valid.dataset)

In [None]:
from hydra import compose, initialize

from looped_experiments.models import Transformer
from looped_experiments.tasks import *

In [None]:
with initialize(config_path="../configs", version_base=None):
    cfg = compose("base")

In [None]:
train = cfg.training
task = LinearRegression(train.batch_size, **cfg.task)
dl_train = dataloader(task, train.train_steps)
dl_eval = dataloader(task, train.eval_steps)

In [None]:
torch.manual_seed(42)
model = Transformer(cfg.model)

In [None]:
@FnCallback("before_batch")
def trans_input(learner):  learner.xb = (learner.xb, learner.yb)

cbs = [WandbCB(cfg), ToDeviceCB(), CurriculumCB(cfg.training.curriculum), trans_input]
learn = Learner(model, dl_train, dl_eval, cfg.training.n_epoch, cbs=cbs)
learn.fit(lr=cfg.training.learning_rate)