Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ignite/contrib/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ignite.contrib.metrics.regression
from ignite.contrib.metrics.average_precision import AveragePrecision
from ignite.contrib.metrics.cohen_kappa import CohenKappa
from ignite.contrib.metrics.gpu_info import GpuInfo
from ignite.contrib.metrics.precision_recall_curve import PrecisionRecallCurve
from ignite.contrib.metrics.roc_auc import ROC_AUC, RocCurve
75 changes: 75 additions & 0 deletions ignite/contrib/metrics/cohen_kappa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Callable, Optional, Union

import torch

from ignite.metrics import EpochMetric


class CohenKappa(EpochMetric):
"""Compute different types of Cohen's Kappa: Non-Wieghted, Linear, Quadratic.
Accumulating predictions and the ground-truth during an epoch and applying
`sklearn.metrics.cohen_kappa_score <https://scikit-learn.org/stable/modules/
generated/sklearn.metrics.cohen_kappa_score.html>`_ .

Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
weights: a string is used to define the type of Cohen's Kappa whether Non-Weighted or Linear
or Quadratic. Default, None.
check_compute_fn: Default False. If True, `cohen_kappa_score
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.cohen_kappa_score.html>`_
is run on the first batch of data to ensure there are
no issues. User will be warned in case there are any issues computing the function.
device: optional device specification for internal storage.

.. code-block:: python

def activated_output_transform(output):
y_pred, y = output
return y_pred, y

weights = None or linear or quadratic

cohen_kappa = CohenKappa(activated_output_transform, weights)

"""

def __init__(
self,
output_transform: Callable = lambda x: x,
weights: Optional[str] = None,
check_compute_fn: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
):

try:
from sklearn.metrics import cohen_kappa_score
except ImportError:
raise RuntimeError("This contrib module requires sklearn to be installed.")

if weights not in (None, "linear", "quadratic"):
raise ValueError("Kappa Weighting type must be None or linear or quadratic.")

# initalize weights
self.weights = weights

self.cohen_kappa_compute = self.get_cohen_kappa_fn()

super(CohenKappa, self).__init__(
self.cohen_kappa_compute,
output_transform=output_transform,
check_compute_fn=check_compute_fn,
device=device,
)

def get_cohen_kappa_fn(self) -> Callable[[torch.Tensor, torch.Tensor], float]:
from sklearn.metrics import cohen_kappa_score

def wrapper(y_targets: torch.Tensor, y_preds: torch.Tensor) -> float:
y_true = y_targets.cpu().numpy()
y_pred = y_preds.cpu().numpy()
return cohen_kappa_score(y_true, y_pred, weights=self.weights)

return wrapper
198 changes: 198 additions & 0 deletions tests/ignite/contrib/metrics/test_cohen_kappa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import os

import numpy as np
import pytest
import torch
from sklearn.metrics import cohen_kappa_score

import ignite.distributed as idist
from ignite.contrib.metrics import CohenKappa
from ignite.engine import Engine
from ignite.exceptions import NotComputableError


def test_no_update():
ck = CohenKappa()

with pytest.raises(
NotComputableError, match=r"EpochMetric must have at least one example before it can be computed"
):
ck.compute()


def test_input_types():
ck = CohenKappa()
ck.reset()
output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
ck.update(output1)

with pytest.raises(ValueError, match=r"Incoherent types between input y_pred and stored predictions"):
ck.update((torch.randint(0, 5, size=(4, 3)), torch.randint(0, 2, size=(4, 3))))

with pytest.raises(ValueError, match=r"Incoherent types between input y and stored targets"):
ck.update((torch.rand(4, 3), torch.randint(0, 2, size=(4, 3)).to(torch.int32)))


def test_check_shape():
ck = CohenKappa()

with pytest.raises(ValueError, match=r"Predictions should be of shape"):
ck._check_shape((torch.randint(0, 2, size=(10, 1, 5, 12)).long(), torch.randint(0, 2, size=(10, 5, 6)).long()))

with pytest.raises(ValueError, match=r"Predictions should be of shape"):
ck._check_shape((torch.randint(0, 2, size=(10, 1, 6)).long(), torch.randint(0, 2, size=(10, 5, 6)).long()))

with pytest.raises(ValueError, match=r"Targets should be of shape"):
ck._check_shape((torch.randint(0, 2, size=(10, 1)).long(), torch.randint(0, 2, size=(10, 5, 2)).long()))


@pytest.mark.parametrize("weights", [None, "linear", "quadratic"])
def test_cohen_kappa_all_weights(weights):
size = 100
np_y_pred = np.random.randint(0, 2, size=(size, 1), dtype=np.long)
np_y = np.random.randint(0, 2, size=(size, 1), dtype=np.long)
np_ck = cohen_kappa_score(np_y, np_y_pred)

ck_metric = CohenKappa(weights=weights)
y_pred = torch.from_numpy(np_y_pred)
y = torch.from_numpy(np_y)

ck_metric.reset()
ck_metric.update((y_pred, y))
ck = ck_metric.compute()

assert ck == pytest.approx(np_ck)


def test_cohen_kappa_wrong_weights_type():
with pytest.raises(ValueError, match=r"Kappa Weighting type must be"):
ck = CohenKappa(weights=7)

with pytest.raises(ValueError, match=r"Kappa Weighting type must be"):
ck = CohenKappa(weights="dd")


@pytest.mark.parametrize("weights", [None, "linear", "quadratic"])
def test_cohen_kappa_all_weights_with_output_transform(weights):
np.random.seed(1)
size = 100
np_y_pred = np.random.randint(0, 2, size=(size, 1), dtype=np.long)
np_y = np.zeros((size,), dtype=np.long)
np_y[size // 2 :] = 1
np.random.shuffle(np_y)

ck_value_sk = cohen_kappa_score(np_y, np_y_pred)

batch_size = 10

def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

engine = Engine(update_fn)

ck_metric = CohenKappa(output_transform=lambda x: (x[1], x[2]), weights=weights)
ck_metric.attach(engine, "cohen_kappa")

data = list(range(size // batch_size))
ck_value = engine.run(data, max_epochs=1).metrics["cohen_kappa"]

assert ck_value == pytest.approx(ck_value_sk)


def _test_distrib_compute(device):
rank = idist.get_rank()

def _test(metric_device):
metric_device = torch.device(metric_device)
ck_metric = CohenKappa(device=metric_device)

torch.manual_seed(10 + rank)

y_pred = torch.randint(0, 2, size=(100, 1), device=device)
y = torch.randint(0, 2, size=(100, 1), device=device)

ck_metric.update((y_pred, y))

# gather y_pred, y
y_pred = idist.all_gather(y_pred)
y = idist.all_gather(y)

np_y_pred = y_pred.cpu().numpy()
np_y = y.cpu().numpy()

np_ck = cohen_kappa_score(np_y, np_y_pred)

res = ck_metric.compute()
assert res == pytest.approx(np_ck)

for _ in range(3):
_test("cpu")
if device.type != "xla":
_test(idist.device())


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_gpu(distributed_context_single_node_nccl):
device = torch.device(f"cuda:{distributed_context_single_node_nccl['local_rank']}")
_test_distrib_compute(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_cpu(distributed_context_single_node_gloo):

device = torch.device("cpu")
_test_distrib_compute(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
def test_distrib_hvd(gloo_hvd_executor):

device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

gloo_hvd_executor(_test_distrib_compute, (device,), np=nproc, do_init=True)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_cpu(distributed_context_multi_node_gloo):
device = torch.device("cpu")
_test_distrib_compute(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_gpu(distributed_context_multi_node_nccl):
device = torch.device(f"cuda:{distributed_context_multi_node_nccl['local_rank']}")
_test_distrib_compute(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
device = idist.device()
_test_distrib_compute(device)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_compute(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)