diff --git a/ignite/metrics/loss.py b/ignite/metrics/loss.py index 7182e7033d5..2be0a7d2387 100644 --- a/ignite/metrics/loss.py +++ b/ignite/metrics/loss.py @@ -29,6 +29,9 @@ class Loss(Metric): device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. + skip_unrolling: specifies whether input should be unrolled or not before it is passed to to loss_fn. + Should be true for multi-output model, for example, if ``y_pred`` contains multi-ouput as + ``(y_pred_a, y_pred_b)`` Attributes: required_output_keys: dictionary defines required keys to be found in ``engine.state.output`` if the @@ -62,6 +65,8 @@ class Loss(Metric): -0.3499999... + .. versionchanged:: 0.5.1 + ``skip_unrolling`` argument is added. """ required_output_keys = ("y_pred", "y", "criterion_kwargs") @@ -73,8 +78,9 @@ def __init__( output_transform: Callable = lambda x: x, batch_size: Callable = len, device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = False, ): - super(Loss, self).__init__(output_transform, device=device) + super(Loss, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling) self._loss_fn = loss_fn self._batch_size = batch_size diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 39e5cb74522..4ccfd8ea7af 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -233,6 +233,59 @@ class Metric(Serializable, metaclass=ABCMeta): device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` + Alternatively, ``output_transform`` can be used to handle this. + + Examples: + The following example shows a custom loss metric that expects input from a multi-output model. + + .. code-block:: python + + import torch + import torch.nn as nn + import torch.nn.functional as F + + from ignite.engine import create_supervised_evaluator + from ignite.metrics import Loss + + class MyLoss(nn.Module): + def __init__(self, ca: float = 1.0, cb: float = 1.0) -> None: + super().__init__() + self.ca = ca + self.cb = cb + + def forward(self, + y_pred: Tuple[torch.Tensor, torch.Tensor], + y_true: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + a_true, b_true = y_true + a_pred, b_pred = y_pred + return self.ca * F.mse_loss(a_pred, a_true) + self.cb * F.cross_entropy(b_pred, b_true) + + + def prepare_batch(batch, device, non_blocking): + return torch.rand(4, 1), (torch.rand(4, 1), torch.rand(4, 2)) + + + class MyModel(nn.Module): + + def forward(self, x): + return torch.rand(4, 1), torch.rand(4, 2) + + + model = MyModel() + + device = "cpu" + loss = MyLoss(0.5, 1.0) + metrics = { + "Loss": Loss(loss, skip_unrolling=True) + } + train_evaluator = create_supervised_evaluator(model, metrics, device, prepare_batch=prepare_batch) + + + data = range(10) + train_evaluator.run(data) + train_evaluator.state.metrics["Loss"] Attributes: required_output_keys: dictionary defines required keys to be found in ``engine.state.output`` if the @@ -292,6 +345,9 @@ def compute(self): .. versionchanged:: 0.4.2 ``required_output_keys`` became public attribute. + + .. versionchanged:: 0.5.1 + ``skip_unrolling`` argument is added. """ # public class attribute @@ -300,7 +356,10 @@ def compute(self): _required_output_keys = required_output_keys def __init__( - self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") + self, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = False, ): self._output_transform = output_transform @@ -309,6 +368,7 @@ def __init__( raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.") self._device = torch.device(device) + self._skip_unrolling = skip_unrolling self.reset() @abstractmethod @@ -390,7 +450,11 @@ def iteration_completed(self, engine: Engine) -> None: ) output = tuple(output[k] for k in self.required_output_keys) - if isinstance(output, Sequence) and all([_is_list_of_tensors_or_numbers(o) for o in output]): + if ( + (not self._skip_unrolling) + and isinstance(output, Sequence) + and all([_is_list_of_tensors_or_numbers(o) for o in output]) + ): if not (len(output) == 2 and len(output[0]) == len(output[1])): raise ValueError( f"Output should have 2 items of the same length, " diff --git a/tests/ignite/metrics/test_loss.py b/tests/ignite/metrics/test_loss.py index 19cc68cd45c..0e945bec58c 100644 --- a/tests/ignite/metrics/test_loss.py +++ b/tests/ignite/metrics/test_loss.py @@ -1,11 +1,12 @@ import os +from typing import Tuple from unittest.mock import MagicMock import pytest import torch from numpy.testing import assert_almost_equal from torch import nn -from torch.nn.functional import nll_loss +from torch.nn.functional import mse_loss, nll_loss import ignite.distributed as idist from ignite.engine import State @@ -314,3 +315,50 @@ def compute(self): (torch.rand(4, 10), torch.randint(0, 3, size=(4,))), ] evaluator.run(data) + + +class CustomMultiMSELoss(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward( + self, y_pred: Tuple[torch.Tensor, torch.Tensor], y_true: Tuple[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + a_true, b_true = y_true + a_pred, b_pred = y_pred + return mse_loss(a_pred, a_true) + mse_loss(b_pred, b_true) + + +class DummyLoss3(Loss): + def __init__(self, loss_fn, expected_loss, output_transform=lambda x: x, skip_unrolling=False): + super(DummyLoss3, self).__init__(loss_fn, output_transform=output_transform, skip_unrolling=skip_unrolling) + self._expected_loss = expected_loss + self._loss_fn = loss_fn + + def reset(self): + pass + + def compute(self): + pass + + def update(self, output): + y_pred, y_true = output + calculated_loss = self._loss_fn(y_pred=y_pred, y_true=y_true) + assert calculated_loss == self._expected_loss + + +def test_skip_unrolling_loss(): + a_pred = torch.rand(8, 1) + b_pred = torch.rand(8, 1) + y_pred = [a_pred, b_pred] + a_true = torch.rand(8, 1) + b_true = torch.rand(8, 1) + y_true = [a_true, b_true] + + multi_output_mse_loss = CustomMultiMSELoss() + expected_loss = multi_output_mse_loss(y_pred=y_pred, y_true=y_true) + + loss_metric = DummyLoss3(loss_fn=multi_output_mse_loss, expected_loss=expected_loss, skip_unrolling=True) + state = State(output=(y_pred, y_true)) + engine = MagicMock(state=state) + loss_metric.iteration_completed(engine) diff --git a/tests/ignite/metrics/test_metric.py b/tests/ignite/metrics/test_metric.py index f9db11b1a37..96c19d668d7 100644 --- a/tests/ignite/metrics/test_metric.py +++ b/tests/ignite/metrics/test_metric.py @@ -1416,3 +1416,33 @@ def wrapper(x, **kwargs): assert (output == expected).all(), (output, expected) else: assert output == expected, (output, expected) + + +class DummyMetric5(Metric): + def __init__(self, true_output, output_transform=lambda x: x, skip_unrolling=False): + super(DummyMetric5, self).__init__(output_transform=output_transform, skip_unrolling=skip_unrolling) + self.true_output = true_output + + def reset(self): + pass + + def compute(self): + pass + + def update(self, output): + assert output == self.true_output + + +def test_skip_unrolling(): + # y_pred and y are ouputs recieved from a multi_output model + a_pred = torch.rand(8, 1) + b_pred = torch.rand(8, 1) + y_pred = [a_pred, b_pred] + a_true = torch.rand(8, 1) + b_true = torch.rand(8, 1) + y_true = [a_true, b_true] + + metric = DummyMetric5(true_output=(y_pred, y_true), skip_unrolling=True) + state = State(output=(y_pred, y_true)) + engine = MagicMock(state=state) + metric.iteration_completed(engine)