In [None]:
#|default_exp multiloss

In [None]:
#|export
from __future__ import annotations
from types import FunctionType

from torch.distributions.beta import Beta

from fastai.callback.core import Callback
from fastai.learner import Recorder
from fastai.layers import NoneReduce

from fastxtend.basics import is_listish
from fastxtend.metrics import AvgLossX, AvgSmoothLossX
from fastxtend.imports import *

In [None]:
#hide
from nbdev.showdoc import *
from fastxtend.test_utils import *

# MultiLoss
> A loss wrapper and callback to calculate and log individual losses as fastxtend metrics.

In [None]:
#|exporti
def init_loss(l, **kwargs):
    "Initiatize loss class or partial loss function"
    return partialler(l, reduction='none') if isinstance(l, FunctionType) else l(reduction='none', **kwargs)

In [None]:
#|export
class MultiLoss(Module):
    """
    Combine multiple `loss_funcs` on one prediction & target via `reduction`, with optional weighting. 
    
    Log `loss_funcs` as metrics via `MultiLossCallback`, optionally using `loss_names`.
    """
    def __init__(self, 
        loss_funcs:listy[Callable[...,nn.Module]|FunctionType], # Uninitialized loss functions or classes. Must support PyTorch `reduction` string.
        weights:listified[Number]|None=None, # Weight per loss. Defaults to uniform weighting.
        loss_kwargs:listy[dict[str,Any]]|None=None, # kwargs to pass to each loss function. Defaults to None.
        loss_names:listy[str]|None=None, # Loss names to log using `MultiLossCallback`. Defaults to loss `__name__`.
        reduction:str|None='mean' # PyTorch loss reduction
    ):
        store_attr(but='loss_names')
        assert is_listy(loss_funcs), "`loss_funcs` must be list-like"
        
        if weights is None or len(weights)==0: 
            self.weights = [1]*len(loss_funcs)
        else:
            assert len(loss_funcs) == len(weights), "Must provide same number of `weights` as `loss_funcs`"
            self.weights = weights

        if loss_kwargs is None or len(loss_kwargs)==0: loss_kwargs = [{}]*len(loss_funcs)
        else: assert len(loss_funcs) == len(loss_kwargs), "Must provide same number of `loss_kwargs` as `loss_funcs`"

        if loss_names is None or len(loss_names)==0: loss_names = [l.__name__ for l in loss_funcs]
        else: assert len(loss_funcs) == len(loss_names), "Must provide same number of `loss_names` as `loss_funcs`"

        self.loss_funcs = [init_loss(l, **k) for l, k in zip(loss_funcs, loss_kwargs)]
        self.loss_names = L(loss_names)
        self._reduction,self._loss = reduction,{}
        
        for loss in self.loss_funcs:
            if getattr(loss, 'y_int', False):
                self.y_int = True

    def forward(self, pred, targ):
        for i, loss_func in enumerate(self.loss_funcs):
            l = self.weights[i]*loss_func(pred, targ)
            if i == 0: loss = torch.zeros_like(targ).float()
            loss += l
            self._loss[i] = l

        return loss.mean() if self._reduction=='mean' else loss.sum() if self._reduction=='sum' else loss

    def forward_mixup(self, pred, targ1, targ2, lam):
        for i, loss_func in enumerate(self.loss_funcs):
            l = self.weights[i]*torch.lerp(loss_func(pred, targ1), loss_func(pred, targ2), lam)
            if i == 0: loss = torch.zeros_like(targ1).float()
            loss += l
            self._loss[i] = l

        return loss.mean() if self._reduction=='mean' else loss.sum() if self._reduction=='sum' else loss

    @property
    def losses(self): return self._loss

    @property
    def reduction(self): return self._reduction

    @reduction.setter
    def reduction(self, r): self._reduction = r

    @delegates(Module.to)
    def to(self, *args, **kwargs):
        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
        if is_listish(self.weights) or not isinstance(self.weights, torch.Tensor): self.weights = torch.Tensor(self.weights)
        if self.weights.device != device: self.weights = self.weights.to(device=device)
        super().to(*args, **kwargs)

    def activation(self, pred):
        "Returns first `loss_funcs` `activation`"
        return getattr(self.loss_funcs[0], 'activation', noop)(pred)

    def decodes(self, pred):
        "Returns first `loss_funcs` `decodes`"
        return getattr(self.loss_funcs[0], 'decodes', noop)(pred)

`MultiLoss` is a simple multiple loss wrapper which allows logging each individual loss automatically using the `MultiLossCallback`.

Pass uninitialized loss functions to `loss_funcs`, optional per loss weighting via `weights`, any loss arguments via a list of dictionaries in `loss_kwargs`, and optional names for each individual loss via `loss_names`.

If passed, `weights`, `loss_kwargs`, & `loss_names` must be an iterable of the same length as `loss_funcs`.

Output from each loss function must be the same shape.

In [None]:
#|hide
losses = [nn.MSELoss, nn.L1Loss]
multiloss = MultiLoss(loss_funcs=losses)

output = torch.sigmoid(torch.randn(32, 5, 10))
target = torch.randint(0,2,(32, 5, 10))

with torch.no_grad():
    ml = multiloss(output, target)
    for i, l in enumerate(losses):
        test_close(l()(output, target), multiloss.losses[i].mean())

In [None]:
#|hide
from fastai.losses import FocalLoss

losses = [nn.CrossEntropyLoss, FocalLoss]
multiloss = MultiLoss(loss_funcs=losses)

output = torch.randn(32, 5, 128, 128)
target = torch.randint(0, 5, (32, 128, 128))

with torch.no_grad():
    ml = multiloss(output, target)
    for i, l in enumerate(losses):
        test_close(l()(output, target), multiloss.losses[i].mean())

# MultiTargetLoss -

In [None]:
#|export
class MultiTargetLoss(MultiLoss):
    """
    Combine `loss_funcs` from multiple predictions & targets via `reduction`, with optional weighting.
    
    Log `loss_funcs` as metrics via `MultiLossCallback`, optionally using `loss_names`.
    """
    def __init__(self, 
        loss_funcs:listy[Callable[...,nn.Module]|FunctionType], # Uninitialized loss functions or classes. One per prediction and target. Must support PyTorch `reduction` string.
        weights:listified[Number]|None=None, # Weight per loss. Defaults to uniform weighting.
        loss_kwargs:listy[dict[str,Any]]|None=None, # kwargs to pass to each loss function. Defaults to None.
        loss_names:listy[str]|None=None, # Loss names to log using `MultiLossCallback`. Defaults to loss `__name__`.
        reduction:str|None='mean' # PyTorch loss reduction
    ):
        super().__init__(loss_funcs, weights, loss_kwargs, loss_names, reduction)

    def forward(self, preds, targs):
        for i, (loss_func, pred, targ) in enumerate(zip(self.loss_funcs, preds, targs)):
            l = TensorBase(self.weights[i]*loss_func(pred, targ))
            if i == 0: loss = TensorBase(torch.zeros_like(targ)).float()
            loss += l
            self._loss[i] = l

        return loss.mean() if self._reduction=='mean' else loss.sum() if self._reduction=='sum' else loss

    def forward_mixup(self, **kwargs):
        raise NotImplementedError("Mixup doesn't support Multi-Target training")

    def activation(self, preds):
        "Returns list of `activation`"
        return [getattr(self.loss_funcs[i], 'activation', noop)(pred) for i, pred in enumerate(preds)]

    def decodes(self, preds):
        "Returns list of `decodes`"
        return [getattr(self.loss_funcs[i], 'decodes', noop)(pred) for i, pred in enumerate(preds)]

`MultiTargetLoss` a single loss per multiple target version of `Multiloss`. It is a simple multiple loss wrapper which allows logging each individual loss automatically using the `MultiLossCallback`.

Pass uninitialized loss functions to `loss_funcs`, optional per loss weighting via `weights`, any loss arguments via a list of dictionaries in `loss_kwargs`, and optional names for each individual loss via `loss_names`.

If passed, `weights`, `loss_kwargs`, & `loss_names` must be an iterable of the same length as `loss_funcs`.

Output from each loss function must be the same shape.

In [None]:
#|hide
losses = [nn.MSELoss, nn.L1Loss]
multitargloss = MultiTargetLoss(loss_funcs=losses)

outputs = [torch.sigmoid(torch.randn(32, 5, 10)),torch.sigmoid(torch.randn(32, 5, 10))]
targets = [torch.randint(0,2,(32, 5, 10)),torch.randint(0,2,(32, 5, 10))]

with torch.no_grad():
    ml = multitargloss(outputs, targets)
    for i, (l, out, targ) in enumerate(zip(losses, outputs, targets)):
        test_close(l()(out, targ), multitargloss.losses[i].mean())

In [None]:
#|hide
from fastai.losses import FocalLoss

losses = [nn.CrossEntropyLoss, FocalLoss]
multitargloss = MultiTargetLoss(loss_funcs=losses)

outputs = [torch.randn(32, 5, 128, 128), torch.randn(32, 5, 128, 128)]
targets = [torch.randint(0, 5, (32, 128, 128)), torch.randint(0, 5, (32, 128, 128))]

with torch.no_grad():
    ml = multitargloss(outputs, targets)
    for i, (l, out, targ) in enumerate(zip(losses, outputs, targets)):
        test_close(l()(out, targ), multitargloss.losses[i].mean())

## Multiloss Metrics -

In [None]:
#|exporti
class MultiAvgLoss(AvgLossX):
    "Average the MultiLoss losses taking into account potential different batch sizes"
    def __init__(self, 
        i, # `Multiloss` loss function location
        name, # Loss function name
        reduction:str|None='mean' # Override loss reduction for logging
    ):
        store_attr(but='name')
        self._name = name
    
    def accumulate(self, learn):
        bs = find_bs(learn.yb)
        loss = learn.loss_func.losses[self.i]
        loss = loss.mean() if self.reduction=='mean' else loss.sum() if self.reduction=='sum' else loss
        self.total += learn.to_detach(loss)*bs
        self.count += bs

In [None]:
#|exporti
class MultiAvgSmoothLoss(AvgSmoothLossX):
    "Smooth average of the MultiLoss losses (exponentially weighted with `beta`)"
    def __init__(self, 
        i, # `Multiloss` loss function location
        name, # Loss function name
        beta:float=0.98, # Smoothing beta
        reduction:str|None='mean' # Override loss reduction for logging
    ):
        super().__init__()
        store_attr(but='name')
        self._name = name

    def accumulate(self, learn):
        self.count += 1
        loss = learn.loss_func.losses[self.i]
        loss = loss.mean() if self.reduction=='mean' else loss.sum() if self.reduction=='sum' else loss
        self.val = torch.lerp(to_detach(loss, gather=False), self.val, self.beta)

In [None]:
#|exporti
class MultiAvgSmoothLossMixup(AvgSmoothLossX):
    "Smooth average of the MultiLoss losses (exponentially weighted with `beta`)"
    def __init__(self, 
        i, # `Multiloss` loss function location
        name, # Loss function name
        beta:float=0.98, # Smoothing beta
        reduction:str|None='mean' # Override loss reduction for logging
    ):
        super().__init__()
        store_attr(but='name')
        self._name = name

    def accumulate(self, learn):
        self.count += 1
        loss = learn.loss_func_mixup.losses[self.i]
        loss = loss.mean() if self.reduction=='mean' else loss.sum() if self.reduction=='sum' else loss
        self.val = torch.lerp(to_detach(loss, gather=False), self.val, self.beta)

## MixHandlerX -

In [None]:
#|export
class MixHandlerX(Callback):
    "A handler class for implementing `MixUp` style scheduling. Like fastai's `MixHandler` but supports `MultiLoss`."
    run_valid = False
    def __init__(self,
        alpha:float=0.5, # Alpha & beta parametrization for `Beta` distribution
        interp_label:bool|None=None # Blend or stack labels. Defaults to `loss_func.y_int` if None
    ):
        store_attr()
        self.distrib = Beta(tensor(alpha), tensor(alpha))

    def before_fit(self):
        "Determine whether to stack or interpolate labels"
        self.multiloss = isinstance(self.learn.loss_func, MultiLoss)
        if self.interp_label is None: 
            self.stack_y = getattr(self.learn.loss_func, 'y_int', False)
        else:
            self.stack_y = not self.interp_label 

    def before_train(self):
        "Determine whether to stack y"
        if self.stack_y:
            if self.multiloss:
                self.learn.loss_func_mixup = self.learn.loss_func
                self.learn.loss_func = self.multi_lf
            else:
                self.old_lf = self.learn.loss_func
                self.learn.loss_func = self.solo_lf

    def after_train(self):
        "Set the loss function back to the original loss"
        if self.stack_y: 
            if self.multiloss:
                self.learn.loss_func = self.learn.loss_func_mixup
            else: 
                self.learn.loss_func = self.old_lf

    def after_cancel_train(self):
        "If training is canceled, still set the loss function back"
        self.after_train()

    def after_cancel_fit(self):
        "If fit is canceled, still set the loss function back"
        self.after_train()

    def solo_lf(self, pred, *yb):
        "Interpolates losses on stacked labels by `self.lam` during training"
        if not self.training: return self.old_lf(pred, *yb)
        with NoneReduce(self.old_lf) as lf:
            loss = torch.lerp(lf(pred,*self.yb1), lf(pred,*yb), self.lam)
        return reduce_loss(loss, getattr(self.old_lf, 'reduction', 'mean'))

    def multi_lf(self, pred, *yb):
        if not self.training: 
            return self.learn.loss_func_mixup(pred, *yb)
        else:
            return self.learn.loss_func_mixup.forward_mixup(pred, *self.yb1, *yb, self.lam)

`MixHandlerX` is defined here to prevent a circular import between multiloss and cutmixup modules. If `interp_label` is false, then labels will be blended together. Use with losses that prefer floats as labels such as BCE. If `interp_label` is false, then `MixHandlerX` will call the loss function twice, once with each label, and blend the losses together. Use with losses that prefer class integers as labels such as CE. If `interp_label` is None, then it is set via `loss_func.y_int`.

## MultiLossCallback -

In [None]:
#|export
class MultiLossCallback(Callback):
    "Callback to automatically log and name `MultiLoss` losses as fastxtend metrics"
    run_valid,order = False,Recorder.order-1

    def __init__(self, 
        beta:float=0.98, # Smoothing beta
        reduction:str|None='mean' # Override loss reduction for logging
    ):
        store_attr()

    def before_fit(self):
        if not isinstance(self.loss_func, MultiLoss):
            raise ValueError("`MultiLossCallback` requires loss to be `MultiLoss` class")

        mixup = len(self.learn._grab_cbs(MixHandlerX)) > 0 and getattr(self.learn.loss_func, 'y_int', False)

        mets= L()
        reduction = self.loss_func.reduction if self.reduction is None else self.reduction
        for i in range(len(self.loss_func.loss_funcs)):
            if mixup: mets += MultiAvgSmoothLossMixup(i, self.loss_func.loss_names[i], self.beta, reduction)
            else:     mets += MultiAvgSmoothLoss(i, self.loss_func.loss_names[i], self.beta, reduction)
            mets += MultiAvgLoss(i, self.loss_func.loss_names[i], reduction)
        
        self.learn.metrics = mets + self.learn.metrics

## Example

In [None]:
#|hide
#|slow
from fastai.learner import Learner
from fastai.optimizer import SGD
from fastxtend.metrics import RMSE

@delegates(Learner.__init__)
def synth_learner(n_trn=10, n_val=2, cuda=False, lr=1e-3, data=None, model=None, **kwargs):
    if data is None: data=synth_dbunch(n_train=n_trn,n_valid=n_val, cuda=cuda)
    if model is None: model=RegModel()
    return Learner(data, model, lr=lr, opt_func=partial(SGD, mom=0.9), **kwargs)

In [None]:
#|slow
with no_random():
    mloss = MultiLoss(loss_funcs=[nn.MSELoss, nn.L1Loss], 
                      weights=[1, 3.5],
                      loss_names=['mse_loss', 'l1_loss'])


    learn = synth_learner(n_trn=5, loss_func=mloss, metrics=RMSE(), cbs=MultiLossCallback)
    learn.fit(5)

epoch,train_loss,train_mse_loss,train_l1_loss,valid_loss,valid_mse_loss,valid_l1_loss,valid_rmse,time
0,23.598301,12.719514,10.878788,17.910727,9.067028,8.843699,3.011151,00:00
1,22.448792,11.937573,10.511218,15.481797,7.46443,8.017367,2.732111,00:00
2,20.827835,10.837888,9.989948,12.756706,5.756156,7.00055,2.399199,00:00
3,19.028177,9.657351,9.370827,10.031281,4.145008,5.886274,2.035929,00:00
4,17.167393,8.481768,8.685625,7.58102,2.787561,4.793459,1.669599,00:00
