Skip to content

Commit

Permalink
⚡ adding class based greedy accuracy metric
Browse files Browse the repository at this point in the history
  • Loading branch information
rbturnbull committed Apr 29, 2024
1 parent 6fa945e commit b9f6796
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
16 changes: 16 additions & 0 deletions hierarchicalsoftmax/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,19 @@ def greedy_accuracy_parent(prediction_tensor, target_tensor, root, max_depth=Non
return (prediction_parent_ids.to(target_parent_ids.device) == target_parent_ids).float().mean()


class GreedyAccuracy():
name:str = "greedy"

def __init__(self, root:nodes.SoftmaxNode, name="greedy_accuracy", max_depth=None):
self.max_depth = max_depth
self.name = name
self.root = root

@property
def __name__(self):
""" For using as a FastAI metric. """
return self.name

def __call__(self, predictions, targets):
return greedy_accuracy(predictions, targets, self.root, max_depth=self.max_depth)

16 changes: 16 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
greedy_accuracy_parent,
greedy_precision,
greedy_recall,
GreedyAccuracy,
)
from torch.testing import assert_allclose

Expand All @@ -33,6 +34,9 @@ def test_greedy_accuracy():

assert_allclose(greedy_accuracy(predictions, target_tensor, root=root), 0.75)

metric = GreedyAccuracy(root=root)
assert_allclose(metric(predictions, target_tensor), 0.75)


def test_greedy_f1_score():
root, targets = depth_two_tree_and_targets_three_children()
Expand Down Expand Up @@ -119,6 +123,18 @@ def test_greedy_accuracy_max_depth_simple():
assert greedy_accuracy_depth_two(predictions_rearranged, target_tensor, root=root) < 0.01
assert greedy_accuracy(predictions_rearranged, target_tensor, root=root) < 0.01

depth_one = GreedyAccuracy(root=root, max_depth=1, name="depth_one")
assert 0.99 < depth_one(predictions_rearranged, target_tensor)
depth_two = GreedyAccuracy(root=root, max_depth=2, name="depth_two")
assert depth_two(predictions_rearranged, target_tensor) < 0.01

assert depth_one.name == "depth_one"
assert depth_one.__name__ == "depth_one"
assert depth_two.name == "depth_two"
assert depth_two.__name__ == "depth_two"




def test_greedy_accuracy_max_depth_complex():
root, targets = depth_three_tree_and_targets()
Expand Down

0 comments on commit b9f6796

Please sign in to comment.