-
-
Notifications
You must be signed in to change notification settings - Fork 655
add cohen kappa #1690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
add cohen kappa #1690
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
eef60c6
add cohen kappa
KickItLikeShika 5bc2f7a
add cohen kappa
KickItLikeShika 4603c1b
add cohen kappa
KickItLikeShika 32940dc
add updated cohen cappa
KickItLikeShika f27a840
Merge branch 'master' of https://github.com/pytorch/ignite into add-c…
KickItLikeShika 3037b3f
add cohen kappa and all tests
KickItLikeShika 20419c5
add cohen kappa and tests
KickItLikeShika 9f1fe62
update tests
KickItLikeShika 8907637
update tests
KickItLikeShika d2fb323
updated tests
KickItLikeShika 82086ec
updated tests
KickItLikeShika ff240fa
reformatting
KickItLikeShika cdfd9df
Merge branch 'master' into add-cohen-kappa
KickItLikeShika a1491b0
update tests
KickItLikeShika 2922271
update tests
KickItLikeShika 733fa83
Merge branch 'master' into add-cohen-kappa
vfdev-5 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import ignite.contrib.metrics.regression | ||
from ignite.contrib.metrics.average_precision import AveragePrecision | ||
from ignite.contrib.metrics.cohen_kappa import CohenKappa | ||
from ignite.contrib.metrics.gpu_info import GpuInfo | ||
from ignite.contrib.metrics.precision_recall_curve import PrecisionRecallCurve | ||
from ignite.contrib.metrics.roc_auc import ROC_AUC, RocCurve |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from typing import Callable, Optional, Union | ||
|
||
import torch | ||
|
||
from ignite.metrics import EpochMetric | ||
|
||
|
||
class CohenKappa(EpochMetric): | ||
"""Compute different types of Cohen's Kappa: Non-Wieghted, Linear, Quadratic. | ||
Accumulating predictions and the ground-truth during an epoch and applying | ||
`sklearn.metrics.cohen_kappa_score <https://scikit-learn.org/stable/modules/ | ||
generated/sklearn.metrics.cohen_kappa_score.html>`_ . | ||
|
||
Args: | ||
output_transform: a callable that is used to transform the | ||
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the | ||
form expected by the metric. This can be useful if, for example, you have a multi-output model and | ||
you want to compute the metric with respect to one of the outputs. | ||
weights: a string is used to define the type of Cohen's Kappa whether Non-Weighted or Linear | ||
or Quadratic. Default, None. | ||
check_compute_fn: Default False. If True, `cohen_kappa_score | ||
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.cohen_kappa_score.html>`_ | ||
is run on the first batch of data to ensure there are | ||
no issues. User will be warned in case there are any issues computing the function. | ||
device: optional device specification for internal storage. | ||
|
||
.. code-block:: python | ||
|
||
def activated_output_transform(output): | ||
y_pred, y = output | ||
return y_pred, y | ||
|
||
weights = None or linear or quadratic | ||
|
||
cohen_kappa = CohenKappa(activated_output_transform, weights) | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
output_transform: Callable = lambda x: x, | ||
weights: Optional[str] = None, | ||
check_compute_fn: bool = False, | ||
device: Union[str, torch.device] = torch.device("cpu"), | ||
): | ||
|
||
try: | ||
from sklearn.metrics import cohen_kappa_score | ||
except ImportError: | ||
raise RuntimeError("This contrib module requires sklearn to be installed.") | ||
|
||
if weights not in (None, "linear", "quadratic"): | ||
raise ValueError("Kappa Weighting type must be None or linear or quadratic.") | ||
|
||
# initalize weights | ||
self.weights = weights | ||
KickItLikeShika marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self.cohen_kappa_compute = self.get_cohen_kappa_fn() | ||
|
||
super(CohenKappa, self).__init__( | ||
self.cohen_kappa_compute, | ||
output_transform=output_transform, | ||
check_compute_fn=check_compute_fn, | ||
device=device, | ||
) | ||
|
||
def get_cohen_kappa_fn(self) -> Callable[[torch.Tensor, torch.Tensor], float]: | ||
from sklearn.metrics import cohen_kappa_score | ||
|
||
def wrapper(y_targets: torch.Tensor, y_preds: torch.Tensor) -> float: | ||
y_true = y_targets.cpu().numpy() | ||
y_pred = y_preds.cpu().numpy() | ||
return cohen_kappa_score(y_true, y_pred, weights=self.weights) | ||
|
||
return wrapper |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
import os | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
from sklearn.metrics import cohen_kappa_score | ||
|
||
import ignite.distributed as idist | ||
from ignite.contrib.metrics import CohenKappa | ||
from ignite.engine import Engine | ||
from ignite.exceptions import NotComputableError | ||
|
||
|
||
def test_no_update(): | ||
ck = CohenKappa() | ||
|
||
with pytest.raises( | ||
NotComputableError, match=r"EpochMetric must have at least one example before it can be computed" | ||
): | ||
ck.compute() | ||
|
||
|
||
def test_input_types(): | ||
ck = CohenKappa() | ||
ck.reset() | ||
output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long)) | ||
ck.update(output1) | ||
|
||
with pytest.raises(ValueError, match=r"Incoherent types between input y_pred and stored predictions"): | ||
ck.update((torch.randint(0, 5, size=(4, 3)), torch.randint(0, 2, size=(4, 3)))) | ||
|
||
with pytest.raises(ValueError, match=r"Incoherent types between input y and stored targets"): | ||
ck.update((torch.rand(4, 3), torch.randint(0, 2, size=(4, 3)).to(torch.int32))) | ||
|
||
|
||
def test_check_shape(): | ||
ck = CohenKappa() | ||
|
||
with pytest.raises(ValueError, match=r"Predictions should be of shape"): | ||
ck._check_shape((torch.randint(0, 2, size=(10, 1, 5, 12)).long(), torch.randint(0, 2, size=(10, 5, 6)).long())) | ||
|
||
with pytest.raises(ValueError, match=r"Predictions should be of shape"): | ||
ck._check_shape((torch.randint(0, 2, size=(10, 1, 6)).long(), torch.randint(0, 2, size=(10, 5, 6)).long())) | ||
|
||
with pytest.raises(ValueError, match=r"Targets should be of shape"): | ||
ck._check_shape((torch.randint(0, 2, size=(10, 1)).long(), torch.randint(0, 2, size=(10, 5, 2)).long())) | ||
|
||
|
||
@pytest.mark.parametrize("weights", [None, "linear", "quadratic"]) | ||
def test_cohen_kappa_all_weights(weights): | ||
size = 100 | ||
np_y_pred = np.random.randint(0, 2, size=(size, 1), dtype=np.long) | ||
np_y = np.random.randint(0, 2, size=(size, 1), dtype=np.long) | ||
np_ck = cohen_kappa_score(np_y, np_y_pred) | ||
|
||
ck_metric = CohenKappa(weights=weights) | ||
y_pred = torch.from_numpy(np_y_pred) | ||
y = torch.from_numpy(np_y) | ||
|
||
ck_metric.reset() | ||
ck_metric.update((y_pred, y)) | ||
ck = ck_metric.compute() | ||
|
||
assert ck == pytest.approx(np_ck) | ||
|
||
|
||
def test_cohen_kappa_wrong_weights_type(): | ||
with pytest.raises(ValueError, match=r"Kappa Weighting type must be"): | ||
ck = CohenKappa(weights=7) | ||
|
||
with pytest.raises(ValueError, match=r"Kappa Weighting type must be"): | ||
ck = CohenKappa(weights="dd") | ||
|
||
|
||
@pytest.mark.parametrize("weights", [None, "linear", "quadratic"]) | ||
def test_cohen_kappa_all_weights_with_output_transform(weights): | ||
np.random.seed(1) | ||
size = 100 | ||
np_y_pred = np.random.randint(0, 2, size=(size, 1), dtype=np.long) | ||
np_y = np.zeros((size,), dtype=np.long) | ||
np_y[size // 2 :] = 1 | ||
np.random.shuffle(np_y) | ||
|
||
ck_value_sk = cohen_kappa_score(np_y, np_y_pred) | ||
|
||
batch_size = 10 | ||
|
||
def update_fn(engine, batch): | ||
idx = (engine.state.iteration - 1) * batch_size | ||
y_true_batch = np_y[idx : idx + batch_size] | ||
y_pred_batch = np_y_pred[idx : idx + batch_size] | ||
return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) | ||
|
||
engine = Engine(update_fn) | ||
|
||
ck_metric = CohenKappa(output_transform=lambda x: (x[1], x[2]), weights=weights) | ||
ck_metric.attach(engine, "cohen_kappa") | ||
|
||
data = list(range(size // batch_size)) | ||
ck_value = engine.run(data, max_epochs=1).metrics["cohen_kappa"] | ||
|
||
assert ck_value == pytest.approx(ck_value_sk) | ||
|
||
|
||
def _test_distrib_compute(device): | ||
rank = idist.get_rank() | ||
|
||
def _test(metric_device): | ||
metric_device = torch.device(metric_device) | ||
ck_metric = CohenKappa(device=metric_device) | ||
|
||
torch.manual_seed(10 + rank) | ||
|
||
y_pred = torch.randint(0, 2, size=(100, 1), device=device) | ||
y = torch.randint(0, 2, size=(100, 1), device=device) | ||
|
||
ck_metric.update((y_pred, y)) | ||
|
||
# gather y_pred, y | ||
y_pred = idist.all_gather(y_pred) | ||
y = idist.all_gather(y) | ||
|
||
np_y_pred = y_pred.cpu().numpy() | ||
np_y = y.cpu().numpy() | ||
|
||
np_ck = cohen_kappa_score(np_y, np_y_pred) | ||
|
||
res = ck_metric.compute() | ||
assert res == pytest.approx(np_ck) | ||
|
||
for _ in range(3): | ||
_test("cpu") | ||
if device.type != "xla": | ||
_test(idist.device()) | ||
|
||
|
||
@pytest.mark.distributed | ||
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") | ||
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") | ||
def test_distrib_gpu(distributed_context_single_node_nccl): | ||
device = torch.device(f"cuda:{distributed_context_single_node_nccl['local_rank']}") | ||
_test_distrib_compute(device) | ||
|
||
|
||
@pytest.mark.distributed | ||
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") | ||
def test_distrib_cpu(distributed_context_single_node_gloo): | ||
|
||
device = torch.device("cpu") | ||
_test_distrib_compute(device) | ||
|
||
|
||
@pytest.mark.distributed | ||
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") | ||
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") | ||
def test_distrib_hvd(gloo_hvd_executor): | ||
|
||
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") | ||
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() | ||
|
||
gloo_hvd_executor(_test_distrib_compute, (device,), np=nproc, do_init=True) | ||
|
||
|
||
@pytest.mark.multinode_distributed | ||
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") | ||
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") | ||
def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): | ||
device = torch.device("cpu") | ||
_test_distrib_compute(device) | ||
|
||
|
||
@pytest.mark.multinode_distributed | ||
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") | ||
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") | ||
def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): | ||
device = torch.device(f"cuda:{distributed_context_multi_node_nccl['local_rank']}") | ||
_test_distrib_compute(device) | ||
|
||
|
||
@pytest.mark.tpu | ||
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") | ||
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") | ||
def test_distrib_single_device_xla(): | ||
device = idist.device() | ||
_test_distrib_compute(device) | ||
|
||
|
||
def _test_distrib_xla_nprocs(index): | ||
device = idist.device() | ||
_test_distrib_compute(device) | ||
|
||
|
||
@pytest.mark.tpu | ||
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") | ||
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") | ||
def test_distrib_xla_nprocs(xmp_executor): | ||
n = int(os.environ["NUM_TPU_WORKERS"]) | ||
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.