From b9f679659720f9f1c074d5f09de0e77904db5d6c Mon Sep 17 00:00:00 2001 From: Robert Turnbull Date: Mon, 29 Apr 2024 15:36:46 +1000 Subject: [PATCH] :zap: adding class based greedy accuracy metric --- hierarchicalsoftmax/metrics.py | 16 ++++++++++++++++ tests/test_metrics.py | 16 ++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/hierarchicalsoftmax/metrics.py b/hierarchicalsoftmax/metrics.py index 4b9ec79..8e1689f 100644 --- a/hierarchicalsoftmax/metrics.py +++ b/hierarchicalsoftmax/metrics.py @@ -137,3 +137,19 @@ def greedy_accuracy_parent(prediction_tensor, target_tensor, root, max_depth=Non return (prediction_parent_ids.to(target_parent_ids.device) == target_parent_ids).float().mean() +class GreedyAccuracy(): + name:str = "greedy" + + def __init__(self, root:nodes.SoftmaxNode, name="greedy_accuracy", max_depth=None): + self.max_depth = max_depth + self.name = name + self.root = root + + @property + def __name__(self): + """ For using as a FastAI metric. """ + return self.name + + def __call__(self, predictions, targets): + return greedy_accuracy(predictions, targets, self.root, max_depth=self.max_depth) + diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 5ff7ab7..6bc9c01 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -8,6 +8,7 @@ greedy_accuracy_parent, greedy_precision, greedy_recall, + GreedyAccuracy, ) from torch.testing import assert_allclose @@ -33,6 +34,9 @@ def test_greedy_accuracy(): assert_allclose(greedy_accuracy(predictions, target_tensor, root=root), 0.75) + metric = GreedyAccuracy(root=root) + assert_allclose(metric(predictions, target_tensor), 0.75) + def test_greedy_f1_score(): root, targets = depth_two_tree_and_targets_three_children() @@ -119,6 +123,18 @@ def test_greedy_accuracy_max_depth_simple(): assert greedy_accuracy_depth_two(predictions_rearranged, target_tensor, root=root) < 0.01 assert greedy_accuracy(predictions_rearranged, target_tensor, root=root) < 0.01 + depth_one = GreedyAccuracy(root=root, max_depth=1, name="depth_one") + assert 0.99 < depth_one(predictions_rearranged, target_tensor) + depth_two = GreedyAccuracy(root=root, max_depth=2, name="depth_two") + assert depth_two(predictions_rearranged, target_tensor) < 0.01 + + assert depth_one.name == "depth_one" + assert depth_one.__name__ == "depth_one" + assert depth_two.name == "depth_two" + assert depth_two.__name__ == "depth_two" + + + def test_greedy_accuracy_max_depth_complex(): root, targets = depth_three_tree_and_targets()