diff --git a/ignite/contrib/metrics/__init__.py b/ignite/contrib/metrics/__init__.py index e51efe469cac..a9163ea6a737 100644 --- a/ignite/contrib/metrics/__init__.py +++ b/ignite/contrib/metrics/__init__.py @@ -1,5 +1,6 @@ import ignite.contrib.metrics.regression from ignite.contrib.metrics.average_precision import AveragePrecision +from ignite.contrib.metrics.cohen_kappa import CohenKappa from ignite.contrib.metrics.gpu_info import GpuInfo from ignite.contrib.metrics.precision_recall_curve import PrecisionRecallCurve from ignite.contrib.metrics.roc_auc import ROC_AUC, RocCurve diff --git a/ignite/contrib/metrics/cohen_kappa.py b/ignite/contrib/metrics/cohen_kappa.py new file mode 100644 index 000000000000..289b844e2eee --- /dev/null +++ b/ignite/contrib/metrics/cohen_kappa.py @@ -0,0 +1,75 @@ +from typing import Callable, Optional, Union + +import torch + +from ignite.metrics import EpochMetric + + +class CohenKappa(EpochMetric): + """Compute different types of Cohen's Kappa: Non-Wieghted, Linear, Quadratic. + Accumulating predictions and the ground-truth during an epoch and applying + `sklearn.metrics.cohen_kappa_score `_ . + + Args: + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + weights: a string is used to define the type of Cohen's Kappa whether Non-Weighted or Linear + or Quadratic. Default, None. + check_compute_fn: Default False. If True, `cohen_kappa_score + `_ + is run on the first batch of data to ensure there are + no issues. User will be warned in case there are any issues computing the function. + device: optional device specification for internal storage. + + .. code-block:: python + + def activated_output_transform(output): + y_pred, y = output + return y_pred, y + + weights = None or linear or quadratic + + cohen_kappa = CohenKappa(activated_output_transform, weights) + + """ + + def __init__( + self, + output_transform: Callable = lambda x: x, + weights: Optional[str] = None, + check_compute_fn: bool = False, + device: Union[str, torch.device] = torch.device("cpu"), + ): + + try: + from sklearn.metrics import cohen_kappa_score + except ImportError: + raise RuntimeError("This contrib module requires sklearn to be installed.") + + if weights not in (None, "linear", "quadratic"): + raise ValueError("Kappa Weighting type must be None or linear or quadratic.") + + # initalize weights + self.weights = weights + + self.cohen_kappa_compute = self.get_cohen_kappa_fn() + + super(CohenKappa, self).__init__( + self.cohen_kappa_compute, + output_transform=output_transform, + check_compute_fn=check_compute_fn, + device=device, + ) + + def get_cohen_kappa_fn(self) -> Callable[[torch.Tensor, torch.Tensor], float]: + from sklearn.metrics import cohen_kappa_score + + def wrapper(y_targets: torch.Tensor, y_preds: torch.Tensor) -> float: + y_true = y_targets.cpu().numpy() + y_pred = y_preds.cpu().numpy() + return cohen_kappa_score(y_true, y_pred, weights=self.weights) + + return wrapper diff --git a/tests/ignite/contrib/metrics/test_cohen_kappa.py b/tests/ignite/contrib/metrics/test_cohen_kappa.py new file mode 100644 index 000000000000..26ac5f09ab23 --- /dev/null +++ b/tests/ignite/contrib/metrics/test_cohen_kappa.py @@ -0,0 +1,198 @@ +import os + +import numpy as np +import pytest +import torch +from sklearn.metrics import cohen_kappa_score + +import ignite.distributed as idist +from ignite.contrib.metrics import CohenKappa +from ignite.engine import Engine +from ignite.exceptions import NotComputableError + + +def test_no_update(): + ck = CohenKappa() + + with pytest.raises( + NotComputableError, match=r"EpochMetric must have at least one example before it can be computed" + ): + ck.compute() + + +def test_input_types(): + ck = CohenKappa() + ck.reset() + output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) + ck.update(output1) + + with pytest.raises(ValueError, match=r"Incoherent types between input y_pred and stored predictions"): + ck.update((torch.randint(0, 5, size=(4, 3)), torch.randint(0, 2, size=(4, 3)))) + + with pytest.raises(ValueError, match=r"Incoherent types between input y and stored targets"): + ck.update((torch.rand(4, 3), torch.randint(0, 2, size=(4, 3)).to(torch.int32))) + + +def test_check_shape(): + ck = CohenKappa() + + with pytest.raises(ValueError, match=r"Predictions should be of shape"): + ck._check_shape((torch.randint(0, 2, size=(10, 1, 5, 12)).long(), torch.randint(0, 2, size=(10, 5, 6)).long())) + + with pytest.raises(ValueError, match=r"Predictions should be of shape"): + ck._check_shape((torch.randint(0, 2, size=(10, 1, 6)).long(), torch.randint(0, 2, size=(10, 5, 6)).long())) + + with pytest.raises(ValueError, match=r"Targets should be of shape"): + ck._check_shape((torch.randint(0, 2, size=(10, 1)).long(), torch.randint(0, 2, size=(10, 5, 2)).long())) + + +@pytest.mark.parametrize("weights", [None, "linear", "quadratic"]) +def test_cohen_kappa_all_weights(weights): + size = 100 + np_y_pred = np.random.randint(0, 2, size=(size, 1), dtype=np.long) + np_y = np.random.randint(0, 2, size=(size, 1), dtype=np.long) + np_ck = cohen_kappa_score(np_y, np_y_pred) + + ck_metric = CohenKappa(weights=weights) + y_pred = torch.from_numpy(np_y_pred) + y = torch.from_numpy(np_y) + + ck_metric.reset() + ck_metric.update((y_pred, y)) + ck = ck_metric.compute() + + assert ck == pytest.approx(np_ck) + + +def test_cohen_kappa_wrong_weights_type(): + with pytest.raises(ValueError, match=r"Kappa Weighting type must be"): + ck = CohenKappa(weights=7) + + with pytest.raises(ValueError, match=r"Kappa Weighting type must be"): + ck = CohenKappa(weights="dd") + + +@pytest.mark.parametrize("weights", [None, "linear", "quadratic"]) +def test_cohen_kappa_all_weights_with_output_transform(weights): + np.random.seed(1) + size = 100 + np_y_pred = np.random.randint(0, 2, size=(size, 1), dtype=np.long) + np_y = np.zeros((size,), dtype=np.long) + np_y[size // 2 :] = 1 + np.random.shuffle(np_y) + + ck_value_sk = cohen_kappa_score(np_y, np_y_pred) + + batch_size = 10 + + def update_fn(engine, batch): + idx = (engine.state.iteration - 1) * batch_size + y_true_batch = np_y[idx : idx + batch_size] + y_pred_batch = np_y_pred[idx : idx + batch_size] + return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) + + engine = Engine(update_fn) + + ck_metric = CohenKappa(output_transform=lambda x: (x[1], x[2]), weights=weights) + ck_metric.attach(engine, "cohen_kappa") + + data = list(range(size // batch_size)) + ck_value = engine.run(data, max_epochs=1).metrics["cohen_kappa"] + + assert ck_value == pytest.approx(ck_value_sk) + + +def _test_distrib_compute(device): + rank = idist.get_rank() + + def _test(metric_device): + metric_device = torch.device(metric_device) + ck_metric = CohenKappa(device=metric_device) + + torch.manual_seed(10 + rank) + + y_pred = torch.randint(0, 2, size=(100, 1), device=device) + y = torch.randint(0, 2, size=(100, 1), device=device) + + ck_metric.update((y_pred, y)) + + # gather y_pred, y + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + + np_y_pred = y_pred.cpu().numpy() + np_y = y.cpu().numpy() + + np_ck = cohen_kappa_score(np_y, np_y_pred) + + res = ck_metric.compute() + assert res == pytest.approx(np_ck) + + for _ in range(3): + _test("cpu") + if device.type != "xla": + _test(idist.device()) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test_distrib_gpu(distributed_context_single_node_nccl): + device = torch.device(f"cuda:{distributed_context_single_node_nccl['local_rank']}") + _test_distrib_compute(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +def test_distrib_cpu(distributed_context_single_node_gloo): + + device = torch.device("cpu") + _test_distrib_compute(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") +@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") +def test_distrib_hvd(gloo_hvd_executor): + + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") + nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() + + gloo_hvd_executor(_test_distrib_compute, (device,), np=nproc, do_init=True) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): + device = torch.device("cpu") + _test_distrib_compute(device) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): + device = torch.device(f"cuda:{distributed_context_multi_node_nccl['local_rank']}") + _test_distrib_compute(device) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_single_device_xla(): + device = idist.device() + _test_distrib_compute(device) + + +def _test_distrib_xla_nprocs(index): + device = idist.device() + _test_distrib_compute(device) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_xla_nprocs(xmp_executor): + n = int(os.environ["NUM_TPU_WORKERS"]) + xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)