Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed ndcg #3054

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ignite.metrics.precision import Precision
from ignite.metrics.psnr import PSNR
from ignite.metrics.recall import Recall
from ignite.metrics.recsys.ndcg import NDCG
from ignite.metrics.root_mean_squared_error import RootMeanSquaredError
from ignite.metrics.running_average import RunningAverage
from ignite.metrics.ssim import SSIM
Expand Down Expand Up @@ -58,4 +59,5 @@
"Rouge",
"RougeN",
"RougeL",
"NDCG",
]
5 changes: 5 additions & 0 deletions ignite/metrics/recsys/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ignite.metrics.recsys.ndcg import NDCG

__all__ = [
"NDCG",
]
116 changes: 116 additions & 0 deletions ignite/metrics/recsys/ndcg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Callable, Optional, Sequence, Union

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["NDCG"]


def _tie_averaged_dcg(
y_pred: torch.Tensor,
y_true: torch.Tensor,
discount_cumsum: torch.Tensor,
device: Union[str, torch.device] = torch.device("cpu"),
) -> torch.Tensor:
_, inv, counts = torch.unique(-y_pred, return_inverse=True, return_counts=True)
ranked = torch.zeros(counts.shape[0]).to(device)
ranked.index_put_([inv], y_true, accumulate=True)
ranked /= counts
groups = torch.cumsum(counts, dim=-1) - 1
discount_sums = torch.empty(counts.shape[0]).to(device)
discount_sums[0] = discount_cumsum[groups[0]]
discount_sums[1:] = torch.diff(discount_cumsum[groups])

return torch.sum(torch.mul(ranked, discount_sums))


def _dcg_sample_scores(
y_pred: torch.Tensor,
y_true: torch.Tensor,
k: Optional[int] = None,
log_base: Union[int, float] = 2,
ignore_ties: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
) -> torch.Tensor:
discount = torch.log(torch.tensor(log_base)) / torch.log(torch.arange(y_true.shape[1]) + 2)
discount = discount.to(device)

if k is not None:
discount[k:] = 0.0

if ignore_ties:
ranking = torch.argsort(y_pred, descending=True)
ranked = y_true[torch.arange(ranking.shape[0]).reshape(-1, 1), ranking].to(device)
discounted_gains = torch.mm(ranked, discount.reshape(-1, 1))

else:
discount_cumsum = torch.cumsum(discount, dim=-1)
discounted_gains = torch.tensor(
[_tie_averaged_dcg(y_p, y_t, discount_cumsum, device) for y_p, y_t in zip(y_pred, y_true)], device=device
)
Comment on lines +126 to +128
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

@ili0820 ili0820 Sep 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 while studying about this, I found out that sklearn is using for loop too. is there any particular reason why this for loop need to be changed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is purely perf reasons. For example, computing s1 below will be faster than s2:

tensor = torch.rand(100)
s1 = tensor.sum()

s2 = 0
for v in tensor:
    s2 += v


return discounted_gains


def _ndcg_sample_scores(
y_pred: torch.Tensor,
y_true: torch.Tensor,
k: Optional[int] = None,
log_base: Union[int, float] = 2,
ignore_ties: bool = False,
) -> torch.Tensor:
device = y_true.device
gain = _dcg_sample_scores(y_pred, y_true, k=k, log_base=log_base, ignore_ties=ignore_ties, device=device)
if not ignore_ties:
gain = gain.unsqueeze(dim=-1)
normalizing_gain = _dcg_sample_scores(y_true, y_true, k=k, log_base=log_base, ignore_ties=True, device=device)
all_relevant = normalizing_gain != 0
normalized_gain = gain[all_relevant] / normalizing_gain[all_relevant]
return normalized_gain


class NDCG(Metric):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we need to write a docstring like here:

class Accuracy(_BaseClassification):

Please read this section of contributing guide: https://github.com/pytorch/ignite/blob/master/CONTRIBUTING.md#writing-documentation, especially about .. versionadded::

def __init__(
self,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
k: Optional[int] = None,
log_base: Union[int, float] = 2,
exponential: bool = False,
ignore_ties: bool = False,
):
if log_base == 1 or log_base <= 0:
raise ValueError(f"Argument log_base should positive and not equal one,but got {log_base}")
self.log_base = log_base
self.k = k
self.exponential = exponential
self.ignore_ties = ignore_ties
super(NDCG, self).__init__(output_transform=output_transform, device=device)

@reinit__is_reduced
def reset(self) -> None:
self.num_examples = 0
self.ndcg = torch.tensor(0.0, device=self._device)

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y_true = output[0].detach(), output[1].detach()

y_pred = y_pred.to(torch.float32).to(self._device)
y_true = y_true.to(torch.float32).to(self._device)

if self.exponential:
y_true = 2**y_true - 1

gain = _ndcg_sample_scores(y_pred, y_true, k=self.k, log_base=self.log_base, ignore_ties=self.ignore_ties)
self.ndcg += torch.sum(gain)
self.num_examples += y_pred.shape[0]

@sync_all_reduce("ndcg", "num_examples")
def compute(self) -> float:
if self.num_examples == 0:
raise NotComputableError("NGCD must have at least one example before it can be computed.")

return (self.ndcg / self.num_examples).item()
168 changes: 168 additions & 0 deletions tests/ignite/metrics/test_ndcg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import numpy as np
import pytest
import torch
from sklearn.metrics import ndcg_score
from sklearn.metrics._ranking import _dcg_sample_scores

import ignite.distributed as idist

from ignite.exceptions import NotComputableError
from ignite.metrics.recsys.ndcg import NDCG


@pytest.fixture(params=[item for item in range(6)])
def test_case(request):
return [
(torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]), torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]])),
(
torch.tensor([[3.7, 3.7, 3.7, 3.7, 3.7], [3.7, 3.7, 3.7, 3.7, 3.9]]),
torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0, 5.0]]),
),
][request.param % 2]


@pytest.mark.parametrize("k", [None, 2, 3])
@pytest.mark.parametrize("exponential", [True, False])
@pytest.mark.parametrize("ignore_ties, replacement", [(True, False), (False, True), (False, False)])
def test_output(available_device, test_case, k, exponential, ignore_ties, replacement):
device = available_device
y_pred_distribution, y = test_case

y_pred = torch.multinomial(y_pred_distribution, 5, replacement=replacement)

y_pred = y_pred.to(device)
y = y.to(device)

ndcg = NDCG(k=k, device=device, exponential=exponential, ignore_ties=ignore_ties)
ndcg.update([y_pred, y])
result_ignite = ndcg.compute()

if exponential:
y = 2**y - 1

result_sklearn = ndcg_score(y.cpu().numpy(), y_pred.cpu().numpy(), k=k, ignore_ties=ignore_ties)

np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6)


def test_reset():
y = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
y_pred = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]])
ndcg = NDCG()
ndcg.update([y_pred, y])
ndcg.reset()

with pytest.raises(NotComputableError, match=r"NGCD must have at least one example before it can be computed."):
ndcg.compute()


def _ndcg_sample_scores(y, y_score, k=None, ignore_ties=False):
gain = _dcg_sample_scores(y, y_score, k, ignore_ties=ignore_ties)
normalizing_gain = _dcg_sample_scores(y, y, k, ignore_ties=True)
all_irrelevant = normalizing_gain == 0
gain[all_irrelevant] = 0
gain[~all_irrelevant] /= normalizing_gain[~all_irrelevant]
return gain


@pytest.mark.parametrize("log_base", [2, 3, 10])
def test_log_base(log_base):
def ndcg_score_with_log_base(y, y_score, *, k=None, sample_weight=None, ignore_ties=False, log_base=2):
gain = _ndcg_sample_scores(y, y_score, k=k, ignore_ties=ignore_ties)
return np.average(gain, weights=sample_weight)

y = torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]])
y_pred = torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]])

ndcg = NDCG(log_base=log_base)
ndcg.update([y_pred, y])

result_ignite = ndcg.compute()
result_sklearn = ndcg_score_with_log_base(y.numpy(), y_pred.numpy(), log_base=log_base)

np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6)


def test_update(test_case):
y_pred, y = test_case

y_pred = y_pred
y = y

y1_pred = torch.multinomial(y_pred, 5, replacement=True)
y1_true = torch.multinomial(y, 5, replacement=True)

y2_pred = torch.multinomial(y_pred, 5, replacement=True)
y2_true = torch.multinomial(y, 5, replacement=True)

y_pred_combined = torch.cat((y1_pred, y2_pred))
y_combined = torch.cat((y1_true, y2_true))

ndcg = NDCG()

ndcg.update([y1_pred, y1_true])
ndcg.update([y2_pred, y2_true])

result_ignite = ndcg.compute()

result_sklearn = ndcg_score(y_combined.numpy(), y_pred_combined.numpy())

np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-6)


@pytest.mark.parametrize("metric_device", ["cpu", "process_device"])
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("num_epochs", [1, 2])
def test_distrib_integration(distributed, num_epochs, metric_device):
from ignite.engine import Engine

rank = idist.get_rank()
torch.manual_seed(12 + rank)
n_iters = 5
batch_size = 8
device = idist.device()
if metric_device == "process_device":
metric_device = device if device.type != "xla" else "cpu"

# 10 items
y = torch.rand((n_iters * batch_size, 10)).to(device)
y_preds = torch.rand((n_iters * batch_size, 10)).to(device)

def update(engine, i):
return (
y_preds[i * batch_size : (i + 1) * batch_size, ...],
y[i * batch_size : (i + 1) * batch_size, ...],
)

engine = Engine(update)
NDCG(device=metric_device).attach(engine, "ndcg")

data = list(range(n_iters))
engine.run(data=data, max_epochs=num_epochs)

y_preds = idist.all_gather(y_preds)
y = idist.all_gather(y)

assert "ndcg" in engine.state.metrics
res = engine.state.metrics["ndcg"]

true_res = ndcg_score(y.cpu().numpy(), y_preds.cpu().numpy())

tol = 1e-3 if device.type == "xla" else 1e-4 # Isn't better to ask `distributed` about backend info?

assert pytest.approx(res, abs=tol) == true_res


@pytest.mark.parametrize("metric_device", [torch.device("cpu"), "process_device"])
def test_distrib_accumulator_device(distributed, metric_device):
device = idist.device()
if metric_device == "process_device":
metric_device = torch.device(device if device.type != "xla" else "cpu")

ndcg = NDCG(device=metric_device)

y_pred = torch.rand((2, 10)).to(device)
y = torch.rand((2, 10)).to(device)
ndcg.update((y_pred, y))

dev = ndcg.ndcg.device
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"
Empty file modified tests/run_code_style.sh
100755 → 100644
Empty file.