Skip to content

Commit

Permalink
⚡ adding metric to test the accuracy of getting the correct parent
Browse files Browse the repository at this point in the history
  • Loading branch information
rbturnbull committed Jun 9, 2023
1 parent b880d77 commit 8a59f05
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 3 deletions.
29 changes: 27 additions & 2 deletions hierarchicalsoftmax/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def greedy_accuracy(prediction_tensor, target_tensor, root, max_depth=None):
prediction_node_ids = root.get_node_ids_tensor(prediction_nodes)

if max_depth:
target_node_max_depths = [root.node_list[target].path[:max_depth+1][-1] for target in target_tensor]
target_tensor = root.get_node_ids_tensor(target_node_max_depths)
max_depth_target_nodes = [root.node_list[target].path[:max_depth+1][-1] for target in target_tensor]
target_tensor = root.get_node_ids_tensor(max_depth_target_nodes)

return (prediction_node_ids.to(target_tensor.device) == target_tensor).float().mean()

Expand Down Expand Up @@ -58,3 +58,28 @@ def greedy_f1_score(prediction_tensor:torch.Tensor, target_tensor:torch.Tensor,

return f1_score(target_tensor.cpu(), prediction_node_ids.cpu(), average=average)


def greedy_accuracy_parent(prediction_tensor, target_tensor, root, max_depth=None):
"""
Gives the accuracy of predicting the parent of the target in a hierarchy tree.
Predictions use the `greedy` method which means that it chooses the greatest prediction score at each level of the tree.
Args:
prediction_tensor (torch.Tensor): A tensor with the raw scores for each node in the tree. Shape: (samples, root.layer_size)
target_tensor (torch.Tensor): A tensor with the target node indexes. Shape: (samples,).
root (SoftmaxNode): The root of the hierarchy tree.
Returns:
float: The accuracy value (i.e. the number that are correct divided by the total number of samples)
"""
prediction_nodes = inference.greedy_predictions(prediction_tensor=prediction_tensor, root=root, max_depth=max_depth)
prediction_parents = [node.parent for node in prediction_nodes]
prediction_parent_ids = root.get_node_ids_tensor(prediction_parents)

target_parents = [root.node_list[target].parent for target in target_tensor]
target_parent_ids = root.get_node_ids_tensor(target_parents)

return (prediction_parent_ids.to(target_parent_ids.device) == target_parent_ids).float().mean()


29 changes: 28 additions & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import torch
from hierarchicalsoftmax.nodes import SoftmaxNode
from hierarchicalsoftmax.metrics import greedy_accuracy, greedy_f1_score, greedy_accuracy_depth_one, greedy_accuracy_depth_two
from hierarchicalsoftmax.metrics import (
greedy_accuracy,
greedy_f1_score,
greedy_accuracy_depth_one,
greedy_accuracy_depth_two,
greedy_accuracy_parent,
)

from .util import depth_two_tree_and_targets, depth_three_tree_and_targets

Expand Down Expand Up @@ -97,3 +103,24 @@ def test_greedy_accuracy_max_depth_complex():
assert greedy_accuracy_depth_two(predictions_rearranged, target_tensor, root=root) < 0.01
assert greedy_accuracy(predictions_rearranged, target_tensor, root=root) < 0.01


def test_greedy_accuracy_parent():
root, targets = depth_three_tree_and_targets()

root.set_indexes()
target_tensor = root.get_node_ids_tensor(targets)

# set up predictions
prediction_nodes = targets.copy()
aaa, aab, aba, abb, baa, bab, bba, bbb = targets
prediction_nodes[0] = aab # correct parent
prediction_nodes[7] = bba # correct parent
prediction_nodes[1] = aba # incorrect parent

predictions = torch.zeros( (len(prediction_nodes), root.layer_size) )
for prediction_index, prediction in enumerate(prediction_nodes):
while prediction.parent:
predictions[ prediction_index, prediction.parent.softmax_start_index + prediction.index_in_parent ] = 20.0
prediction = prediction.parent

assert 0.874 < greedy_accuracy_parent(predictions, target_tensor, root=root) < 0.876

0 comments on commit 8a59f05

Please sign in to comment.