From fa1c898803c693993b3f9f74abf904ab50f33902 Mon Sep 17 00:00:00 2001 From: Alexey Grishanov Date: Thu, 6 Aug 2020 19:06:33 +0300 Subject: [PATCH] Update ndcg tests --- catalyst/utils/metrics/tests/test_ndcg.py | 57 +++++++++-------------- 1 file changed, 22 insertions(+), 35 deletions(-) diff --git a/catalyst/utils/metrics/tests/test_ndcg.py b/catalyst/utils/metrics/tests/test_ndcg.py index ced4840059..0f6313fd83 100644 --- a/catalyst/utils/metrics/tests/test_ndcg.py +++ b/catalyst/utils/metrics/tests/test_ndcg.py @@ -1,47 +1,34 @@ -import numpy as np - import torch from catalyst.utils import metrics -def test_ndcg(): +def test_zero_ndcg(): """ Tests for catalyst.utils.metrics.ndcg metric. """ - # check 0: common values - assert ( - metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([1, 0, 0])) == 1.0 - ) - assert np.allclose( - metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([0, 1, 0])), 0.63093 - ) - assert np.allclose( - metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([0, 0, 1])), 0.5 - ) - assert np.allclose( - metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([1, 0, 1])), 0.91972 - ) - assert np.allclose( - metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([1, 0, 1]), k=2), - 0.61315, - ) + ndcg_at1, ndcg_at3, ndcg_at7 = metrics.ndcg( + torch.tensor([6, 5, 4, 3, 2, 1, 0]), + torch.tensor([0, 0, 0, 0, 0, 0, 1]), + topk=(1, 3, 7), + ) + assert torch.isclose(ndcg_at1, torch.tensor(0.0)) + assert torch.isclose(ndcg_at3, torch.tensor(0.0)) + assert torch.isclose(ndcg_at7, torch.tensor(3.0)) - # check 1: ordering invariance - assert np.allclose( - metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([0, 0, 1])), - metrics.ndcg(torch.tensor([0, 1, 2]), torch.tensor([1, 0, 0])), - ) - assert np.allclose( - metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([0, 0, 1])), - metrics.ndcg(torch.tensor([2, 0, 1]), torch.tensor([0, 1, 0])), - ) - # check2: zero ndcg - assert ( - metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([0, 0, 0])) == 0.0 +def test_ndcg_ordering_invariance(): + """ + Tests for catalyst.utils.metrics.ndcg metric. + """ + [in_order] = metrics.ndcg( + torch.tensor([2, 1, 0]), torch.tensor([1, 0, 0]), topk=(1,) + ) + [first_last] = metrics.ndcg( + torch.tensor([0, 1, 2]), torch.tensor([0, 0, 1]), topk=(1,) ) - assert ( - metrics.ndcg(torch.tensor([2, 1, 0]), torch.tensor([0, 0, 1]), k=2) - == 0.0 + [first_middle] = metrics.ndcg( + torch.tensor([1, 2, 0]), torch.tensor([0, 1, 0]), topk=(1,) ) + assert torch.isclose(in_order, first_last) + assert torch.isclose(first_last, first_middle)