From 3714ebfb06646da923270ea3dc67f097c7c45a7d Mon Sep 17 00:00:00 2001 From: HoJinLee <2hjc12@gmail.com> Date: Thu, 31 Aug 2023 23:48:28 +0900 Subject: [PATCH 1/5] baseline --- ignite/metrics/__init__.py | 2 + ignite/metrics/recsys/__init__.py | 5 + ignite/metrics/recsys/ndcg.py | 122 ++++++++++++++ tests/ignite/metrics/test_ndcg.py | 266 ++++++++++++++++++++++++++++++ 4 files changed, 395 insertions(+) create mode 100644 ignite/metrics/recsys/__init__.py create mode 100644 ignite/metrics/recsys/ndcg.py create mode 100644 tests/ignite/metrics/test_ndcg.py diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index d001436a3ad..63537f017e3 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -23,6 +23,7 @@ from ignite.metrics.running_average import RunningAverage from ignite.metrics.ssim import SSIM from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy +from ignite.metrics.recsys.ndcg import NDCG __all__ = [ "Metric", @@ -58,4 +59,5 @@ "Rouge", "RougeN", "RougeL", + "NDCG" ] diff --git a/ignite/metrics/recsys/__init__.py b/ignite/metrics/recsys/__init__.py new file mode 100644 index 00000000000..98320b910f3 --- /dev/null +++ b/ignite/metrics/recsys/__init__.py @@ -0,0 +1,5 @@ +from ignite.metrics.recsys.ndcg import NDCG + +__all__ = [ + "NDCG", +] \ No newline at end of file diff --git a/ignite/metrics/recsys/ndcg.py b/ignite/metrics/recsys/ndcg.py new file mode 100644 index 00000000000..d4c83dc032e --- /dev/null +++ b/ignite/metrics/recsys/ndcg.py @@ -0,0 +1,122 @@ +from typing import Callable, Optional, Sequence, Union + +import torch + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["NDCG"] + + +def _tie_averaged_dcg( + y_pred: torch.Tensor, + y_true: torch.Tensor, + discount_cumsum: torch.Tensor, + device: Union[str, torch.device] = torch.device("cpu"), +) -> torch.Tensor: + + _, inv, counts = torch.unique(-y_pred, return_inverse=True, return_counts=True) + ranked = torch.zeros(counts.shape[0]).to(device) + ranked.index_put_([inv], y_true, accumulate=True) + ranked /= counts + groups = torch.cumsum(counts, dim=-1) - 1 + discount_sums = torch.empty(counts.shape[0]).to(device) + discount_sums[0] = discount_cumsum[groups[0]] + discount_sums[1:] = torch.diff(discount_cumsum[groups]) + + return torch.sum(torch.mul(ranked, discount_sums)) + + +def _dcg_sample_scores( + y_pred: torch.Tensor, + y_true: torch.Tensor, + k: Optional[int] = None, + log_base: Union[int, float] = 2, + ignore_ties: bool = False, + device: Union[str, torch.device] = torch.device("cpu"), +) -> torch.Tensor: + + discount = torch.log(torch.tensor(log_base)) / torch.log(torch.arange(y_true.shape[1]) + 2) + discount = discount.to(device) + + if k is not None: + discount[k:] = 0.0 + + if ignore_ties: + ranking = torch.argsort(y_pred, descending=True) + ranked = y_true[torch.arange(ranking.shape[0]).reshape(-1, 1), ranking].to(device) + discounted_gains = torch.mm(ranked, discount.reshape(-1, 1)) + + else: + discount_cumsum = torch.cumsum(discount, dim=-1) + discounted_gains = torch.tensor( + [_tie_averaged_dcg(y_p, y_t, discount_cumsum, device) for y_p, y_t in zip(y_pred, y_true)], device=device + ) + + return discounted_gains + + +def _ndcg_sample_scores( + y_pred: torch.Tensor, + y_true: torch.Tensor, + k: Optional[int] = None, + log_base: Union[int, float] = 2, + ignore_ties: bool = False, +) -> torch.Tensor: + + device = y_true.device + gain = _dcg_sample_scores(y_pred, y_true, k=k, log_base=log_base, ignore_ties=ignore_ties, device=device) + if not ignore_ties: + gain = gain.unsqueeze(dim=-1) + normalizing_gain = _dcg_sample_scores(y_true, y_true, k=k, log_base=log_base, ignore_ties=True, device=device) + all_relevant = normalizing_gain != 0 + normalized_gain = gain[all_relevant] / normalizing_gain[all_relevant] + return normalized_gain + + +class NDCG(Metric): + def __init__( + self, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + k: Optional[int] = None, + log_base: Union[int, float] = 2, + exponential: bool = False, + ignore_ties: bool = False, + ): + + if log_base == 1 or log_base <= 0: + raise ValueError(f"Argument log_base should positive and not equal one,but got {log_base}") + self.log_base = log_base + self.k = k + self.exponential = exponential + self.ignore_ties = ignore_ties + super(NDCG, self).__init__(output_transform=output_transform, device=device) + + @reinit__is_reduced + def reset(self) -> None: + + self.num_examples = 0 + self.ndcg = torch.tensor(0.0, device=self._device) + + @reinit__is_reduced + def update(self, output: Sequence[torch.Tensor]) -> None: + + y_pred, y_true = output[0].detach(), output[1].detach() + + y_pred = y_pred.to(torch.float32).to(self._device) + y_true = y_true.to(torch.float32).to(self._device) + + if self.exponential: + y_true = 2 ** y_true - 1 + + gain = _ndcg_sample_scores(y_pred, y_true, k=self.k, log_base=self.log_base, ignore_ties=self.ignore_ties) + self.ndcg += torch.sum(gain) + self.num_examples += y_pred.shape[0] + + @sync_all_reduce("ndcg", "num_examples") + def compute(self) -> float: + if self.num_examples == 0: + raise NotComputableError("NGCD must have at least one example before it can be computed.") + + return (self.ndcg / self.num_examples).item() \ No newline at end of file diff --git a/tests/ignite/metrics/test_ndcg.py b/tests/ignite/metrics/test_ndcg.py new file mode 100644 index 00000000000..b8acf277622 --- /dev/null +++ b/tests/ignite/metrics/test_ndcg.py @@ -0,0 +1,266 @@ +import os + +import numpy as np +import pytest +import torch +from sklearn.metrics import ndcg_score +from sklearn.metrics._ranking import _dcg_sample_scores + +import ignite.distributed as idist +from ignite.engine import Engine + +from ignite.exceptions import NotComputableError +from ignite.metrics.recsys.ndcg import NDCG + + +@pytest.fixture(params=[item for item in range(6)]) +def test_case(request): + + return [ + (torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]), torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]])), + ( + torch.tensor([[3.7, 3.7, 3.7, 3.7, 3.7], [3.7, 3.7, 3.7, 3.7, 3.9]]), + torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0, 5.0]]), + ), + ][request.param % 2] + + +@pytest.mark.parametrize("k", [None, 2, 3]) +@pytest.mark.parametrize("exponential", [True, False]) +@pytest.mark.parametrize("ignore_ties, replacement", [(True, False), (False, True), (False, False)]) +def test_output_cpu(test_case, k, exponential, ignore_ties, replacement): + + device = "cpu" + y_pred_distribution, y_true = test_case + + y_pred = torch.multinomial(y_pred_distribution, 5, replacement=replacement) + + ndcg = NDCG(k=k, device=device, exponential=exponential, ignore_ties=ignore_ties) + ndcg.update([y_pred, y_true]) + result_ignite = ndcg.compute() + + if exponential: + y_true = 2 ** y_true - 1 + + result_sklearn = ndcg_score(y_true.numpy(), y_pred.numpy(), k=k, ignore_ties=ignore_ties) + + np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) + + +@pytest.mark.parametrize("k", [None, 2, 3]) +@pytest.mark.parametrize("exponential", [True, False]) +@pytest.mark.parametrize("ignore_ties, replacement", [(True, False), (False, True), (False, False)]) +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test_output_cuda(test_case, k, exponential, ignore_ties, replacement): + + device = "cuda" + y_pred_distribution, y_true = test_case + + y_pred = torch.multinomial(y_pred_distribution, 5, replacement=replacement) + + y_pred = y_pred.to(device) + y_true = y_true.to(device) + + ndcg = NDCG(k=k, device=device, exponential=exponential, ignore_ties=ignore_ties) + ndcg.update([y_pred, y_true]) + result_ignite = ndcg.compute() + + if exponential: + y_true = 2 ** y_true - 1 + + result_sklearn = ndcg_score(y_true.cpu().numpy(), y_pred.cpu().numpy(), k=k, ignore_ties=ignore_ties) + + np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) + + +def test_reset(): + + y_true = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + y_pred = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]) + ndcg = NDCG() + ndcg.update([y_pred, y_true]) + ndcg.reset() + + with pytest.raises(NotComputableError, match=r"NGCD must have at least one example before it can be computed."): + ndcg.compute() + + +def _ndcg_sample_scores(y_true, y_score, k=None, ignore_ties=False): + + gain = _dcg_sample_scores(y_true, y_score, k, ignore_ties=ignore_ties) + normalizing_gain = _dcg_sample_scores(y_true, y_true, k, ignore_ties=True) + all_irrelevant = normalizing_gain == 0 + gain[all_irrelevant] = 0 + gain[~all_irrelevant] /= normalizing_gain[~all_irrelevant] + return gain + + +@pytest.mark.parametrize("log_base", [2, 3, 10]) +def test_log_base(log_base): + def ndcg_score_with_log_base(y_true, y_score, *, k=None, sample_weight=None, ignore_ties=False, log_base=2): + + gain = _ndcg_sample_scores(y_true, y_score, k=k, ignore_ties=ignore_ties) + return np.average(gain, weights=sample_weight) + + y_true = torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]) + y_pred = torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]]) + + ndcg = NDCG(log_base=log_base) + ndcg.update([y_pred, y_true]) + + result_ignite = ndcg.compute() + result_sklearn = ndcg_score_with_log_base(y_true.numpy(), y_pred.numpy(), log_base=log_base) + + np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) + + +def test_update(test_case): + + y_pred, y_true = test_case + + y_pred = y_pred + y_true = y_true + + y1_pred = torch.multinomial(y_pred, 5, replacement=True) + y1_true = torch.multinomial(y_true, 5, replacement=True) + + y2_pred = torch.multinomial(y_pred, 5, replacement=True) + y2_true = torch.multinomial(y_true, 5, replacement=True) + + y_pred_combined = torch.cat((y1_pred, y2_pred)) + y_true_combined = torch.cat((y1_true, y2_true)) + + ndcg = NDCG() + + ndcg.update([y1_pred, y1_true]) + ndcg.update([y2_pred, y2_true]) + + result_ignite = ndcg.compute() + + result_sklearn = ndcg_score(y_true_combined.numpy(), y_pred_combined.numpy()) + + np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) + + +def _test_distrib_output(device): + + rank = idist.get_rank() + + def _test(n_epochs, metric_device): + + metric_device = torch.device(metric_device) + + n_iters = 5 + batch_size = 8 + n_items = 5 + + torch.manual_seed(12 + rank) + + y_true = torch.rand((n_iters * batch_size, n_items)).to(device) + y_preds = torch.rand((n_iters * batch_size, n_items)).to(device) + + def update(_, i): + return ( + [v for v in y_preds[i * batch_size : (i + 1) * batch_size, ...]], + [v for v in y_true[i * batch_size : (i + 1) * batch_size]], + ) + + engine = Engine(update) + + ndcg = NDCG(device=metric_device) + ndcg.attach(engine, "ndcg") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + y_true = idist.all_gather(y_true) + y_preds = idist.all_gather(y_preds) + + assert ( + ndcg._device == metric_device + ), f"{type(ndcg._device)}:{ndcg._device} vs {type(metric_device)}:{metric_device}" + + assert "ndcg" in engine.state.metrics + res = engine.state.metrics["ndcg"] + if isinstance(res, torch.Tensor): + res = res.cpu().numpy() + + true_res = ndcg_score(y_true.cpu().numpy(), y_preds.cpu().numpy()) + assert pytest.approx(res) == true_res + + metric_devices = ["cpu"] + if device.type != "xla": + metric_devices.append(idist.device()) + for metric_device in metric_devices: + for _ in range(2): + _test(n_epochs=1, metric_device=metric_device) + _test(n_epochs=2, metric_device=metric_device) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo): + + device = idist.device() + _test_distrib_output(device) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): + + device = idist.device() + _test_distrib_output(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test_distrib_nccl_gpu(distributed_context_single_node_nccl): + + device = idist.device() + _test_distrib_output(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): + + device = idist.device() + _test_distrib_output(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") +@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") +def test_distrib_hvd(gloo_hvd_executor): + + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") + nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() + + gloo_hvd_executor(_test_distrib_output, (device,), np=nproc, do_init=True) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_single_device_xla(): + + device = idist.device() + _test_distrib_output(device) + + +def _test_distrib_xla_nprocs(index): + + device = idist.device() + _test_distrib_output(device) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_xla_nprocs(xmp_executor): + n = int(os.environ["NUM_TPU_WORKERS"]) + xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n) \ No newline at end of file From d67303167f16de3a5c218f2cfb340b75f6b3a4b0 Mon Sep 17 00:00:00 2001 From: HoJinLee <2hjc12@gmail.com> Date: Thu, 31 Aug 2023 23:49:40 +0900 Subject: [PATCH 2/5] small changes in distributed tests --- tests/ignite/metrics/test_ndcg.py | 195 ++++++++++-------------------- 1 file changed, 65 insertions(+), 130 deletions(-) diff --git a/tests/ignite/metrics/test_ndcg.py b/tests/ignite/metrics/test_ndcg.py index b8acf277622..b72032f6a3b 100644 --- a/tests/ignite/metrics/test_ndcg.py +++ b/tests/ignite/metrics/test_ndcg.py @@ -31,18 +31,18 @@ def test_case(request): def test_output_cpu(test_case, k, exponential, ignore_ties, replacement): device = "cpu" - y_pred_distribution, y_true = test_case + y_pred_distribution, y = test_case y_pred = torch.multinomial(y_pred_distribution, 5, replacement=replacement) ndcg = NDCG(k=k, device=device, exponential=exponential, ignore_ties=ignore_ties) - ndcg.update([y_pred, y_true]) + ndcg.update([y_pred, y]) result_ignite = ndcg.compute() if exponential: - y_true = 2 ** y_true - 1 + y = 2 ** y - 1 - result_sklearn = ndcg_score(y_true.numpy(), y_pred.numpy(), k=k, ignore_ties=ignore_ties) + result_sklearn = ndcg_score(y.numpy(), y_pred.numpy(), k=k, ignore_ties=ignore_ties) np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) @@ -54,41 +54,41 @@ def test_output_cpu(test_case, k, exponential, ignore_ties, replacement): def test_output_cuda(test_case, k, exponential, ignore_ties, replacement): device = "cuda" - y_pred_distribution, y_true = test_case + y_pred_distribution, y = test_case y_pred = torch.multinomial(y_pred_distribution, 5, replacement=replacement) y_pred = y_pred.to(device) - y_true = y_true.to(device) + y = y.to(device) ndcg = NDCG(k=k, device=device, exponential=exponential, ignore_ties=ignore_ties) - ndcg.update([y_pred, y_true]) + ndcg.update([y_pred, y]) result_ignite = ndcg.compute() if exponential: - y_true = 2 ** y_true - 1 + y = 2 ** y - 1 - result_sklearn = ndcg_score(y_true.cpu().numpy(), y_pred.cpu().numpy(), k=k, ignore_ties=ignore_ties) + result_sklearn = ndcg_score(y.cpu().numpy(), y_pred.cpu().numpy(), k=k, ignore_ties=ignore_ties) np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) def test_reset(): - y_true = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + y = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) y_pred = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]) ndcg = NDCG() - ndcg.update([y_pred, y_true]) + ndcg.update([y_pred, y]) ndcg.reset() with pytest.raises(NotComputableError, match=r"NGCD must have at least one example before it can be computed."): ndcg.compute() -def _ndcg_sample_scores(y_true, y_score, k=None, ignore_ties=False): +def _ndcg_sample_scores(y, y_score, k=None, ignore_ties=False): - gain = _dcg_sample_scores(y_true, y_score, k, ignore_ties=ignore_ties) - normalizing_gain = _dcg_sample_scores(y_true, y_true, k, ignore_ties=True) + gain = _dcg_sample_scores(y, y_score, k, ignore_ties=ignore_ties) + normalizing_gain = _dcg_sample_scores(y, y, k, ignore_ties=True) all_irrelevant = normalizing_gain == 0 gain[all_irrelevant] = 0 gain[~all_irrelevant] /= normalizing_gain[~all_irrelevant] @@ -97,38 +97,38 @@ def _ndcg_sample_scores(y_true, y_score, k=None, ignore_ties=False): @pytest.mark.parametrize("log_base", [2, 3, 10]) def test_log_base(log_base): - def ndcg_score_with_log_base(y_true, y_score, *, k=None, sample_weight=None, ignore_ties=False, log_base=2): + def ndcg_score_with_log_base(y, y_score, *, k=None, sample_weight=None, ignore_ties=False, log_base=2): - gain = _ndcg_sample_scores(y_true, y_score, k=k, ignore_ties=ignore_ties) + gain = _ndcg_sample_scores(y, y_score, k=k, ignore_ties=ignore_ties) return np.average(gain, weights=sample_weight) - y_true = torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]) + y = torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]) y_pred = torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]]) ndcg = NDCG(log_base=log_base) - ndcg.update([y_pred, y_true]) + ndcg.update([y_pred, y]) result_ignite = ndcg.compute() - result_sklearn = ndcg_score_with_log_base(y_true.numpy(), y_pred.numpy(), log_base=log_base) + result_sklearn = ndcg_score_with_log_base(y.numpy(), y_pred.numpy(), log_base=log_base) np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) def test_update(test_case): - y_pred, y_true = test_case + y_pred, y = test_case y_pred = y_pred - y_true = y_true + y = y y1_pred = torch.multinomial(y_pred, 5, replacement=True) - y1_true = torch.multinomial(y_true, 5, replacement=True) + y1_true = torch.multinomial(y, 5, replacement=True) y2_pred = torch.multinomial(y_pred, 5, replacement=True) - y2_true = torch.multinomial(y_true, 5, replacement=True) + y2_true = torch.multinomial(y, 5, replacement=True) y_pred_combined = torch.cat((y1_pred, y2_pred)) - y_true_combined = torch.cat((y1_true, y2_true)) + y_combined = torch.cat((y1_true, y2_true)) ndcg = NDCG() @@ -137,130 +137,65 @@ def test_update(test_case): result_ignite = ndcg.compute() - result_sklearn = ndcg_score(y_true_combined.numpy(), y_pred_combined.numpy()) + result_sklearn = ndcg_score(y_combined.numpy(), y_pred_combined.numpy()) np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) - -def _test_distrib_output(device): +@pytest.mark.parametrize("metric_device", ["cpu", "process_device"]) +def test_distrib_integration(distributed, metric_device): + from ignite.engine import Engine rank = idist.get_rank() - - def _test(n_epochs, metric_device): - - metric_device = torch.device(metric_device) - - n_iters = 5 - batch_size = 8 - n_items = 5 - - torch.manual_seed(12 + rank) - - y_true = torch.rand((n_iters * batch_size, n_items)).to(device) - y_preds = torch.rand((n_iters * batch_size, n_items)).to(device) - - def update(_, i): - return ( - [v for v in y_preds[i * batch_size : (i + 1) * batch_size, ...]], - [v for v in y_true[i * batch_size : (i + 1) * batch_size]], - ) - - engine = Engine(update) - - ndcg = NDCG(device=metric_device) - ndcg.attach(engine, "ndcg") - - data = list(range(n_iters)) - engine.run(data=data, max_epochs=n_epochs) - - y_true = idist.all_gather(y_true) - y_preds = idist.all_gather(y_preds) - - assert ( - ndcg._device == metric_device - ), f"{type(ndcg._device)}:{ndcg._device} vs {type(metric_device)}:{metric_device}" - - assert "ndcg" in engine.state.metrics - res = engine.state.metrics["ndcg"] - if isinstance(res, torch.Tensor): - res = res.cpu().numpy() - - true_res = ndcg_score(y_true.cpu().numpy(), y_preds.cpu().numpy()) - assert pytest.approx(res) == true_res - - metric_devices = ["cpu"] - if device.type != "xla": - metric_devices.append(idist.device()) - for metric_device in metric_devices: - for _ in range(2): - _test(n_epochs=1, metric_device=metric_device) - _test(n_epochs=2, metric_device=metric_device) - - -@pytest.mark.multinode_distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") -def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo): - + torch.manual_seed(12 + rank) + n_iters = 5 + batch_size = 8 device = idist.device() - _test_distrib_output(device) - + if metric_device == "process_device": + metric_device = device if device.type != "xla" else "cpu" -@pytest.mark.multinode_distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") -def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): + #10 items + y = torch.rand((n_iters * batch_size, 10)).to(device) + y_preds = torch.rand((n_iters * batch_size, 10)).to(device) - device = idist.device() - _test_distrib_output(device) + def update(engine, i): + return ( + y_preds[i * batch_size : (i + 1) * batch_size, ...], + y[i * batch_size : (i + 1) * batch_size, ...], + ) + engine = Engine(update) + NDCG(device=metric_device).attach(engine, "ndcg") -@pytest.mark.distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") -def test_distrib_nccl_gpu(distributed_context_single_node_nccl): + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) - device = idist.device() - _test_distrib_output(device) + y_preds = idist.all_gather(y_preds) + y = idist.all_gather(y) + assert "ndcg" in engine.state.metrics + res = engine.state.metrics["ndcg"] -@pytest.mark.distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): + true_res = ndcg_score(y.cpu().numpy(), y_preds.cpu().numpy()) + + tol = 1e-3 if device.type == "xla" else 1e-4 # Isn't better to ask `distributed` about backend info? - device = idist.device() - _test_distrib_output(device) + assert pytest.approx(res, abs=tol) == true_res -@pytest.mark.distributed -@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") -@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") -def test_distrib_hvd(gloo_hvd_executor): - - device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") - nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() - - gloo_hvd_executor(_test_distrib_output, (device,), np=nproc, do_init=True) - - -@pytest.mark.tpu -@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") -@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") -def test_distrib_single_device_xla(): - +@pytest.mark.parametrize("metric_device", [torch.device("cpu"), "process_device"]) +def test_distrib_accumulator_device(distributed, metric_device): device = idist.device() - _test_distrib_output(device) - + if metric_device == "process_device": + metric_device = torch.device(device if device.type != "xla" else "cpu") -def _test_distrib_xla_nprocs(index): + ndcg = NDCG(device=metric_device) - device = idist.device() - _test_distrib_output(device) + + assert ndcg._device == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" + y_pred = torch.rand((2, 10)).to(device) + y = torch.rand((2, 10)).to(device) + ndcg.update((y_pred, y)) -@pytest.mark.tpu -@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") -@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") -def test_distrib_xla_nprocs(xmp_executor): - n = int(os.environ["NUM_TPU_WORKERS"]) - xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n) \ No newline at end of file + dev = ndcg.ndcg.device + assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" From a34fe2c934534fd899e978d627399381cffa2efc Mon Sep 17 00:00:00 2001 From: HoJinLee <2hjc12@gmail.com> Date: Fri, 1 Sep 2023 01:16:31 +0900 Subject: [PATCH 3/5] deleted test_output_cuda & code format alligned --- ignite/metrics/__init__.py | 4 +-- ignite/metrics/recsys/__init__.py | 2 +- ignite/metrics/recsys/ndcg.py | 10 ++----- tests/ignite/metrics/test_ndcg.py | 46 ++++++------------------------- 4 files changed, 14 insertions(+), 48 deletions(-) diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index 63537f017e3..d8905b71b8d 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -19,11 +19,11 @@ from ignite.metrics.precision import Precision from ignite.metrics.psnr import PSNR from ignite.metrics.recall import Recall +from ignite.metrics.recsys.ndcg import NDCG from ignite.metrics.root_mean_squared_error import RootMeanSquaredError from ignite.metrics.running_average import RunningAverage from ignite.metrics.ssim import SSIM from ignite.metrics.top_k_categorical_accuracy import TopKCategoricalAccuracy -from ignite.metrics.recsys.ndcg import NDCG __all__ = [ "Metric", @@ -59,5 +59,5 @@ "Rouge", "RougeN", "RougeL", - "NDCG" + "NDCG", ] diff --git a/ignite/metrics/recsys/__init__.py b/ignite/metrics/recsys/__init__.py index 98320b910f3..71e737cc0bd 100644 --- a/ignite/metrics/recsys/__init__.py +++ b/ignite/metrics/recsys/__init__.py @@ -2,4 +2,4 @@ __all__ = [ "NDCG", -] \ No newline at end of file +] diff --git a/ignite/metrics/recsys/ndcg.py b/ignite/metrics/recsys/ndcg.py index d4c83dc032e..e2ebd417c8e 100644 --- a/ignite/metrics/recsys/ndcg.py +++ b/ignite/metrics/recsys/ndcg.py @@ -14,7 +14,6 @@ def _tie_averaged_dcg( discount_cumsum: torch.Tensor, device: Union[str, torch.device] = torch.device("cpu"), ) -> torch.Tensor: - _, inv, counts = torch.unique(-y_pred, return_inverse=True, return_counts=True) ranked = torch.zeros(counts.shape[0]).to(device) ranked.index_put_([inv], y_true, accumulate=True) @@ -35,7 +34,6 @@ def _dcg_sample_scores( ignore_ties: bool = False, device: Union[str, torch.device] = torch.device("cpu"), ) -> torch.Tensor: - discount = torch.log(torch.tensor(log_base)) / torch.log(torch.arange(y_true.shape[1]) + 2) discount = discount.to(device) @@ -63,7 +61,6 @@ def _ndcg_sample_scores( log_base: Union[int, float] = 2, ignore_ties: bool = False, ) -> torch.Tensor: - device = y_true.device gain = _dcg_sample_scores(y_pred, y_true, k=k, log_base=log_base, ignore_ties=ignore_ties, device=device) if not ignore_ties: @@ -84,7 +81,6 @@ def __init__( exponential: bool = False, ignore_ties: bool = False, ): - if log_base == 1 or log_base <= 0: raise ValueError(f"Argument log_base should positive and not equal one,but got {log_base}") self.log_base = log_base @@ -95,20 +91,18 @@ def __init__( @reinit__is_reduced def reset(self) -> None: - self.num_examples = 0 self.ndcg = torch.tensor(0.0, device=self._device) @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: - y_pred, y_true = output[0].detach(), output[1].detach() y_pred = y_pred.to(torch.float32).to(self._device) y_true = y_true.to(torch.float32).to(self._device) if self.exponential: - y_true = 2 ** y_true - 1 + y_true = 2**y_true - 1 gain = _ndcg_sample_scores(y_pred, y_true, k=self.k, log_base=self.log_base, ignore_ties=self.ignore_ties) self.ndcg += torch.sum(gain) @@ -119,4 +113,4 @@ def compute(self) -> float: if self.num_examples == 0: raise NotComputableError("NGCD must have at least one example before it can be computed.") - return (self.ndcg / self.num_examples).item() \ No newline at end of file + return (self.ndcg / self.num_examples).item() diff --git a/tests/ignite/metrics/test_ndcg.py b/tests/ignite/metrics/test_ndcg.py index b72032f6a3b..693b14fbb92 100644 --- a/tests/ignite/metrics/test_ndcg.py +++ b/tests/ignite/metrics/test_ndcg.py @@ -15,7 +15,6 @@ @pytest.fixture(params=[item for item in range(6)]) def test_case(request): - return [ (torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]), torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]])), ( @@ -28,32 +27,8 @@ def test_case(request): @pytest.mark.parametrize("k", [None, 2, 3]) @pytest.mark.parametrize("exponential", [True, False]) @pytest.mark.parametrize("ignore_ties, replacement", [(True, False), (False, True), (False, False)]) -def test_output_cpu(test_case, k, exponential, ignore_ties, replacement): - - device = "cpu" - y_pred_distribution, y = test_case - - y_pred = torch.multinomial(y_pred_distribution, 5, replacement=replacement) - - ndcg = NDCG(k=k, device=device, exponential=exponential, ignore_ties=ignore_ties) - ndcg.update([y_pred, y]) - result_ignite = ndcg.compute() - - if exponential: - y = 2 ** y - 1 - - result_sklearn = ndcg_score(y.numpy(), y_pred.numpy(), k=k, ignore_ties=ignore_ties) - - np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) - - -@pytest.mark.parametrize("k", [None, 2, 3]) -@pytest.mark.parametrize("exponential", [True, False]) -@pytest.mark.parametrize("ignore_ties, replacement", [(True, False), (False, True), (False, False)]) -@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") -def test_output_cuda(test_case, k, exponential, ignore_ties, replacement): - - device = "cuda" +def test_output(available_device, test_case, k, exponential, ignore_ties, replacement): + device = available_device y_pred_distribution, y = test_case y_pred = torch.multinomial(y_pred_distribution, 5, replacement=replacement) @@ -66,7 +41,7 @@ def test_output_cuda(test_case, k, exponential, ignore_ties, replacement): result_ignite = ndcg.compute() if exponential: - y = 2 ** y - 1 + y = 2**y - 1 result_sklearn = ndcg_score(y.cpu().numpy(), y_pred.cpu().numpy(), k=k, ignore_ties=ignore_ties) @@ -74,7 +49,6 @@ def test_output_cuda(test_case, k, exponential, ignore_ties, replacement): def test_reset(): - y = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) y_pred = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]) ndcg = NDCG() @@ -86,7 +60,6 @@ def test_reset(): def _ndcg_sample_scores(y, y_score, k=None, ignore_ties=False): - gain = _dcg_sample_scores(y, y_score, k, ignore_ties=ignore_ties) normalizing_gain = _dcg_sample_scores(y, y, k, ignore_ties=True) all_irrelevant = normalizing_gain == 0 @@ -98,7 +71,6 @@ def _ndcg_sample_scores(y, y_score, k=None, ignore_ties=False): @pytest.mark.parametrize("log_base", [2, 3, 10]) def test_log_base(log_base): def ndcg_score_with_log_base(y, y_score, *, k=None, sample_weight=None, ignore_ties=False, log_base=2): - gain = _ndcg_sample_scores(y, y_score, k=k, ignore_ties=ignore_ties) return np.average(gain, weights=sample_weight) @@ -115,7 +87,6 @@ def ndcg_score_with_log_base(y, y_score, *, k=None, sample_weight=None, ignore_t def test_update(test_case): - y_pred, y = test_case y_pred = y_pred @@ -141,8 +112,10 @@ def test_update(test_case): np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6) + @pytest.mark.parametrize("metric_device", ["cpu", "process_device"]) -def test_distrib_integration(distributed, metric_device): +@pytest.mark.parametrize("num_epochs", [1, 2]) +def test_distrib_integration(distributed, num_epochs, metric_device): from ignite.engine import Engine rank = idist.get_rank() @@ -153,7 +126,7 @@ def test_distrib_integration(distributed, metric_device): if metric_device == "process_device": metric_device = device if device.type != "xla" else "cpu" - #10 items + # 10 items y = torch.rand((n_iters * batch_size, 10)).to(device) y_preds = torch.rand((n_iters * batch_size, 10)).to(device) @@ -167,7 +140,7 @@ def update(engine, i): NDCG(device=metric_device).attach(engine, "ndcg") data = list(range(n_iters)) - engine.run(data=data, max_epochs=1) + engine.run(data=data, max_epochs=num_epochs) y_preds = idist.all_gather(y_preds) y = idist.all_gather(y) @@ -176,7 +149,7 @@ def update(engine, i): res = engine.state.metrics["ndcg"] true_res = ndcg_score(y.cpu().numpy(), y_preds.cpu().numpy()) - + tol = 1e-3 if device.type == "xla" else 1e-4 # Isn't better to ask `distributed` about backend info? assert pytest.approx(res, abs=tol) == true_res @@ -190,7 +163,6 @@ def test_distrib_accumulator_device(distributed, metric_device): ndcg = NDCG(device=metric_device) - assert ndcg._device == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" y_pred = torch.rand((2, 10)).to(device) From 3ce5a5d5a65c90b7a5a49b0ac82485cae5786c5a Mon Sep 17 00:00:00 2001 From: HoJinLee <2hjc12@gmail.com> Date: Fri, 1 Sep 2023 22:28:34 +0900 Subject: [PATCH 4/5] error fix --- tests/ignite/metrics/test_ndcg.py | 5 ----- tests/run_code_style.sh | 0 2 files changed, 5 deletions(-) mode change 100755 => 100644 tests/run_code_style.sh diff --git a/tests/ignite/metrics/test_ndcg.py b/tests/ignite/metrics/test_ndcg.py index 693b14fbb92..de42533154d 100644 --- a/tests/ignite/metrics/test_ndcg.py +++ b/tests/ignite/metrics/test_ndcg.py @@ -1,5 +1,3 @@ -import os - import numpy as np import pytest import torch @@ -7,7 +5,6 @@ from sklearn.metrics._ranking import _dcg_sample_scores import ignite.distributed as idist -from ignite.engine import Engine from ignite.exceptions import NotComputableError from ignite.metrics.recsys.ndcg import NDCG @@ -163,8 +160,6 @@ def test_distrib_accumulator_device(distributed, metric_device): ndcg = NDCG(device=metric_device) - assert ndcg._device == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" - y_pred = torch.rand((2, 10)).to(device) y = torch.rand((2, 10)).to(device) ndcg.update((y_pred, y)) diff --git a/tests/run_code_style.sh b/tests/run_code_style.sh old mode 100755 new mode 100644 From 3920b2e3da322ef5a2a3f75e9f39526687214654 Mon Sep 17 00:00:00 2001 From: HoJinLee <2hjc12@gmail.com> Date: Thu, 21 Sep 2023 01:54:33 +0900 Subject: [PATCH 5/5] docstrings ing.. --- ignite/metrics/recsys/ndcg.py | 121 +++++++++++++++++++++------------- 1 file changed, 76 insertions(+), 45 deletions(-) diff --git a/ignite/metrics/recsys/ndcg.py b/ignite/metrics/recsys/ndcg.py index e2ebd417c8e..0c15990e948 100644 --- a/ignite/metrics/recsys/ndcg.py +++ b/ignite/metrics/recsys/ndcg.py @@ -8,12 +8,88 @@ __all__ = ["NDCG"] +class NDCG(Metric): + """Computes ndcg + `Normalized DCG(DCG) `_. + + .. math:: + \text{nDCG}_\text{p} = \frac{\text{DCG}_p}{\text{nDCG}_p} + + where :math: \text{DCG}_\text{p} = \sum_{i = 1}^p \frac{2^{rel_i} - 1}{\log_2{(i + 1)}} + :math: \text{IDCG}_\text{p} = \sum_{i = 1}^{|REL_p|} \frac{2^{rel_i} - 1}{\log_2{(i + 1)}} + :math: \text{$rel_i \in \{0, 1\}$ : graded relevance of the result at position $i$} + + + - ``update`` must receive output of the form ``(y_pred, y)``. + + + Args: + + output_transform: A callable that is used to transform the Engine’s + process_function’s output into the form expected by the metric. + device: specifies which device updates are accumulated on. + Setting the metric’s device to be the same as your update arguments ensures + the update method is non-blocking. By default, CPU. + k: Only consider the highest k scores in the ranking. If None, use all outputs. + log_base: Base of logarithm used in computation + exponential: If True, computes exponential gain + ignore_ties: Assume that there are no ties in y_score (which is likely to be the case if y_score is continuous) for efficiency gains. + + Examples: + + """ + + def __init__( + self, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + k: Optional[int] = None, + log_base: Union[int, float] = 2, + exponential: bool = False, + ignore_ties: bool = False, + ): + if log_base == 1 or log_base <= 0: + raise ValueError(f"Argument log_base should positive and not equal one,but got {log_base}") + self.log_base = log_base + self.k = k + self.exponential = exponential + self.ignore_ties = ignore_ties + super(NDCG, self).__init__(output_transform=output_transform, device=device) + + @reinit__is_reduced + def reset(self) -> None: + self.num_examples = 0 + self.ndcg = torch.tensor(0.0, device=self._device) + + @reinit__is_reduced + def update(self, output: Sequence[torch.Tensor]) -> None: + y_pred, y_true = output[0].detach(), output[1].detach() + + y_pred = y_pred.to(torch.float32).to(self._device) + y_true = y_true.to(torch.float32).to(self._device) + + if self.exponential: + y_true = 2**y_true - 1 + + gain = _ndcg_sample_scores(y_pred, y_true, k=self.k, log_base=self.log_base, ignore_ties=self.ignore_ties) + self.ndcg += torch.sum(gain) + self.num_examples += y_pred.shape[0] + + @sync_all_reduce("ndcg", "num_examples") + def compute(self) -> float: + if self.num_examples == 0: + raise NotComputableError("NGCD must have at least one example before it can be computed.") + + return (self.ndcg / self.num_examples).item() + + def _tie_averaged_dcg( y_pred: torch.Tensor, y_true: torch.Tensor, discount_cumsum: torch.Tensor, device: Union[str, torch.device] = torch.device("cpu"), ) -> torch.Tensor: + _, inv, counts = torch.unique(-y_pred, return_inverse=True, return_counts=True) ranked = torch.zeros(counts.shape[0]).to(device) ranked.index_put_([inv], y_true, accumulate=True) @@ -69,48 +145,3 @@ def _ndcg_sample_scores( all_relevant = normalizing_gain != 0 normalized_gain = gain[all_relevant] / normalizing_gain[all_relevant] return normalized_gain - - -class NDCG(Metric): - def __init__( - self, - output_transform: Callable = lambda x: x, - device: Union[str, torch.device] = torch.device("cpu"), - k: Optional[int] = None, - log_base: Union[int, float] = 2, - exponential: bool = False, - ignore_ties: bool = False, - ): - if log_base == 1 or log_base <= 0: - raise ValueError(f"Argument log_base should positive and not equal one,but got {log_base}") - self.log_base = log_base - self.k = k - self.exponential = exponential - self.ignore_ties = ignore_ties - super(NDCG, self).__init__(output_transform=output_transform, device=device) - - @reinit__is_reduced - def reset(self) -> None: - self.num_examples = 0 - self.ndcg = torch.tensor(0.0, device=self._device) - - @reinit__is_reduced - def update(self, output: Sequence[torch.Tensor]) -> None: - y_pred, y_true = output[0].detach(), output[1].detach() - - y_pred = y_pred.to(torch.float32).to(self._device) - y_true = y_true.to(torch.float32).to(self._device) - - if self.exponential: - y_true = 2**y_true - 1 - - gain = _ndcg_sample_scores(y_pred, y_true, k=self.k, log_base=self.log_base, ignore_ties=self.ignore_ties) - self.ndcg += torch.sum(gain) - self.num_examples += y_pred.shape[0] - - @sync_all_reduce("ndcg", "num_examples") - def compute(self) -> float: - if self.num_examples == 0: - raise NotComputableError("NGCD must have at least one example before it can be computed.") - - return (self.ndcg / self.num_examples).item()