# Core
> Mimicking [fastai](https://docs.fast.ai) with minimal functionalities.

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from wafer.imports import *
from collections.abc import Mapping

## Utils

In [None]:
#| export
def to_device(xs, device):
    if isinstance(xs, (tuple, list)): return type(xs)(x.to(device=device) for x in list(xs))
    if isinstance(xs, Mapping): return {k:v.to(device=device) for k,v in xs.items()}
    return xs.to(device=device)

def to_detach(xs):
    if isinstance(xs, (tuple, list)): return type(xs)(x.detach() for x in list(xs))
    if isinstance(xs, Mapping): return {k:v.detach() for k,v in xs.items()}
    return xs.detach()

def to_cpu(xs):
    if isinstance(xs, (tuple, list)): return type(xs)(x.cpu() for x in list(xs))
    if isinstance(xs, Mapping): return {k:v.cpu() for k,v in xs.items()}
    return xs.cpu()

def get_device(device=None):
    device = 'cuda' if torch.cuda.is_available() else (device or 'cpu')
    print(f"Using {device} device.")
    return torch.device(device)

In [None]:
#| export
def has_children(m):
    try: next(m.children())
    except StopIteration: return False
    return True

def has_params(m): return len(list(m.parameters())) > 0

def count_params(m):
    "Count trainable parameters."
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

In [None]:
#| export
def get_actn(actn: Union[str, nn.Module], default=Noop()):
    "Get the correct activation."
    if isinstance(actn, str): return getattr(F, actn)
    elif isinstance(actn, nn.Module): return actn
    else: return default

## Callbacks

In [None]:
#| export
class Callback(): order = -1

In [None]:
#| export
class DeviceCB(Callback):
    "Move the model and batch to the correct device for training."
    order = -10
    def before_fit(self): self.learner.model.to(device=self.learner.device)
    def before_batch(self): self.learner.batch = to_device(self.learner.batch, self.learner.device)

In [None]:
#| export
class BatchXfmCB(Callback):
    "Apply a data transform to a batch."
    order = -5
    def __init__(self, xfm): self.xfm = xfm
    def before_batch(self): self.learner.batch = self.xfm(self.learner.batch)

In [None]:
#| export
class MetricCB(Callback):
    "Using metrics from [`torcheval`](https://pytorch.org/torcheval/stable/) as callbacks."
    order = 10
    def __init__(self, metrics: Union[tuple, list],
                 names: Union[str, tuple[str], list[str]]=None,
                 train: bool=False # Compute metrics on train set; default on test set
                ):
        if not isinstance(metrics, (tuple, list)): metrics = [metrics]
        if isinstance(names, str): names = [names]
        if names is not None: assert len(metrics) == len(names), "sizes of `names` and `metrics` do not match."
        else: names = [type(metric).__name__ for metric in metrics]
        self.metrics,self.names,self.train = metrics,names,train
        for metric,name in zip(self.metrics, self.names): metric.name = name
    
    def reset(self):
        for metric in self.metrics: metric.reset()
    
    @property
    def log_df(self):
        return pd.DataFrame(columns=self.names, data=self._log)

    def before_fit(self):
        self.reset()
        self._log = []
        self.learner.metrics = self

    def before_epoch(self): self.reset()

    def before_loss(self):
        if (self.train and self.learner.training) or (not self.train and not self.learner.training):
            for metric in self.metrics:
                metric.update(self.learner.preds.detach().clone().cpu(), self.learner.yb.clone().cpu())
    
    def after_epoch(self):
        self._log.append([metric.compute().item() for metric in self.metrics])

In [None]:
#| export
class ProgressCB(Callback):
    "Log and display training infos."
    order = MetricCB.order + 1

    def before_fit(self):
        self.learner.progress = self
        self._has_no_log = False
        cols = (  (['train_loss'] if self.learner.dls[0] != [] else [])
                + (['test_loss'] if self.learner.dls[1] != [] else [])
                + (self.learner.metrics.names if hasattr(self.learner, 'metrics') else [])
                + (self.learner.extra_log.all_names if hasattr(self.learner, 'extra_log') else [])
               )
        if len(cols) == 0: self._has_no_log = True; return
        self._log = pd.DataFrame(columns=cols)
        self.disp_log = pd.DataFrame(columns=cols)
        self.disp_log_id = None
    
    def before_epoch(self):
        if self._has_no_log: return
        if self.disp_log_id is None: self.disp_log_id = display(self.disp_log, display_id=True)
        self._one_epoch = [[],[]]

    def before_backward(self):
        if self._has_no_log: return
        self._one_epoch[0].append(self.learner.loss.item()) # batch train_loss

    def after_loss(self):
        if self._has_no_log: return
        self._one_epoch[1].append(self.learner.loss.item()) # batch test_loss

    def after_epoch(self):
        if self._has_no_log: return
        self._log.loc[self.learner.epoch] = ([np.mean(o) for o in self._one_epoch if o != []]
                                             + (self.learner.metrics._log[-1] if hasattr(self.learner, 'metrics') else [])
                                             + (self.learner.extra_log._data if hasattr(self.learner, 'extra_log') else [])
                                            )
        if self.learner.epoch % self.learner.disp_every == 0:
            self.disp_log.loc[self.learner.epoch] = self._log.loc[self.learner.epoch]
            self.disp_log_id.update(self.disp_log)
        
    def after_fit(self):
        if self._has_no_log: return
        self.learner.log = self._log

In [None]:
#| export
class BaseLogCB(Callback):
    "Base class. Log extra (other than standard train/test loss and metrics) infos."
    order = ProgressCB.order-1

    def __init__(self, names: list[list],  # Names of logged entries for each `func`
                 funcs: list[callable],    # Funcstions to get the logged entries; each `f()` should take a single input `learner` and outputs a list/tuple of values
                 keep: bool=False          # Keep a local log; logs can be found in `self.learner.log`
                ):
        self.names,self.funcs,self.keep = names,funcs,keep
        self._data = []
    
    @property
    def all_names(self):
        return [oi for o in self.names for oi in o]
    
    @property
    def log_df(self):
        if not self.keep: print('No log found.'); return
        return pd.DataFrame(columns=self.all_names, data=self._log)

    def before_fit(self):
        self.learner.extra_log = self
        if self.keep: self._log = []

    def before_epoch(self): self._data.clear()
    def after_epoch(self):
        for f in self.funcs:
            val = f(self.learner)
            if isinstance(val, (tuple, list)): self._data.extend(val)
            elif isinstance(val, (int, float, bool, str)): self._data.append(val)
            else: raise TypeError('Log function output should be tuple/list/scalar-like/string.')
        if self.keep: self._log.append(self._data.copy())
            

In [None]:
#| export
class ClipGradCB(Callback):
    "Clip the gradient norm."
    def __init__(self, max_norm=1, norm_type=2, error_if_nonfinite=True, **kwargs):
        self._norm = lambda x: nn.utils.clip_grad_norm_(x, max_norm, norm_type, error_if_nonfinite, **kwargs)
    def before_step(self): self._norm(self.learner.model.parameters())

In [None]:
#| export
class LRCB(Callback):
    "Learning rate callback."
    order = ProgressCB.order + 1
    def __init__(self, scheduler, on_batch=False):
        self.scheduler,self.on_batch = scheduler,on_batch
        
    def after_step(self): 
        if self.on_batch:  self.scheduler.step()
            
    def after_epoch(self): 
        if not self.on_batch: self.scheduler.step()

In [None]:
#| export
class ReduceLROnPlateauCB(LRCB):
    "ReduceLROnPlateau callback."
    def __init__(self, scheduler, metric='test_loss'):
        super().__init__(scheduler, on_batch=False)
        self.metric = metric

    def before_fit(self):
        try:
            self.learner.progress._log[self.metric]
        except:
            raise KeyError(f"'{self.metric}' not found.")

    def after_step(self): pass
            
    def after_epoch(self):
        val = self.learner.progress._log[self.metric].iloc[-1]
        self.scheduler.step(val)

In [None]:
#| export
class EarlyStoppingCB(Callback):
    "Refer to `fastai`."
    order = ProgressCB.order + 1

    def __init__(self, monitor: str='test_loss', # name of the value being monitored
                 comp = np.less,                 # numpy comparison operator, e.g. np.less or np.greater
                 min_delta: float=0.,            # minimum distance between the latest value and the best monitored value
                 patience: int=10,               # number of epoches to wait when no improvement of the model
                 init_val = np.inf,              # initial best value
                 reset_best: bool=True,          # reset best at each fit
                ):
        self.monitor,self.comp = monitor,comp
        self.min_delta,self.patience = min_delta,patience
        self._best = self.init_val = init_val
        self.reset_best = reset_best

    def before_fit(self):
        try:
            self.learner.progress._log[self.monitor]
        except:
            raise KeyError(f"'{self.monitor}' to be monitored not found.")
        if self.reset_best: self._best = self.init_val
        self._wait = 0
        
    def after_epoch(self):
        val = self.learner.progress._log.get(self.monitor).iloc[-1]
        if self.comp(val, self._best - self.min_delta): self._best,self._wait = val,0
        else: self._wait += 1
        
        if self._wait >= self.patience:
            print(f'No improvement since epoch {self.learner.epoch-self._wait}: early stopping')
            raise CancelEpochException()

## Hook

In [None]:
#| export
class Hook():
    "From `fastai`. Register a hook to `m` with `func`."
    def __init__(self, m, func, forward=True, detach=True, cpu=True):
        register = m.register_forward_hook if forward else m.register_full_backward_hook
        self.hook,self.func = register(self.hook_func),func
        self.stored,self.removed = None,False
        self.detach,self.cpu = detach,cpu

    def hook_func(self, m, i, o):
        if self.detach: (i,o) = to_detach(i),to_detach(o)
        if self.cpu: (i,o) = to_cpu(i),to_cpu(o)
        self.stored = self.func(m, i, o)

    def remove(self):
        if not self.removed:
            self.hook.remove()
            self.removed = True
    # context manager
    def __enter__(self, *args): return self
    def __exit__(self, *args): self.remove()

In [None]:
#| export
class Hooks():
    "From `fastai`. Register `Hook` to models in `ms`."
    def __init__(self, ms, func, forward=True, detach=True, cpu=True):
        self.hooks = [Hook(m, func, forward, detach, cpu) for m in ms]

    def __getitem__(self,i): return self.hooks[i]
    def __len__(self):       return len(self.hooks)
    def __iter__(self):      return iter(self.hooks)
    @property
    def stored(self):        return [o.stored for o in self]

    def remove(self):
        for h in self.hooks: h.remove()

    def __enter__(self, *args): return self
    def __exit__ (self, *args): self.remove()

## Learner
A list of events:

- `before_fit`: before fitting.
- `before_epoch`: before every epoch.
- `before_train`: before `model.train()`.
- `before_batch`: *just* after a batch is drawn from the current dataloader.
- `before_pred`: after unpacking a batch into `xb,yb`, before making predictions on `xb`.
- `before_loss`: before calculating the loss, after making `preds`.
- `before_backward`: before `optimizer.zero_grad()` and `loss.backward()`.
- `before_step`: before `optimizer.step()`.
- `after_step`: after `optimizer.step()`.
- `after_batch`: after one batch is done.
- `before_test`: before `model.eval()`.
- `after_loss`: after the loss is computed, *only* in the *testing* loop.
- `after_epoch`: after one epoch is done.
- `after_fit`: after fitting.

Attributes accessible to `Callback` via `self.learner` (after adding to a `Learner`)

- `model`: the current model.
- `dls`: the dataloaders (for training and testing).
- `opt`: the optimizer.
- `batch`: the current batch.
- `xb`,`yb`: the current batched inputs and targets.
- `preds`: the predictions on the current `xb`.
- `loss`: the current loss (train/test).
- `n_epochs`: the total number of epochs.
- `epoch`: the current epoch.
- `n_iters`: the current iterations from the beginning.

In [None]:
#| export
class CancelEpochException(Exception): pass

In [None]:
#| export
class Learner():
    "One place to train/test a model."
    _default_cbs = [DeviceCB(), ProgressCB()]
    def __init__(self, model: nn.Module,  # The model to be trained
                 dls: Union[tuple, list], # The dataloaders used for training and testing
                 opt,                     # The optimizer used to update model's parameters
                 loss_func,               # The loss function
                 cbs=[],                  # Callbacks called in `order`
                 disp_every: int=1,       # Display log every `disp_every` epochs
                 device=torch.device('cpu')
                ):
        self.model,self.dls,self.opt,self.device = model,dls,opt,device
        self.disp_every = disp_every
        self.loss_func = loss_func
        self.cbs = sorted(list(cbs) + self._default_cbs, key=lambda o: o.order) # sort callbacks according to their `order`
        for cb in self.cbs: cb.learner  = self
    @property
    def training(self):
        return self.model.training

    def do_one_batch(self, train):
        self('before_batch')
        self.xb,self.yb = self.batch
        self('before_pred')
        self.preds = self.model(self.xb)
        self('before_loss')
        self.loss = self.loss_func(self.preds, self.yb)
        if train:
            self('before_backward')
            self.opt.zero_grad()
            self.loss.backward()
            self('before_step')
            self.opt.step()
            self('after_step')
        else:
            self('after_loss')
        self('after_batch')

    def do_one_epoch(self):
        self('before_epoch')
        ### training ###
        self('before_train')
        # for self.batch in tqdm(self.dls[0], desc='Train', leave=False):
        for self.batch in self.dls[0]:
            self.model.train()
            self.do_one_batch(True)
            self.n_iters += 1
        ### testing ###
        self('before_test')
        # for self.batch in tqdm(self.dls[1], desc='Test', leave=False):
        for self.batch in self.dls[1]:
            self.model.eval()
            with torch.inference_mode():
                self.do_one_batch(False)
        self('after_epoch')

    def fit(self, n_epochs: int=1):
        self('before_fit')
        self.n_epochs = n_epochs
        self.n_iters = 0
        for self.epoch in tqdm(range(n_epochs), desc='Epochs'):
            try: self.do_one_epoch()
            except CancelEpochException: break
        self('after_fit')

    def predict(self, x):
        "Predict on an input instance."
        self.model.eval()
        x = to_device(x, self.device)
        with torch.inference_mode():
            try:
                pred = self.model(x)
            except:
                pred = self.model(x.unsqueeze(0))
            return pred

    def predict_batch(self, xb=None):
        "Predict on a batch."
        if xb is None:
            xb = self.dls[0].one_batch()[0] if self.dls[1] == [] else self.dls[1].one_batch()[0]
        xb = to_device(xb, self.device)
        self.model.eval()
        with torch.inference_mode():
            preds = self.model(xb)
        return preds

    def plot_loss(self, ax=None, figsize=(3,3), title="", logscale=False, skip: int =0):
        if not hasattr(self, 'log'): return
        assert isinstance(skip, int) and skip >= 0, 'skip must be an integer >= 0'
        if ax is None: fig,ax = plt.subplots(figsize=figsize)
        try: ax.plot(self.log['train_loss'].to_numpy()[skip:], c='r', label='train')
        except: pass
        try: ax.plot(self.log['test_loss'].to_numpy()[skip:], c='b', label='test')
        except: pass
        if logscale: ax.set_yscale('log')
        ax.set_xlabel('epoch')
        ax.set_ylabel('loss')
        ax.set_title(title)
        ax.legend(loc=1)
        try: fig.tight_layout()
        except: pass

    def save(self, path:str, add_datetime: bool=True):
        "Save model's state dict and log."
        torch.save(obj=self.model.state_dict(), f=path+'-model.pth')
        if hasattr(self, 'log'):
            self.log.to_csv(path+'-log.csv', index=False)
            print('Model and log all saved.')
        else:
            print('Model saved.')
        
    def load(self, path: str):
        "Load model from saved state dict and log from csv."
        if '-model.pth' in path: path = path.split('-model.pth')[0]
        msg = self.model.load_state_dict(torch.load(path+'-model.pth'))
        print(msg)
        try:
            self.log = pd.read_csv(path+'-log.csv')
            print('Log loaded.')
        except: pass
        
    def __call__(self, name):
        for cb in self.cbs: getattr(cb, name, noop)()

## Data

In [None]:
#| export
class Dataloader(DataLoader):
    "Extension to `torch.utils.data.DataLoader`, to work with huggingface's `Dataset`."
    def __init__(self, dataset, get_xy: callable, **kwargs):
        super().__init__(dataset=dataset, **kwargs)
        self.get_xy = get_xy
    
    def __iter__(self):
        it = super().__iter__()
        def _f():
            for o in it:
                yield tuple(self.get_xy(o))
        return _f()

    def one_batch(self):
        return next(iter(self))

In [None]:
#| export
def mk_dls_from_ds(ds,                        # Huggingface dataset
                   get_xy: callable,          # A function to get (input, target) from a dict 
                   fields=['train', 'test'],  # Dict keys to split the dataset
                   bs=[64, 64],               # Batch sizes
                   shuffle=True,              # Shuffle the training set
                  ):
    "Create train-test dataloaders."
    if not isinstance(ds, (tuple, list)):
        return (Dataloader(ds[fields[0]], get_xy, batch_size=bs[0], shuffle=shuffle),
                Dataloader(ds[fields[1]], get_xy, batch_size=bs[1], shuffle=False))
    else:
        return (Dataloader(ds[0], get_xy, batch_size=bs[0], shuffle=shuffle),
                Dataloader(ds[1], get_xy, batch_size=bs[1], shuffle=False))

In [None]:
#| export
def mk_dls_from_hub(name: str,                 # Name/path of the dataset
                    get_xy: callable,          # A function to get (input, target) from a dict
                    fields=['train', 'test'],  # Dict keys to split the dataset
                    sz=[None, None],           # Sizes of train and test set, if `None` then returns all
                    bs=[64, 64],               # Batch sizes
                    shuffle=True,              # Shuffle the training set
                    device=None
                   ):
    "Conveience method to create train-test dataloaders from huggingface hub."
    ds = load_dataset(name, device=device, trust_remote_code=True)
    train = Dataset.from_dict(ds[fields[0]][:sz[0]]).with_format('torch') if sz[0] is not None else ds[fields[0]].with_format('torch')
    test  = Dataset.from_dict(ds[fields[1]][:sz[1]]).with_format('torch') if sz[1] is not None else ds[fields[1]].with_format('torch')
    return mk_dls_from_ds([train, test], get_xy, bs=bs, shuffle=shuffle)

In [None]:
#| export
class Scaler():
    "Simple wrapper of `sklearn.preprocessing`'s scalers."
    def __init__(self, scaler, data: Union[list, np.ndarray],
                 shaper = lambda o: np.reshape(o, (-1, np.shape(o)[-1])),  # function to reshape the `data` into (num_sample, num_feature)
                 **kwargs):
        self.shaper = shaper
        self.scaler = scaler.fit(self.shaper(data), **kwargs)
    
    def xfm(self, data):
        _shape = np.shape(data)
        return self.scaler.transform(self.shaper(data)).reshape(_shape)
    
    def inv_xfm(self, data):
        _shape = np.shape(data)
        return self.scaler.inverse_transform(self.shaper(data)).reshape(_shape)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()