Skip to content

Commit

Permalink
Update ndcg tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shashist committed Aug 6, 2020
1 parent 37a155b commit fa1c898
Showing 1 changed file with 22 additions and 35 deletions.
57 changes: 22 additions & 35 deletions catalyst/utils/metrics/tests/test_ndcg.py
@@ -1,47 +1,34 @@
import numpy as np

import torch

from catalyst.utils import metrics


def test_ndcg():
def test_zero_ndcg():
"""
Tests for catalyst.utils.metrics.ndcg metric.
"""
# check 0: common values
assert (
metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([1, 0, 0])) == 1.0
)
assert np.allclose(
metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([0, 1, 0])), 0.63093
)
assert np.allclose(
metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([0, 0, 1])), 0.5
)
assert np.allclose(
metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([1, 0, 1])), 0.91972
)
assert np.allclose(
metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([1, 0, 1]), k=2),
0.61315,
)
ndcg_at1, ndcg_at3, ndcg_at7 = metrics.ndcg(
torch.tensor([6, 5, 4, 3, 2, 1, 0]),
torch.tensor([0, 0, 0, 0, 0, 0, 1]),
topk=(1, 3, 7),
)
assert torch.isclose(ndcg_at1, torch.tensor(0.0))
assert torch.isclose(ndcg_at3, torch.tensor(0.0))
assert torch.isclose(ndcg_at7, torch.tensor(3.0))

# check 1: ordering invariance
assert np.allclose(
metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([0, 0, 1])),
metrics.ndcg(torch.tensor([0, 1, 2]), torch.tensor([1, 0, 0])),
)
assert np.allclose(
metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([0, 0, 1])),
metrics.ndcg(torch.tensor([2, 0, 1]), torch.tensor([0, 1, 0])),
)

# check2: zero ndcg
assert (
metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([0, 0, 0])) == 0.0
def test_ndcg_ordering_invariance():
"""
Tests for catalyst.utils.metrics.ndcg metric.
"""
[in_order] = metrics.ndcg(
torch.tensor([2, 1, 0]), torch.tensor([1, 0, 0]), topk=(1,)
)
[first_last] = metrics.ndcg(
torch.tensor([0, 1, 2]), torch.tensor([0, 0, 1]), topk=(1,)
)
assert (
metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([0, 0, 1]), k=2)
== 0.0
[first_middle] = metrics.ndcg(
torch.tensor([1, 2, 0]), torch.tensor([0, 1, 0]), topk=(1,)
)
assert torch.isclose(in_order, first_last)
assert torch.isclose(first_last, first_middle)

0 comments on commit fa1c898

Please sign in to comment.