Skip to content

Commit

Permalink
Adds typing information to the nn.metric package.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 416596386
  • Loading branch information
G4G authored and Copybara-Service committed Jan 6, 2022
1 parent 3b030f8 commit 948bbb7
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 21 deletions.
13 changes: 8 additions & 5 deletions tensorflow_graphics/nn/metric/fscore.py
Expand Up @@ -17,20 +17,23 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable

import tensorflow as tf

from tensorflow_graphics.nn.metric import precision as precision_module
from tensorflow_graphics.nn.metric import recall as recall_module
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(ground_truth,
prediction,
precision_function=precision_module.evaluate,
recall_function=recall_module.evaluate,
name="fscore_evaluate"):
def evaluate(ground_truth: type_alias.TensorLike,
prediction: type_alias.TensorLike,
precision_function: Callable[..., Any] = precision_module.evaluate,
recall_function: Callable[..., Any] = recall_module.evaluate,
name: str = "fscore_evaluate") -> tf.Tensor:
"""Computes the fscore metric for the given ground truth and predicted labels.
The fscore is calculated as 2 * (precision * recall) / (precision + recall)
Expand Down
9 changes: 5 additions & 4 deletions tensorflow_graphics/nn/metric/intersection_over_union.py
Expand Up @@ -23,12 +23,13 @@
from tensorflow_graphics.util import asserts
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(ground_truth_labels,
predicted_labels,
grid_size=1,
name="intersection_over_union_evaluate"):
def evaluate(ground_truth_labels: type_alias.TensorLike,
predicted_labels: type_alias.TensorLike,
grid_size: int = 1,
name: str = "intersection_over_union_evaluate") -> tf.Tensor:
"""Computes the Intersection-Over-Union metric for the given ground truth and predicted labels.
Note:
Expand Down
15 changes: 9 additions & 6 deletions tensorflow_graphics/nn/metric/precision.py
Expand Up @@ -17,23 +17,26 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable, List, Optional, Union, Tuple

import tensorflow as tf

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def _cast_to_int(prediction):
return tf.cast(x=prediction, dtype=tf.int32)


def evaluate(ground_truth,
prediction,
classes=None,
reduce_average=True,
prediction_to_category_function=_cast_to_int,
name="precision_evaluate"):
def evaluate(ground_truth: type_alias.TensorLike,
prediction: type_alias.TensorLike,
classes: Optional[Union[int, List[int], Tuple[int]]] = None,
reduce_average: bool = True,
prediction_to_category_function: Callable[..., Any] = _cast_to_int,
name: str = "precision_evaluate") -> tf.Tensor:
"""Computes the precision metric for the given ground truth and predictions.
Note:
Expand Down
15 changes: 9 additions & 6 deletions tensorflow_graphics/nn/metric/recall.py
Expand Up @@ -17,23 +17,26 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable, List, Optional, Tuple, Union

import tensorflow as tf

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def _cast_to_int(prediction):
return tf.cast(x=prediction, dtype=tf.int32)


def evaluate(ground_truth,
prediction,
classes=None,
reduce_average=True,
prediction_to_category_function=_cast_to_int,
name="recall_evaluate"):
def evaluate(ground_truth: type_alias.TensorLike,
prediction: type_alias.TensorLike,
classes: Optional[Union[int, List[int], Tuple[int]]] = None,
reduce_average: bool = True,
prediction_to_category_function: Callable[..., Any] = _cast_to_int,
name: str = "recall_evaluate") -> tf.Tensor:
"""Computes the recall metric for the given ground truth and predictions.
Note:
Expand Down

0 comments on commit 948bbb7

Please sign in to comment.