In [None]:
#|default_exp callback.mesa

In [None]:
#|export
from __future__ import annotations
from types import FunctionType

try:
    import timm
except ImportError:
    raise ImportError("timm is required to use MESACallback. Install via `pip install timm`.")

from timm.utils.model_ema import ModelEmaV2

from fastai.callback.core import Callback
from fastai.callback.mixup import reduce_loss
from fastai.callback.fp16 import MixedPrecision
from fastai.layers import NoneReduce

from fastxtend.multiloss import MultiLoss, MultiLossCallback, MixHandlerX
from fastxtend.imports import *

In [None]:
#hide
from nbdev.showdoc import *
from fastxtend.test_utils import *

# Memory-Efficient Sharpness-Aware Training
> A callback to add Memory-Efficient Sharpness-Aware Training from [Sharpness-Aware Training for Free](https://arxiv.org/abs/2205.14083) to fastai. EMA implementation from [timm](https://github.com/rwightman/pytorch-image-models/blob/master/timm/utils/model_ema.py).

Memory-Efficient Sharpness-Aware Training (MESA) adds a Kullback-Leibler divergence loss between the model's predictions and the exponential moving average (EMA) predictions to penalize the sharpness of current model weights. This process encourages the model to converge to a flat minimum.

Unlike [Sharpness-Aware Minimization](https://exeter-ecml.github.io/papers/0068-sharpness-aware-minimisation-for-efficiently-improving-generalization) which can double training computation, MESA only requires a minimal ~15% of additional computation. A second forward pass over the EMA weights and the EMA calculation.

## MESALoss -

In [None]:
#|export
class MESALoss(MultiLoss):
    "Loss function for MESA. Automatically added to `Learner` by `MESACallback`"
    def __init__(self,
        orig_loss:nn.Module|FunctionType, # Original loss function from `Learner.loss_func`
        temp:Number=5, # Soften MESA targets by this temperature. τ in paper
        weight:float=0.8, # Weight of MESA loss. λ in paper
        reduction:str='mean' # PyTorch loss reduction
    ):
        store_attr(but='reduction')
        self.temp = 1/temp
        if hasattr(self.orig_loss, 'reduction'): self.orig_loss.reduction = reduction
        else: self.orig_loss = partial(self.orig_loss, reduction=reduction)
        self._mesa_loss = nn.KLDivLoss(log_target=True, reduction='batchmean' if reduction=='mean' else reduction)
        self.mesa_loss = False
        self.loss_names = L('orig_loss', 'mesa_loss')
        self.loss_funcs = self.loss_names # compatibility with MultiLossCallback
        self._zero, self._loss = torch.tensor(0., requires_grad=False), {}
        if getattr(self.orig_loss, 'y_int', False): self.y_int = True

    def forward(self, pred, *targs):
        "Add MESA loss to `orig_loss` if `mesa_loss==True`"
        targ, mesa_targ = targs
        self._loss[0] = self.orig_loss(pred, targ)
        if self.mesa_loss:
            self._loss[1] = self.weight*self._mesa_loss(self.temp*F.log_softmax(pred, dim=1), self.temp*F.log_softmax(mesa_targ, dim=1))
        else:
            self._loss[1] = self._zero
        return self._loss[0] + self._loss[1]

    def forward_mixup(self, pred, *targs):
        "Used by `MixHandlerX` for MixUp, CutMix, etc. Otherwise, same as `forward`."
        targ1, targ2, mesa_targ, lam = targs
        with NoneReduce(self.orig_loss) as ol:
            loss = torch.lerp(ol(pred, targ1), ol(pred, targ2), lam)
        self._loss[0] = reduce_loss(loss, getattr(self.orig_loss, 'reduction', 'mean'))
        if self.mesa_loss:
            self._loss[1] = self.weight*self._mesa_loss(self.temp*F.log_softmax(pred, dim=1), self.temp*F.log_softmax(mesa_targ, dim=1))
        else:
            self._loss[1] = self._zero
        return self._loss[0] + self._loss[1]

    @property
    def reduction(self): return self._reduction

    @reduction.setter
    def reduction(self, r):
        if hasattr(self.orig_loss, 'reduction'): self.orig_loss.reduction = r
        else: self.orig_loss = partial(self.orig_loss, reduction=r)
        self._mesa_loss.reduction = 'batchmean' if r=='mean' else r
        self._reduction = r

    @delegates(Module.to)
    def to(self, *args, **kwargs):
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
        self._zero.to(device)
        super(Module, self).to(*args, **kwargs)

    def activation(self, pred):
        "Returns `orig_loss` `activation`"
        return getattr(self.orig_loss, 'activation', noop)(pred)

    def decodes(self, pred):
        "Returns `orig_loss` `decodes`"
        return getattr(self.orig_loss, 'decodes', noop)(pred)

## MESACallback -

In [None]:
#|export
class MESACallback(Callback):
    order = MixedPrecision.order+1
    "Callback to implement Memory-Efficient Sharpness-Aware Training from https://arxiv.org/abs/2205.14083"
    def __init__(self,
        start_epoch:int=4, # Epoch to start MESA (index 0)
        temp:Number=5, # Soften MESA targets by this temperature. τ in paper
        weight:float=0.8, # Weight of MESA loss. λ in paper
        decay:float=0.9998, # EMA decay. β in paper
        reduction:str='mean', # PyTorch loss reduction
        cleanup:bool=True # Remove `MESACallback` after training 
    ):
        store_attr()

    @torch.no_grad()
    def before_fit(self):
        if hasattr(self.learn, 'lr_finder') or hasattr(self, "gather_preds"): return
        self.start_epoch = max(self.start_epoch, 0)
        self._ema_forward = lambda x: 0
        self.orig_loss = self.learn.loss_func
        self.orig_loss_reduction = self.orig_loss.reduction if hasattr(self.orig_loss, 'reduction') else None
        self.learn.loss_func = MESALoss(self.orig_loss, self.temp, self.weight, self.reduction)
        self.learn.loss_func.to(getattr(self.dls, 'device', default_device()))
        self.ema_model = ModelEmaV2(self.learn.model, self.decay)
        mix = self.learn._grab_cbs(MixHandlerX)
        self._mixup = len(mix) > 0 and mix[0].stack_y

    def before_train(self):
        "Start calculating MESA if `start_epoch` is reached"
        if self.start_epoch == self.epoch:
            if self._mixup: self.learn.loss_func_mixup.mesa_loss = True
            else:           self.learn.loss_func.mesa_loss = True
            self._ema_forward = self.ema_model.module

    @torch.no_grad()
    def after_pred(self):
        "Create MESA targets from EMA prediction"
        self.learn.yb = tuple([self.y, self._ema_forward(*self.xb)])

    def after_loss(self):
        "Remove MESA targets `yb` for metrics compatibility"
        y, _ = self.yb
        self.learn.yb = tuple([y])

    def after_batch(self):
        "Update model's EMA"
        self.ema_model.update(self.learn.model)

    @torch.no_grad()
    def after_fit(self):
        "Optionally remove `MESACallback` from `Learner` post fit"
        if self.cleanup:
            if hasattr(self.orig_loss, 'reduction'):
                self.orig_loss.reduction = self.orig_loss_reduction
            self.learn.loss_func = self.orig_loss
            self.ema_model = None
            self.remove_cb(MESACallback)

Currently, `MESACallback` is incompatible with multi-loss or multi-target training via `MultiLoss` and `MultiTargetLoss`, respectively.

## Hyperparameters 

> Note: <code>MESACallback</code> defaults to the [reported hyperparameters](https://arxiv.org/abs/2205.14083) for training ResNets on ImageNet for 90 epochs with SGD.

Du et al keep `weight` and `temp` constant across all reported CIFAR and ImageNet experiments. 

The MESA `start_epoch` is 5 for 90 epochs of ImageNet training using ResNets and SGD and for 200 epochs of CIFAR10 & CIFAR100 training using ResNets and SGD. Du et al report conflicting start epochs when training ViT on 300 epochs of ImageNet using AdamW, the paper body states the start epoch is 5 and the appendix reports it as 100. (Following fastai convention, `MESACallback` indexes from 0 so its default is 4).

On CIFAR10 and CIFAR100, Du et al reduce MESA's EMA `decay` from ImageNet's 0.9998 to 0.9995.

In [None]:
#|hide
#|slow
from fastcore.basics import num_cpus

from fastai.data.external import URLs, untar_data
from fastai.data.block import DataBlock, CategoryBlock
from fastai.data.transforms import GrandparentSplitter, get_image_files, parent_label, Normalize
from fastai.learner import Learner, Recorder
from fastai.vision.augment import Resize
from fastai.vision.core import imagenet_stats
from fastai.vision.data import ImageBlock
from fastai.vision.models.xresnet import xresnet18

from fastxtend.callback.cutmixup import MixUp

## Example

To log both the original loss and MESA loss in addition to the combined loss, add both `MESACallback` and `MultiLossCallback` to the `Learner`.

In [None]:
#|slow
#|cuda
with no_random():
    imagenette = untar_data(URLs.IMAGENETTE_160)

    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                        splitter=GrandparentSplitter(valid_name='val'),
                        get_items=get_image_files, get_y=parent_label,
                        item_tfms=Resize(64),
                        batch_tfms=Normalize.from_stats(*imagenet_stats))
    dls =  dblock.dataloaders(imagenette, bs=64, num_workers=num_cpus())

    learn = Learner(dls, xresnet18(n_out=dls.c), cbs=[MESACallback, MultiLossCallback])
    learn.remove_cb(Recorder)
    learn.add_cb(Recorder(add_time=False))
    learn.fit_one_cycle(10, 1e-3)

epoch,train_loss,train_orig_loss,train_mesa_loss,valid_loss,valid_orig_loss,valid_mesa_loss
0,1.764903,1.764903,0.0,1.589739,1.589739,0.0
1,1.248024,1.248024,0.0,1.415202,1.415202,0.0
2,1.036043,1.036043,0.0,1.170383,1.170383,0.0
3,0.853409,0.853409,0.0,1.038804,1.038804,0.0
4,1.431591,1.149464,0.282127,1.517223,1.239665,0.277558
5,1.400509,1.099069,0.30144,1.48252,1.173118,0.309402
6,1.358856,1.035686,0.323169,1.4939,1.170899,0.323002
7,1.304363,0.961582,0.342781,1.458709,1.11078,0.347929
8,1.265853,0.914811,0.351041,1.451359,1.115306,0.336053
9,1.248852,0.889899,0.358953,1.450576,1.114508,0.336068


`MESACallback` works with `MixUp`, `CutMix`, `CutMixUp`, and `CutMixUpAugment`.

In [None]:
#|hide
#|slow
#|cuda

# mixup test
with no_random():
    imagenette = untar_data(URLs.IMAGENETTE_160)

    dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                        splitter=GrandparentSplitter(valid_name='val'),
                        get_items=get_image_files, get_y=parent_label,
                        item_tfms=Resize(64),
                        batch_tfms=Normalize.from_stats(*imagenet_stats))
    dls =  dblock.dataloaders(imagenette, bs=64, num_workers=num_cpus())

    learn = Learner(dls, xresnet18(n_out=dls.c), cbs=[MESACallback, MultiLossCallback, MixUp])
    learn.remove_cb(Recorder)
    learn.add_cb(Recorder(add_time=False))
    learn.fit_one_cycle(10, 1e-3)

epoch,train_loss,train_orig_loss,train_mesa_loss,valid_loss,valid_orig_loss,valid_mesa_loss
0,1.936363,1.936363,0.0,1.637596,1.637596,0.0
1,1.596632,1.596632,0.0,1.494622,1.494622,0.0
2,1.423184,1.423184,0.0,1.211276,1.211276,0.0
3,1.326761,1.326761,0.0,1.170963,1.170963,0.0
4,1.67802,1.489676,0.188344,1.522724,1.293728,0.228995
5,1.656398,1.450418,0.20598,1.495232,1.249833,0.2454
6,1.631453,1.4042,0.227253,1.512141,1.249805,0.262335
7,1.598414,1.358582,0.239832,1.456735,1.160747,0.295988
8,1.568128,1.318318,0.24981,1.449251,1.170657,0.278594
9,1.565907,1.313766,0.252141,1.451141,1.162868,0.288273
