Skip to content

Commit

Permalink
Introduce a variable skip_unrolling in class Metric (#3258)
Browse files Browse the repository at this point in the history
* Introduce a variable skip_unrolling in class Metric

* Add docstring for skip_unrolling, modify skip_unrolling clause

* Modify docstring

Co-authored-by: vfdev <vfdev.5@gmail.com>

* Apply suggestions from code review

Co-authored-by: vfdev <vfdev.5@gmail.com>

* Modify docstring, revert version tag

* Add test_skip_unrolling, DummyMetric5 class

* Add example usage of skip unrolling in Metric, Update Loss class with skip_unrolling arg

* Fix doc

* Add test for skip_unrolling in Loss

* Apply suggestions from code review

* Update ignite/metrics/metric.py

* Update docstring

* fix test_loss.py for python below 3.9

---------

Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
simeetnayan81 and vfdev-5 committed Jul 1, 2024
1 parent 5a66d9e commit d715807
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 4 deletions.
8 changes: 7 additions & 1 deletion ignite/metrics/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class Loss(Metric):
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
skip_unrolling: specifies whether input should be unrolled or not before it is passed to to loss_fn.
Should be true for multi-output model, for example, if ``y_pred`` contains multi-ouput as
``(y_pred_a, y_pred_b)``
Attributes:
required_output_keys: dictionary defines required keys to be found in ``engine.state.output`` if the
Expand Down Expand Up @@ -62,6 +65,8 @@ class Loss(Metric):
-0.3499999...
.. versionchanged:: 0.5.1
``skip_unrolling`` argument is added.
"""

required_output_keys = ("y_pred", "y", "criterion_kwargs")
Expand All @@ -73,8 +78,9 @@ def __init__(
output_transform: Callable = lambda x: x,
batch_size: Callable = len,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
):
super(Loss, self).__init__(output_transform, device=device)
super(Loss, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling)
self._loss_fn = loss_fn
self._batch_size = batch_size

Expand Down
68 changes: 66 additions & 2 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,59 @@ class Metric(Serializable, metaclass=ABCMeta):
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.
Examples:
The following example shows a custom loss metric that expects input from a multi-output model.
.. code-block:: python
import torch
import torch.nn as nn
import torch.nn.functional as F
from ignite.engine import create_supervised_evaluator
from ignite.metrics import Loss
class MyLoss(nn.Module):
def __init__(self, ca: float = 1.0, cb: float = 1.0) -> None:
super().__init__()
self.ca = ca
self.cb = cb
def forward(self,
y_pred: Tuple[torch.Tensor, torch.Tensor],
y_true: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
a_true, b_true = y_true
a_pred, b_pred = y_pred
return self.ca * F.mse_loss(a_pred, a_true) + self.cb * F.cross_entropy(b_pred, b_true)
def prepare_batch(batch, device, non_blocking):
return torch.rand(4, 1), (torch.rand(4, 1), torch.rand(4, 2))
class MyModel(nn.Module):
def forward(self, x):
return torch.rand(4, 1), torch.rand(4, 2)
model = MyModel()
device = "cpu"
loss = MyLoss(0.5, 1.0)
metrics = {
"Loss": Loss(loss, skip_unrolling=True)
}
train_evaluator = create_supervised_evaluator(model, metrics, device, prepare_batch=prepare_batch)
data = range(10)
train_evaluator.run(data)
train_evaluator.state.metrics["Loss"]
Attributes:
required_output_keys: dictionary defines required keys to be found in ``engine.state.output`` if the
Expand Down Expand Up @@ -292,6 +345,9 @@ def compute(self):
.. versionchanged:: 0.4.2
``required_output_keys`` became public attribute.
.. versionchanged:: 0.5.1
``skip_unrolling`` argument is added.
"""

# public class attribute
Expand All @@ -300,7 +356,10 @@ def compute(self):
_required_output_keys = required_output_keys

def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
self,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
):
self._output_transform = output_transform

Expand All @@ -309,6 +368,7 @@ def __init__(
raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.")

self._device = torch.device(device)
self._skip_unrolling = skip_unrolling
self.reset()

@abstractmethod
Expand Down Expand Up @@ -390,7 +450,11 @@ def iteration_completed(self, engine: Engine) -> None:
)
output = tuple(output[k] for k in self.required_output_keys)

if isinstance(output, Sequence) and all([_is_list_of_tensors_or_numbers(o) for o in output]):
if (
(not self._skip_unrolling)
and isinstance(output, Sequence)
and all([_is_list_of_tensors_or_numbers(o) for o in output])
):
if not (len(output) == 2 and len(output[0]) == len(output[1])):
raise ValueError(
f"Output should have 2 items of the same length, "
Expand Down
50 changes: 49 additions & 1 deletion tests/ignite/metrics/test_loss.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
from typing import Tuple
from unittest.mock import MagicMock

import pytest
import torch
from numpy.testing import assert_almost_equal
from torch import nn
from torch.nn.functional import nll_loss
from torch.nn.functional import mse_loss, nll_loss

import ignite.distributed as idist
from ignite.engine import State
Expand Down Expand Up @@ -314,3 +315,50 @@ def compute(self):
(torch.rand(4, 10), torch.randint(0, 3, size=(4,))),
]
evaluator.run(data)


class CustomMultiMSELoss(nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(
self, y_pred: Tuple[torch.Tensor, torch.Tensor], y_true: Tuple[torch.Tensor, torch.Tensor]
) -> torch.Tensor:
a_true, b_true = y_true
a_pred, b_pred = y_pred
return mse_loss(a_pred, a_true) + mse_loss(b_pred, b_true)


class DummyLoss3(Loss):
def __init__(self, loss_fn, expected_loss, output_transform=lambda x: x, skip_unrolling=False):
super(DummyLoss3, self).__init__(loss_fn, output_transform=output_transform, skip_unrolling=skip_unrolling)
self._expected_loss = expected_loss
self._loss_fn = loss_fn

def reset(self):
pass

def compute(self):
pass

def update(self, output):
y_pred, y_true = output
calculated_loss = self._loss_fn(y_pred=y_pred, y_true=y_true)
assert calculated_loss == self._expected_loss


def test_skip_unrolling_loss():
a_pred = torch.rand(8, 1)
b_pred = torch.rand(8, 1)
y_pred = [a_pred, b_pred]
a_true = torch.rand(8, 1)
b_true = torch.rand(8, 1)
y_true = [a_true, b_true]

multi_output_mse_loss = CustomMultiMSELoss()
expected_loss = multi_output_mse_loss(y_pred=y_pred, y_true=y_true)

loss_metric = DummyLoss3(loss_fn=multi_output_mse_loss, expected_loss=expected_loss, skip_unrolling=True)
state = State(output=(y_pred, y_true))
engine = MagicMock(state=state)
loss_metric.iteration_completed(engine)
30 changes: 30 additions & 0 deletions tests/ignite/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,3 +1416,33 @@ def wrapper(x, **kwargs):
assert (output == expected).all(), (output, expected)
else:
assert output == expected, (output, expected)


class DummyMetric5(Metric):
def __init__(self, true_output, output_transform=lambda x: x, skip_unrolling=False):
super(DummyMetric5, self).__init__(output_transform=output_transform, skip_unrolling=skip_unrolling)
self.true_output = true_output

def reset(self):
pass

def compute(self):
pass

def update(self, output):
assert output == self.true_output


def test_skip_unrolling():
# y_pred and y are ouputs recieved from a multi_output model
a_pred = torch.rand(8, 1)
b_pred = torch.rand(8, 1)
y_pred = [a_pred, b_pred]
a_true = torch.rand(8, 1)
b_true = torch.rand(8, 1)
y_true = [a_true, b_true]

metric = DummyMetric5(true_output=(y_pred, y_true), skip_unrolling=True)
state = State(output=(y_pred, y_true))
engine = MagicMock(state=state)
metric.iteration_completed(engine)

0 comments on commit d715807

Please sign in to comment.