Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the ndcg metric [WIP] #2632

Draft
wants to merge 38 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
cd6ec7f
Added the ndcg metric [WIP]
kamalojasv181 Jul 23, 2022
6421879
Merge branch 'master' into ndcg
kamalojasv181 Jul 23, 2022
4535af1
added GPU support, corrected mypy errors, and minor fixes
kamalojasv181 Jul 23, 2022
eb73c99
Merge branch 'ndcg' of https://github.com/kamalojasv181/ignite into ndcg
kamalojasv181 Jul 23, 2022
70d06e5
Incorporated the suggested changes
kamalojasv181 Jul 24, 2022
6a86f5f
Fixed mypy error
kamalojasv181 Jul 24, 2022
7b7ed6f
Fixed bugs in NDCG and added tests for output and reset
kamalojasv181 Jul 24, 2022
2c87ee1
Fixed mypy error
kamalojasv181 Jul 24, 2022
f4c628a
Added the exponential form on https://en.wikipedia.org/wiki/Discounte…
kamalojasv181 Jul 25, 2022
e72b59e
Corrected true, pred order and corresponding tests
kamalojasv181 Jul 25, 2022
ef63d85
Added ties case, exponential tests, log_base tests, corresponding tes…
kamalojasv181 Jul 26, 2022
115501b
Added GPU check on top
kamalojasv181 Jul 26, 2022
189b579
Put tensors on GPU inside the function to pervent error
kamalojasv181 Jul 26, 2022
c509456
Improved tests and minor bugfixes
kamalojasv181 Jul 27, 2022
84900f0
Removed device hyperparam from _ndcg_smaple_scores
kamalojasv181 Jul 27, 2022
9bfc06e
Skipped GPU tests for CPU only systems
kamalojasv181 Jul 27, 2022
477e096
Changed Error message
kamalojasv181 Jul 27, 2022
44329d7
Merge branch 'master' of https://github.com/pytorch/ignite into ndcg
kamalojasv181 Aug 27, 2022
5ba7fb7
Merge branch 'pytorch:master' into ndcg
kamalojasv181 Aug 27, 2022
ac800ff
Made tests randomised from deterministic and introduced ignore_ties_f…
kamalojasv181 Aug 27, 2022
691e89a
Merge branch 'ndcg' of https://github.com/kamalojasv181/ignite into ndcg
kamalojasv181 Aug 27, 2022
962bcef
Merge branch 'master' of https://github.com/pytorch/ignite into ndcg
kamalojasv181 Aug 29, 2022
79979cc
Merge branch 'pytorch:master' into ndcg
kamalojasv181 Aug 29, 2022
0c1d6fd
Changed test name to test_output_cuda from test_output_gpu
kamalojasv181 Aug 29, 2022
85cdcaf
Merge branch 'ndcg' of https://github.com/kamalojasv181/ignite into ndcg
kamalojasv181 Aug 29, 2022
fdf7877
Merge branch 'master' into ndcg
kamalojasv181 Aug 29, 2022
2931d20
Changed variable names to replacement and ignore_ties and removed red…
kamalojasv181 Aug 30, 2022
c308e41
Changed variable names to replacement and ignore_ties and removed red…
kamalojasv181 Aug 30, 2022
95ede6c
Merge branch 'master' of https://github.com/pytorch/ignite into ndcg
kamalojasv181 Aug 30, 2022
3a4d2af
Merge branch 'ndcg' of https://github.com/kamalojasv181/ignite into ndcg
kamalojasv181 Aug 30, 2022
6e66273
Removed redundant test cases and removed the redundant if statement
kamalojasv181 Aug 30, 2022
eb75afa
Added distributed tests, added multiple test cases corresponding to o…
kamalojasv181 Aug 30, 2022
dcf276d
Made the tests wsork on in ddp configuration
kamalojasv181 Aug 31, 2022
cb273e7
Merge branch 'master' of https://github.com/pytorch/ignite into ndcg
kamalojasv181 Aug 31, 2022
b0f449b
Merge branch 'pytorch:master' into ndcg
kamalojasv181 Aug 31, 2022
388db23
Merge branch 'ndcg' of https://github.com/kamalojasv181/ignite into ndcg
kamalojasv181 Aug 31, 2022
b3b6b28
Merge branch 'pytorch:master' into ndcg
kamalojasv181 Sep 1, 2022
6dcf3b2
Returning tuple of two tensors instead of tuple of list of tensors
kamalojasv181 Sep 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ignite.metrics.precision import Precision
from ignite.metrics.psnr import PSNR
from ignite.metrics.recall import Recall
from ignite.metrics.recsys.ndcg import NDCG
from ignite.metrics.root_mean_squared_error import RootMeanSquaredError
from ignite.metrics.running_average import RunningAverage
from ignite.metrics.ssim import SSIM
Expand Down
5 changes: 5 additions & 0 deletions ignite/metrics/recsys/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ignite.metrics.recsys.ndcg import NDCG

__all__ = [
"NDCG",
]
122 changes: 122 additions & 0 deletions ignite/metrics/recsys/ndcg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from typing import Callable, Optional, Sequence, Union

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["NDCG"]


def _tie_averaged_dcg(
y_pred: torch.Tensor,
y_true: torch.Tensor,
discount_cumsum: torch.Tensor,
device: Union[str, torch.device] = torch.device("cpu"),
) -> torch.Tensor:

_, inv, counts = torch.unique(-y_pred, return_inverse=True, return_counts=True)
ranked = torch.zeros(counts.shape[0]).to(device)
ranked.index_put_([inv], y_true, accumulate=True)
ranked /= counts
groups = torch.cumsum(counts, dim=-1) - 1
discount_sums = torch.empty(counts.shape[0]).to(device)
discount_sums[0] = discount_cumsum[groups[0]]
discount_sums[1:] = torch.diff(discount_cumsum[groups])

return torch.sum(torch.mul(ranked, discount_sums))


def _dcg_sample_scores(
y_pred: torch.Tensor,
y_true: torch.Tensor,
k: Optional[int] = None,
log_base: Union[int, float] = 2,
ignore_ties: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
) -> torch.Tensor:

discount = torch.log(torch.tensor(log_base)) / torch.log(torch.arange(y_true.shape[1]) + 2)
discount = discount.to(device)

if k is not None:
discount[k:] = 0.0

if ignore_ties:
ranking = torch.argsort(y_pred, descending=True)
ranked = y_true[torch.arange(ranking.shape[0]).reshape(-1, 1), ranking].to(device)
discounted_gains = torch.mm(ranked, discount.reshape(-1, 1))

else:
discount_cumsum = torch.cumsum(discount, dim=-1)
discounted_gains = torch.tensor(
[_tie_averaged_dcg(y_p, y_t, discount_cumsum, device) for y_p, y_t in zip(y_pred, y_true)], device=device
)
Comment on lines +52 to +54
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, there is no way to make it vectorized == without for-loop ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I havent checked yet. For now I have added this implementation. It's a TODO


return discounted_gains


def _ndcg_sample_scores(
y_pred: torch.Tensor,
y_true: torch.Tensor,
k: Optional[int] = None,
log_base: Union[int, float] = 2,
ignore_ties: bool = False,
) -> torch.Tensor:

device = y_true.device
gain = _dcg_sample_scores(y_pred, y_true, k=k, log_base=log_base, ignore_ties=ignore_ties, device=device)
if not ignore_ties:
gain = gain.unsqueeze(dim=-1)
normalizing_gain = _dcg_sample_scores(y_true, y_true, k=k, log_base=log_base, ignore_ties=True, device=device)
all_relevant = normalizing_gain != 0
normalized_gain = gain[all_relevant] / normalizing_gain[all_relevant]
return normalized_gain


class NDCG(Metric):
def __init__(
self,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
k: Optional[int] = None,
log_base: Union[int, float] = 2,
exponential: bool = False,
ignore_ties: bool = False,
):

if log_base == 1 or log_base <= 0:
raise ValueError(f"Argument log_base should positive and not equal one,but got {log_base}")
self.log_base = log_base
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
self.k = k
self.exponential = exponential
self.ignore_ties = ignore_ties
super(NDCG, self).__init__(output_transform=output_transform, device=device)

@reinit__is_reduced
def reset(self) -> None:

self.num_examples = 0
self.ndcg = torch.tensor(0.0, device=self._device)

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:

y_pred, y_true = output[0].detach(), output[1].detach()

y_pred = y_pred.to(torch.float32).to(self._device)
y_true = y_true.to(torch.float32).to(self._device)

if self.exponential:
y_true = 2 ** y_true - 1

gain = _ndcg_sample_scores(y_pred, y_true, k=self.k, log_base=self.log_base, ignore_ties=self.ignore_ties)
self.ndcg += torch.sum(gain)
self.num_examples += y_pred.shape[0]

@sync_all_reduce("ndcg", "num_examples")
def compute(self) -> float:
if self.num_examples == 0:
raise NotComputableError("NGCD must have at least one example before it can be computed.")

return (self.ndcg / self.num_examples).item()
Loading