Skip to content

Commit

Permalink
Add divergence metrics (#3232)
Browse files Browse the repository at this point in the history
* add KLDivergence metric

* add JSDivergence

* fix variable name

* update docstring for JSDivergence

* Update ignite/metrics/js_divergence.py

Co-authored-by: vfdev <vfdev.5@gmail.com>

* Update ignite/metrics/kl_divergence.py

Co-authored-by: vfdev <vfdev.5@gmail.com>

* swap ground truth and prediction

* swap the definitions of p and q

---------

Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
kzkadc and vfdev-5 committed Apr 17, 2024
1 parent 2c79b7e commit f431e60
Show file tree
Hide file tree
Showing 6 changed files with 512 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ Complete list of metrics
FID
CosineSimilarity
Entropy
KLDivergence
JSDivergence
AveragePrecision
CohenKappa
GpuInfo
Expand Down
4 changes: 4 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from ignite.metrics.gan.fid import FID
from ignite.metrics.gan.inception_score import InceptionScore
from ignite.metrics.gpu_info import GpuInfo
from ignite.metrics.js_divergence import JSDivergence
from ignite.metrics.kl_divergence import KLDivergence
from ignite.metrics.loss import Loss
from ignite.metrics.mean_absolute_error import MeanAbsoluteError
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
Expand Down Expand Up @@ -57,6 +59,8 @@
"InceptionScore",
"mIoU",
"JaccardIndex",
"JSDivergence",
"KLDivergence",
"MultiLabelConfusionMatrix",
"MutualInformation",
"Precision",
Expand Down
87 changes: 87 additions & 0 deletions ignite/metrics/js_divergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.kl_divergence import KLDivergence
from ignite.metrics.metric import sync_all_reduce

__all__ = ["JSDivergence"]


class JSDivergence(KLDivergence):
r"""Calculates the mean of `Jensen-Shannon (JS) divergence
<https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence>`_.
.. math::
\begin{align*}
D_\text{JS}(\mathbf{p}_i \| \mathbf{q}_i) &= \frac{1}{2} D_\text{KL}(\mathbf{p}_i \| \mathbf{m}_i)
+ \frac{1}{2} D_\text{KL}(\mathbf{q}_i \| \mathbf{m}_i), \\
\mathbf{m}_i &= \frac{1}{2}(\mathbf{p}_i + \mathbf{q}_i), \\
D_\text{KL}(\mathbf{p}_i \| \mathbf{q}_i) &= \sum_{c=1}^C p_{i,c} \log \frac{p_{i,c}}{q_{i,c}}.
\end{align*}
where :math:`\mathbf{p}_i` and :math:`\mathbf{q}_i` are the ground truth and prediction probability tensors,
and :math:`D_\text{KL}` is the KL-divergence.
- ``update`` must receive output of the form ``(y_pred, y)``.
- ``y_pred`` and ``y`` are expected to be the unnormalized logits for each class. :math:`(B, C)` (classification)
or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed.
Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
to the metric to transform the output into the form expected by the metric.
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
.. include:: defaults.rst
:start-after: :orphan:
.. testcode::
metric = JSDivergence()
metric.attach(default_evaluator, 'js-div')
y_true = torch.tensor([
[ 0.0000, -2.3026, -2.3026],
[ 1.3863, 1.6094, 1.6094],
[ 0.0000, 0.6931, 1.0986]
])
y_pred = torch.tensor([
[ 0.0000, 0.6931, 1.0986],
[ 1.3863, 1.6094, 1.6094],
[ 0.0000, -2.3026, -2.3026]
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics['js-div'])
.. testoutput::
0.16266516844431558
"""

def _update(self, y_pred: torch.Tensor, y: torch.Tensor) -> None:
m_prob = (F.softmax(y_pred, dim=1) + F.softmax(y, dim=1)) / 2
m_log = m_prob.log()
y_pred = F.log_softmax(y_pred, dim=1)
y = F.log_softmax(y, dim=1)
self._sum_of_kl += (
F.kl_div(m_log, y_pred, log_target=True, reduction="sum")
+ F.kl_div(m_log, y, log_target=True, reduction="sum")
).to(self._device)

@sync_all_reduce("_sum_of_kl", "_num_examples")
def compute(self) -> float:
if self._num_examples == 0:
raise NotComputableError("JSDivergence must have at least one example before it can be computed.")
return self._sum_of_kl.item() / (self._num_examples * 2)
102 changes: 102 additions & 0 deletions ignite/metrics/kl_divergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Sequence

import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["KLDivergence"]


class KLDivergence(Metric):
r"""Calculates the mean of `Kullback-Leibler (KL) divergence
<https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_.
.. math:: D_\text{KL}(\mathbf{p}_i \| \mathbf{q}_i) = \sum_{c=1}^C p_{i,c} \log \frac{p_{i,c}}{q_{i,c}}
where :math:`\mathbf{p}_i` and :math:`\mathbf{q}_i` are the ground truth and prediction probability tensors.
- ``update`` must receive output of the form ``(y_pred, y)``.
- ``y_pred`` and ``y`` are expected to be the unnormalized logits for each class. :math:`(B, C)` (classification)
or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed.
Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
to the metric to transform the output into the form expected by the metric.
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
.. include:: defaults.rst
:start-after: :orphan:
.. testcode::
metric = KLDivergence()
metric.attach(default_evaluator, 'kl-div')
y_true = torch.tensor([
[ 0.0000, -2.3026, -2.3026],
[ 1.3863, 1.6094, 1.6094],
[ 0.0000, 0.6931, 1.0986]
])
y_pred = torch.tensor([
[ 0.0000, 0.6931, 1.0986],
[ 1.3863, 1.6094, 1.6094],
[ 0.0000, -2.3026, -2.3026]
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics['kl-div'])
.. testoutput::
0.7220296859741211
"""

_state_dict_all_req_keys = ("_sum_of_kl", "_num_examples")

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_kl = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()
if y_pred.shape != y.shape:
raise ValueError(f"y_pred and y must be in the same shape, got {y_pred.shape} != {y.shape}.")

if y_pred.ndim >= 3:
num_classes = y_pred.shape[1]
# (B, C, ...) -> (B, ..., C) -> (B*..., C)
# regarding as B*... predictions
y_pred = y_pred.movedim(1, -1).reshape(-1, num_classes)
y = y.movedim(1, -1).reshape(-1, num_classes)
elif y_pred.ndim == 1:
raise ValueError(f"y_pred must be in the shape of (B, C) or (B, C, ...), got {y_pred.shape}.")

self._num_examples += y_pred.shape[0]
self._update(y_pred, y)

def _update(self, y_pred: torch.Tensor, y: torch.Tensor) -> None:
y_pred = F.log_softmax(y_pred, dim=1)
y = F.log_softmax(y, dim=1)
kl_sum = F.kl_div(y_pred, y, log_target=True, reduction="sum")
self._sum_of_kl += kl_sum.to(self._device)

@sync_all_reduce("_sum_of_kl", "_num_examples")
def compute(self) -> float:
if self._num_examples == 0:
raise NotComputableError("KLDivergence must have at least one example before it can be computed.")
return self._sum_of_kl.item() / self._num_examples
159 changes: 159 additions & 0 deletions tests/ignite/metrics/test_js_divergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from typing import Tuple

import numpy as np
import pytest
import torch
from scipy.spatial.distance import jensenshannon
from scipy.special import softmax
from torch import Tensor

import ignite.distributed as idist
from ignite.engine import Engine
from ignite.exceptions import NotComputableError
from ignite.metrics import JSDivergence


def scipy_js_div(np_y_pred: np.ndarray, np_y: np.ndarray) -> float:
y_pred_prob = softmax(np_y_pred, axis=1)
y_prob = softmax(np_y, axis=1)
# jensenshannon computes the sqrt of the JS divergence
js_mean = np.mean(np.square(jensenshannon(y_pred_prob, y_prob, axis=1)))
return js_mean


def test_zero_sample():
js_div = JSDivergence()
with pytest.raises(
NotComputableError, match=r"JSDivergence must have at least one example before it can be computed"
):
js_div.compute()


def test_shape_mismatch():
js_div = JSDivergence()
y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float)
y = torch.tensor([[-2.0, 1.0]], dtype=torch.float)
with pytest.raises(ValueError, match=r"y_pred and y must be in the same shape, got"):
js_div.update((y_pred, y))


def test_invalid_shape():
js_div = JSDivergence()
y_pred = torch.tensor([2.0, 3.0], dtype=torch.float)
y = torch.tensor([4.0, 5.0], dtype=torch.float)
with pytest.raises(ValueError, match=r"y_pred must be in the shape of \(B, C\) or \(B, C, ...\), got"):
js_div.update((y_pred, y))


@pytest.fixture(params=list(range(4)))
def test_case(request):
return [
(torch.randn((100, 10)), torch.rand((100, 10)), 1),
(torch.rand((100, 500)), torch.randn((100, 500)), 1),
# updated batches
(torch.normal(0.0, 5.0, size=(100, 10)), torch.rand((100, 10)), 16),
(torch.normal(5.0, 3.0, size=(100, 200)), torch.rand((100, 200)), 16),
# image segmentation
(torch.randn((100, 5, 32, 32)), torch.rand((100, 5, 32, 32)), 16),
(torch.rand((100, 5, 224, 224)), torch.randn((100, 5, 224, 224)), 16),
][request.param]


@pytest.mark.parametrize("n_times", range(5))
def test_compute(n_times, test_case: Tuple[Tensor, Tensor, int]):
y_pred, y, batch_size = test_case

js_div = JSDivergence()

js_div.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
js_div.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
js_div.update((y_pred, y))

res = js_div.compute()

np_y_pred = y_pred.numpy()
np_y = y.numpy()

np_res = scipy_js_div(np_y_pred, np_y)

assert isinstance(res, float)
assert pytest.approx(np_res, rel=1e-4) == res


def test_accumulator_detached():
js_div = JSDivergence()

y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float)
y = torch.tensor([[-2.0, 1.0], [2.0, 3.0]], dtype=torch.float)
js_div.update((y_pred, y))

assert not js_div._sum_of_kl.requires_grad


@pytest.mark.usefixtures("distributed")
class TestDistributed:
def test_integration(self):
tol = 1e-4
n_iters = 100
batch_size = 10
n_dims = 100

rank = idist.get_rank()
torch.manual_seed(12 + rank)

device = idist.device()
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)

for metric_device in metric_devices:
y_true = torch.randn((n_iters * batch_size, n_dims)).float().to(device)
y_preds = torch.normal(2.0, 3.0, size=(n_iters * batch_size, n_dims)).float().to(device)

engine = Engine(
lambda e, i: (
y_preds[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)
)

m = JSDivergence(device=metric_device)
m.attach(engine, "js_div")

data = list(range(n_iters))
engine.run(data=data, max_epochs=1)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "js_div" in engine.state.metrics
res = engine.state.metrics["js_div"]

y_true_np = y_true.cpu().numpy()
y_preds_np = y_preds.cpu().numpy()
true_res = scipy_js_div(y_preds_np, y_true_np)

assert pytest.approx(true_res, rel=tol) == res

def test_accumulator_device(self):
device = idist.device()
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
for metric_device in metric_devices:
js_div = JSDivergence(device=metric_device)

for dev in (js_div._device, js_div._sum_of_kl.device):
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]]).float()
y = torch.ones(2, 2).float()
js_div.update((y_pred, y))

for dev in (js_div._device, js_div._sum_of_kl.device):
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"
Loading

0 comments on commit f431e60

Please sign in to comment.