Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ Complete list of metrics

.. autofunction:: mIoU

.. autofunction:: JaccardIndex

.. autoclass:: Loss

.. autoclass:: MeanAbsoluteError
Expand Down
3 changes: 2 additions & 1 deletion ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ignite.metrics.accumulation import Average, GeometricAverage, VariableAccumulation
from ignite.metrics.accuracy import Accuracy
from ignite.metrics.confusion_matrix import ConfusionMatrix, DiceCoefficient, IoU, mIoU
from ignite.metrics.confusion_matrix import ConfusionMatrix, DiceCoefficient, IoU, JaccardIndex, mIoU
from ignite.metrics.epoch_metric import EpochMetric
from ignite.metrics.fbeta import Fbeta
from ignite.metrics.frequency import Frequency
Expand Down Expand Up @@ -36,6 +36,7 @@
"GeometricAverage",
"IoU",
"mIoU",
"JaccardIndex",
"MultiLabelConfusionMatrix",
"Precision",
"PSNR",
Expand Down
31 changes: 30 additions & 1 deletion ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
from ignite.metrics.metrics_lambda import MetricsLambda

__all__ = ["ConfusionMatrix", "mIoU", "IoU", "DiceCoefficient", "cmAccuracy", "cmPrecision", "cmRecall"]
__all__ = ["ConfusionMatrix", "mIoU", "IoU", "DiceCoefficient", "cmAccuracy", "cmPrecision", "cmRecall", "JaccardIndex"]


class ConfusionMatrix(Metric):
Expand Down Expand Up @@ -323,3 +323,32 @@ def ignore_index_fn(dice_vector: torch.Tensor) -> torch.Tensor:
return MetricsLambda(ignore_index_fn, dice)
else:
return dice


def JaccardIndex(cm: ConfusionMatrix, ignore_index: Optional[int] = None) -> MetricsLambda:
r"""Calculates the Jaccard Index using :class:`~ignite.metrics.ConfusionMatrix` metric.
Implementation is based on :meth:`~ignite.metrics.IoU`.

.. math:: \text{J}(A, B) = \frac{ \lvert A \cap B \rvert }{ \lvert A \cup B \rvert }

Args:
cm: instance of confusion matrix metric
ignore_index: index to ignore, e.g. background index

Returns:
MetricsLambda

Examples:

.. code-block:: python

train_evaluator = ...

cm = ConfusionMatrix(num_classes=num_classes)
JaccardIndex(cm, ignore_index=0).attach(train_evaluator, 'JaccardIndex')

state = train_evaluator.run(train_dataset)
# state.metrics['JaccardIndex'] -> tensor of shape (num_classes - 1, )

"""
return IoU(cm, ignore_index)
45 changes: 44 additions & 1 deletion tests/ignite/metrics/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import ignite.distributed as idist
from ignite.exceptions import NotComputableError
from ignite.metrics import ConfusionMatrix, IoU, mIoU
from ignite.metrics import ConfusionMatrix, IoU, JaccardIndex, mIoU
from ignite.metrics.confusion_matrix import DiceCoefficient, cmAccuracy, cmPrecision, cmRecall

torch.manual_seed(12)
Expand Down Expand Up @@ -647,6 +647,49 @@ def _test_distrib_accumulator_device(device):
), f"{type(cm.confusion_matrix.device)}:{cm._num_correct.device} vs {type(metric_device)}:{metric_device}"


def test_jaccard_index():
def _test(average=None):

y_true, y_pred = get_y_true_y_pred()
th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred)

true_res = [0, 0, 0]
for index in range(3):
bin_y_true = y_true == index
bin_y_pred = y_pred == index
intersection = bin_y_true & bin_y_pred
union = bin_y_true | bin_y_pred
true_res[index] = intersection.sum() / union.sum()

cm = ConfusionMatrix(num_classes=3, average=average)
jaccard_index = JaccardIndex(cm)

# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)

res = jaccard_index.compute().numpy()

assert np.all(res == true_res)

for ignore_index in range(3):
cm = ConfusionMatrix(num_classes=3)
jaccard_index_metric = JaccardIndex(cm, ignore_index=ignore_index)
# Update metric
output = (th_y_logits, th_y_true)
cm.update(output)
res = jaccard_index_metric.compute().numpy()
true_res_ = true_res[:ignore_index] + true_res[ignore_index + 1 :]
assert np.all(res == true_res_), f"{ignore_index}: {res} vs {true_res_}"

_test()
_test(average="samples")

with pytest.raises(ValueError, match=r"ConfusionMatrix should have average attribute either"):
cm = ConfusionMatrix(num_classes=3, average="precision")
JaccardIndex(cm)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
Expand Down