In [None]:
#|default_exp callback.mesa

In [None]:
#|export
from __future__ import annotations
from types import FunctionType

try:
    import timm
except ImportError:
    raise ImportError("timm is required to use MESACallback. Install via `pip install timm`.")

from timm.utils.model_ema import ModelEmaV2

from fastai.callback.core import Callback
from fastai.callback.mixup import reduce_loss
from fastai.callback.fp16 import MixedPrecision
from fastai.layers import NoneReduce

from fastxtend.multiloss import MultiLoss, MultiLossCallback, MixHandlerX
from fastxtend.imports import *

In [None]:
#hide
from nbdev.showdoc import *
from fastxtend.test_utils import *

# Memory-Efficient Sharpness-Aware Callback
> First pass at a callback to add [Memory-Efficient Sharpness-Aware](https://arxiv.org/abs/2205.14083) training to fastai. EMA implementation from [timm](https://github.com/rwightman/pytorch-image-models/blob/master/timm/utils/model_ema.py).

In [None]:
#|export
class MESALoss(MultiLoss):
    def __init__(self,
        orig_loss:nn.Module|FunctionType,
        temp:Number=3, # Soften MESA targets by this temperature. τ in paper
        weight:float=0.8, # Weight of MESA loss. λ in paper
        reduction:str='mean' # PyTorch loss reduction
    ):
        store_attr(but='reduction')
        self.temp = 1/temp
        if hasattr(self.orig_loss, 'reduction'): self.orig_loss.reduction = reduction
        else: self.orig_loss = partial(self.orig_loss, reduction=reduction)
        self._mesa_loss = nn.KLDivLoss(log_target=True, reduction='batchmean' if reduction=='mean' else reduction)
        self.mesa_loss = False
        self.loss_names = L('orig_loss', 'mesa_loss')
        self.loss_funcs = self.loss_names # compatibility with MultiLossCallback
        self._zero, self._loss = torch.tensor(0., requires_grad=False), {}
        if getattr(self.orig_loss, 'y_int', False): self.y_int = True

    def forward(self, pred, *targs):
        targ, mesa_targ = targs
        self._loss[0] = self.orig_loss(pred, targ)
        if self.mesa_loss:
            self._loss[1] = self.weight*self._mesa_loss(F.log_softmax(self.temp*pred, dim=1), F.log_softmax(self.temp*mesa_targ, dim=1))
        else:
            self._loss[1] = self._zero
        return self._loss[0] + self._loss[1]

    def forward_mixup(self, pred, *targs):
        targ1, targ2, mesa_targ, lam = targs
        with NoneReduce(self.orig_loss) as ol:
            loss = torch.lerp(ol(pred, targ1), ol(pred, targ2), lam)
        self._loss[0] = reduce_loss(loss, getattr(self.orig_loss, 'reduction', 'mean'))
        if self.mesa_loss:
            self._loss[1] = self.weight*self._mesa_loss(F.log_softmax(self.temp*pred, dim=1), F.log_softmax(self.temp*mesa_targ, dim=1))
        else:
            self._loss[1] = self._zero
        return self._loss[0] + self._loss[1]

    @property
    def reduction(self): return self._reduction

    @reduction.setter
    def reduction(self, r):
        if hasattr(self.orig_loss, 'reduction'): self.orig_loss.reduction = r
        else: self.orig_loss = partial(self.orig_loss, reduction=r)
        self._mesa_loss.reduction = 'batchmean' if r=='mean' else r
        self._reduction = r

    @delegates(Module.to)
    def to(self, *args, **kwargs):
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
        self._zero.to(device)
        super(Module, self).to(*args, **kwargs)

    def activation(self, pred):
        "Returns `orig_loss` `activation`"
        return getattr(self.orig_loss, 'activation', noop)(pred)

    def decodes(self, pred):
        "Returns `orig_loss` `decodes`"
        return getattr(self.orig_loss, 'decodes', noop)(pred)

In [None]:
#|export
class MESACallback(Callback):
    order = MixedPrecision.order+1
    "Callback to implment Memory-Efficient Sharpness-Aware training from https://arxiv.org/abs/2205.14083"
    def __init__(self,
        start_epoch:int=5, # Epoch to start MESA. Defaults to `start_pct` if None (index 1)
        temp:Number=3, # Soften MESA targets by this temperature. τ in paper
        weight:float=0.8, # Weight of MESA loss. λ in paper
        decay:float=0.9998, # EMA decay. β in paper
        reduction:str='mean', # PyTorch loss reduction
        cleanup:bool=True
    ):
        store_attr()

    @torch.no_grad()
    def before_fit(self):
        if hasattr(self.learn, 'lr_finder') or hasattr(self, "gather_preds"): return
        self.start_epoch=self.start_epoch-1
        self._ema_pred = lambda x: 0
        self.orig_loss = self.learn.loss_func
        self.orig_loss_reduction = self.orig_loss.reduction if hasattr(self.orig_loss, 'reduction') else None
        self.learn.loss_func = MESALoss(self.orig_loss, self.temp, self.weight, self.reduction)
        self.learn.loss_func.to(getattr(self.dls, 'device', default_device()))
        self.ema_model = ModelEmaV2(self.learn.model, self.decay)
        self._mixup = len(self.learn._grab_cbs(MixHandlerX)) > 0 and getattr(self.orig_loss, 'y_int', False)

    def before_train(self):
        if self.start_epoch == self.epoch:
            if self._mixup: self.learn.loss_func_mixup.mesa_loss = True
            else:           self.learn.loss_func.mesa_loss = True
            self._ema_pred = self.ema_model.module

    @torch.no_grad()
    def after_pred(self):
        self.learn.yb = tuple([self.y, self._ema_pred(*self.xb)])

    def after_loss(self):
        y, _ = self.yb
        self.learn.yb = tuple([y])

    def after_batch(self):
        self.ema_model.update(self.learn.model)

    @torch.no_grad()
    def after_fit(self):
        if self.cleanup:
            if hasattr(self.orig_loss, 'reduction'):
                self.orig_loss.reduction = self.orig_loss_reduction
            self.learn.loss_func = self.orig_loss
            self.ema_model = None
            self.remove_cb(MESACallback)

In [None]:
#|hide
#|slow
from fastcore.basics import num_cpus

from fastai.data.external import URLs, untar_data
from fastai.data.block import DataBlock, CategoryBlock
from fastai.data.transforms import GrandparentSplitter, get_image_files, parent_label, Normalize
from fastai.learner import Learner, Recorder
from fastai.vision.augment import Resize
from fastai.vision.core import imagenet_stats
from fastai.vision.data import ImageBlock
from fastai.vision.models.xresnet import xresnet18

from fastxtend.callback.cutmixup import MixUp

## Example

In [None]:
#|slow
#|cuda
with no_random():
    imagenette = untar_data(URLs.IMAGENETTE_160)

    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                        splitter=GrandparentSplitter(valid_name='val'),
                        get_items=get_image_files, get_y=parent_label,
                        item_tfms=Resize(64),
                        batch_tfms=Normalize.from_stats(*imagenet_stats))
    dls =  dblock.dataloaders(imagenette, bs=64, num_workers=num_cpus())

    learn = Learner(dls, xresnet18(n_out=dls.c), cbs=[MESACallback, MultiLossCallback])
    learn.remove_cb(Recorder)
    learn.add_cb(Recorder(add_time=False))
    learn.fit_one_cycle(10, 1e-3)

epoch,train_loss,train_orig_loss,train_mesa_loss,valid_loss,valid_orig_loss,valid_mesa_loss
0,1.764903,1.764903,0.0,1.589739,1.589739,0.0
1,1.248024,1.248024,0.0,1.415202,1.415202,0.0
2,1.036043,1.036043,0.0,1.170383,1.170383,0.0
3,0.853409,0.853409,0.0,1.038804,1.038804,0.0
4,0.897851,0.743198,0.154654,1.098253,0.92733,0.170922
5,0.788958,0.628259,0.160699,0.966897,0.799375,0.167522
6,0.677702,0.511786,0.165916,0.941499,0.786114,0.155385
7,0.563589,0.397229,0.16636,0.881339,0.713294,0.168045
8,0.492636,0.325024,0.167612,0.865666,0.704752,0.160914
9,0.460393,0.293331,0.167062,0.860304,0.702629,0.157675


In [None]:
#|hide
#|slow
#|cuda

# mixup test
with no_random():
    imagenette = untar_data(URLs.IMAGENETTE_160)

    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                        splitter=GrandparentSplitter(valid_name='val'),
                        get_items=get_image_files, get_y=parent_label,
                        item_tfms=Resize(64),
                        batch_tfms=Normalize.from_stats(*imagenet_stats))
    dls =  dblock.dataloaders(imagenette, bs=64, num_workers=num_cpus())

    learn = Learner(dls, xresnet18(n_out=dls.c), cbs=[MESACallback, MultiLossCallback, MixUp])
    learn.remove_cb(Recorder)
    learn.add_cb(Recorder(add_time=False))
    learn.fit_one_cycle(10, 1e-3)

epoch,train_loss,train_orig_loss,train_mesa_loss,valid_loss,valid_orig_loss,valid_mesa_loss
0,1.936363,1.936363,0.0,1.637596,1.637596,0.0
1,1.596632,1.596632,0.0,1.494622,1.494622,0.0
2,1.423184,1.423184,0.0,1.211276,1.211276,0.0
3,1.326761,1.326761,0.0,1.170963,1.170963,0.0
4,1.352496,1.263226,0.08927,1.062338,0.940975,0.121363
5,1.273435,1.180923,0.092512,1.022477,0.905502,0.116975
6,1.196219,1.102559,0.093661,0.944279,0.825535,0.118744
7,1.125571,1.030896,0.094675,0.877916,0.752801,0.125115
8,1.067545,0.97296,0.094585,0.858698,0.745409,0.113289
9,1.062083,0.969335,0.092748,0.853712,0.740154,0.113558
