From 948bbb78f5cdf1460d69d119eca80d760060f3fe Mon Sep 17 00:00:00 2001 From: Cengiz Oztireli Date: Wed, 15 Dec 2021 10:46:17 -0800 Subject: [PATCH] Adds typing information to the nn.metric package. PiperOrigin-RevId: 416596386 --- tensorflow_graphics/nn/metric/fscore.py | 13 ++++++++----- .../nn/metric/intersection_over_union.py | 9 +++++---- tensorflow_graphics/nn/metric/precision.py | 15 +++++++++------ tensorflow_graphics/nn/metric/recall.py | 15 +++++++++------ 4 files changed, 31 insertions(+), 21 deletions(-) diff --git a/tensorflow_graphics/nn/metric/fscore.py b/tensorflow_graphics/nn/metric/fscore.py index d979eee3e..8bb696c39 100644 --- a/tensorflow_graphics/nn/metric/fscore.py +++ b/tensorflow_graphics/nn/metric/fscore.py @@ -17,6 +17,8 @@ from __future__ import division from __future__ import print_function +from typing import Any, Callable + import tensorflow as tf from tensorflow_graphics.nn.metric import precision as precision_module @@ -24,13 +26,14 @@ from tensorflow_graphics.util import export_api from tensorflow_graphics.util import safe_ops from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def evaluate(ground_truth, - prediction, - precision_function=precision_module.evaluate, - recall_function=recall_module.evaluate, - name="fscore_evaluate"): +def evaluate(ground_truth: type_alias.TensorLike, + prediction: type_alias.TensorLike, + precision_function: Callable[..., Any] = precision_module.evaluate, + recall_function: Callable[..., Any] = recall_module.evaluate, + name: str = "fscore_evaluate") -> tf.Tensor: """Computes the fscore metric for the given ground truth and predicted labels. The fscore is calculated as 2 * (precision * recall) / (precision + recall) diff --git a/tensorflow_graphics/nn/metric/intersection_over_union.py b/tensorflow_graphics/nn/metric/intersection_over_union.py index 5b1896649..1fe718057 100644 --- a/tensorflow_graphics/nn/metric/intersection_over_union.py +++ b/tensorflow_graphics/nn/metric/intersection_over_union.py @@ -23,12 +23,13 @@ from tensorflow_graphics.util import asserts from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def evaluate(ground_truth_labels, - predicted_labels, - grid_size=1, - name="intersection_over_union_evaluate"): +def evaluate(ground_truth_labels: type_alias.TensorLike, + predicted_labels: type_alias.TensorLike, + grid_size: int = 1, + name: str = "intersection_over_union_evaluate") -> tf.Tensor: """Computes the Intersection-Over-Union metric for the given ground truth and predicted labels. Note: diff --git a/tensorflow_graphics/nn/metric/precision.py b/tensorflow_graphics/nn/metric/precision.py index 913801456..814a82fb1 100644 --- a/tensorflow_graphics/nn/metric/precision.py +++ b/tensorflow_graphics/nn/metric/precision.py @@ -17,23 +17,26 @@ from __future__ import division from __future__ import print_function +from typing import Any, Callable, List, Optional, Union, Tuple + import tensorflow as tf from tensorflow_graphics.util import export_api from tensorflow_graphics.util import safe_ops from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias def _cast_to_int(prediction): return tf.cast(x=prediction, dtype=tf.int32) -def evaluate(ground_truth, - prediction, - classes=None, - reduce_average=True, - prediction_to_category_function=_cast_to_int, - name="precision_evaluate"): +def evaluate(ground_truth: type_alias.TensorLike, + prediction: type_alias.TensorLike, + classes: Optional[Union[int, List[int], Tuple[int]]] = None, + reduce_average: bool = True, + prediction_to_category_function: Callable[..., Any] = _cast_to_int, + name: str = "precision_evaluate") -> tf.Tensor: """Computes the precision metric for the given ground truth and predictions. Note: diff --git a/tensorflow_graphics/nn/metric/recall.py b/tensorflow_graphics/nn/metric/recall.py index 250f9e544..7e82be870 100644 --- a/tensorflow_graphics/nn/metric/recall.py +++ b/tensorflow_graphics/nn/metric/recall.py @@ -17,23 +17,26 @@ from __future__ import division from __future__ import print_function +from typing import Any, Callable, List, Optional, Tuple, Union + import tensorflow as tf from tensorflow_graphics.util import export_api from tensorflow_graphics.util import safe_ops from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias def _cast_to_int(prediction): return tf.cast(x=prediction, dtype=tf.int32) -def evaluate(ground_truth, - prediction, - classes=None, - reduce_average=True, - prediction_to_category_function=_cast_to_int, - name="recall_evaluate"): +def evaluate(ground_truth: type_alias.TensorLike, + prediction: type_alias.TensorLike, + classes: Optional[Union[int, List[int], Tuple[int]]] = None, + reduce_average: bool = True, + prediction_to_category_function: Callable[..., Any] = _cast_to_int, + name: str = "recall_evaluate") -> tf.Tensor: """Computes the recall metric for the given ground truth and predictions. Note: