diff --git a/tensorflow_graphics/nn/layer/pointnet.py b/tensorflow_graphics/nn/layer/pointnet.py index 30e8ff5f1..ca9a2ea81 100644 --- a/tensorflow_graphics/nn/layer/pointnet.py +++ b/tensorflow_graphics/nn/layer/pointnet.py @@ -38,6 +38,8 @@ `C`: The number of feature channels. """ +from typing import Optional + import tensorflow as tf from tensorflow_graphics.util import export_api @@ -62,13 +64,15 @@ def __init__(self, channels, momentum): self.channels = channels self.momentum = momentum - def build(self, input_shape): + def build(self, input_shape: tf.Tensor): """Builds the layer with a specified input_shape.""" self.conv = tf.keras.layers.Conv2D( self.channels, (1, 1), input_shape=input_shape) self.bn = tf.keras.layers.BatchNormalization(momentum=self.momentum) - def call(self, inputs, training=None): # pylint: disable=arguments-differ + def call(self, + inputs: tf.Tensor, + training: Optional[bool] = None) -> tf.Tensor: # pylint: disable=arguments-differ """Executes the convolution. Args: @@ -96,12 +100,14 @@ def __init__(self, channels, momentum): self.momentum = momentum self.channels = channels - def build(self, input_shape): + def build(self, input_shape: tf.Tensor): """Builds the layer with a specified input_shape.""" self.dense = tf.keras.layers.Dense(self.channels, input_shape=input_shape) self.bn = tf.keras.layers.BatchNormalization(momentum=self.momentum) - def call(self, inputs, training=None): # pylint: disable=arguments-differ + def call(self, + inputs: tf.Tensor, + training: Optional[bool] = None) -> tf.Tensor: # pylint: disable=arguments-differ """Executes the convolution. Args: @@ -125,7 +131,7 @@ class VanillaEncoder(tf.keras.layers.Layer): https://github.com/charlesq34/pointnet/blob/master/models/pointnet_cls_basic.py """ - def __init__(self, momentum=.5): + def __init__(self, momentum: float = .5): """Constructs a VanillaEncoder keras layer. Args: @@ -138,7 +144,9 @@ def __init__(self, momentum=.5): self.conv4 = PointNetConv2Layer(128, momentum) self.conv5 = PointNetConv2Layer(1024, momentum) - def call(self, inputs, training=None): # pylint: disable=arguments-differ + def call(self, + inputs: tf.Tensor, + training: Optional[bool] = None) -> tf.Tensor: # pylint: disable=arguments-differ """Computes the PointNet features. Args: @@ -166,7 +174,10 @@ class ClassificationHead(tf.keras.layers.Layer): logits of the num_classes classes. """ - def __init__(self, num_classes=40, momentum=0.5, dropout_rate=0.3): + def __init__(self, + num_classes: int = 40, + momentum: float = 0.5, + dropout_rate: float = 0.3): """Constructor. Args: @@ -180,7 +191,9 @@ def __init__(self, num_classes=40, momentum=0.5, dropout_rate=0.3): self.dropout = tf.keras.layers.Dropout(dropout_rate) self.dense3 = tf.keras.layers.Dense(num_classes, activation="linear") - def call(self, inputs, training=None): # pylint: disable=arguments-differ + def call(self, + inputs: tf.Tensor, + training: Optional[bool] = None) -> tf.Tensor: # pylint: disable=arguments-differ """Computes the classifiation logits given features (note: without softmax). Args: @@ -199,7 +212,10 @@ def call(self, inputs, training=None): # pylint: disable=arguments-differ class PointNetVanillaClassifier(tf.keras.layers.Layer): """The PointNet 'Vanilla' classifier (i.e. without spatial transformer).""" - def __init__(self, num_classes=40, momentum=.5, dropout_rate=.3): + def __init__(self, + num_classes: int = 40, + momentum: float = .5, + dropout_rate: float = .3): """Constructor. Args: @@ -212,7 +228,9 @@ def __init__(self, num_classes=40, momentum=.5, dropout_rate=.3): self.classifier = ClassificationHead( num_classes=num_classes, momentum=momentum, dropout_rate=dropout_rate) - def call(self, points, training=None): # pylint: disable=arguments-differ + def call(self, + points: tf.Tensor, + training: Optional[bool] = None) -> tf.Tensor: # pylint: disable=arguments-differ """Computes the classifiation logits of a point set. Args: @@ -227,7 +245,8 @@ def call(self, points, training=None): # pylint: disable=arguments-differ return logits @staticmethod - def loss(labels, logits): + def loss(labels: tf.Tensor, + logits: tf.Tensor) -> tf.Tensor: """The classification model training loss. Note: @@ -236,6 +255,9 @@ def loss(labels, logits): Args: labels: a tensor with shape `[B,]` logits: a tensor with shape `[B,num_classes]` + + Returns: + A tensor with the same shape as labels and of the same type as logits. """ cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits residual = cross_entropy(labels, logits) diff --git a/tensorflow_graphics/nn/loss/chamfer_distance.py b/tensorflow_graphics/nn/loss/chamfer_distance.py index 61c733f1f..e22e40425 100644 --- a/tensorflow_graphics/nn/loss/chamfer_distance.py +++ b/tensorflow_graphics/nn/loss/chamfer_distance.py @@ -21,9 +21,12 @@ from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def evaluate(point_set_a, point_set_b, name="chamfer_distance_evaluate"): +def evaluate(point_set_a: type_alias.TensorLike, + point_set_b: type_alias.TensorLike, + name: str = "chamfer_distance_evaluate") -> tf.Tensor: """Computes the Chamfer distance for the given two point sets. Note: diff --git a/tensorflow_graphics/nn/loss/hausdorff_distance.py b/tensorflow_graphics/nn/loss/hausdorff_distance.py index 46fb0946f..279f1bbbd 100644 --- a/tensorflow_graphics/nn/loss/hausdorff_distance.py +++ b/tensorflow_graphics/nn/loss/hausdorff_distance.py @@ -21,9 +21,12 @@ from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def evaluate(point_set_a, point_set_b, name="hausdorff_distance_evaluate"): +def evaluate(point_set_a: type_alias.TensorLike, + point_set_b: type_alias.TensorLike, + name: str = "hausdorff_distance_evaluate") -> tf.Tensor: """Computes the Hausdorff distance from point_set_a to point_set_b. Note: diff --git a/tensorflow_graphics/nn/metric/fscore.py b/tensorflow_graphics/nn/metric/fscore.py index d979eee3e..0e40b3e23 100644 --- a/tensorflow_graphics/nn/metric/fscore.py +++ b/tensorflow_graphics/nn/metric/fscore.py @@ -17,6 +17,8 @@ from __future__ import division from __future__ import print_function +from typing import Any, Callable + import tensorflow as tf from tensorflow_graphics.nn.metric import precision as precision_module @@ -24,13 +26,14 @@ from tensorflow_graphics.util import export_api from tensorflow_graphics.util import safe_ops from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def evaluate(ground_truth, - prediction, - precision_function=precision_module.evaluate, - recall_function=recall_module.evaluate, - name="fscore_evaluate"): +def evaluate(ground_truth: type_alias.TensorLike, + prediction: type_alias.TensorLike, + precision_function: Callable[..., Any] = precision_module.evaluate, + recall_function: Callable[..., Any] = recall_module.evaluate, + name: str = "fscore_evaluate"): """Computes the fscore metric for the given ground truth and predicted labels. The fscore is calculated as 2 * (precision * recall) / (precision + recall) diff --git a/tensorflow_graphics/nn/metric/intersection_over_union.py b/tensorflow_graphics/nn/metric/intersection_over_union.py index 5b1896649..1105af64e 100644 --- a/tensorflow_graphics/nn/metric/intersection_over_union.py +++ b/tensorflow_graphics/nn/metric/intersection_over_union.py @@ -23,12 +23,13 @@ from tensorflow_graphics.util import asserts from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def evaluate(ground_truth_labels, - predicted_labels, - grid_size=1, - name="intersection_over_union_evaluate"): +def evaluate(ground_truth_labels: type_alias.TensorLike, + predicted_labels: type_alias.TensorLike, + grid_size: int = 1, + name: str = "intersection_over_union_evaluate"): """Computes the Intersection-Over-Union metric for the given ground truth and predicted labels. Note: diff --git a/tensorflow_graphics/nn/metric/precision.py b/tensorflow_graphics/nn/metric/precision.py index 913801456..a8c7591bf 100644 --- a/tensorflow_graphics/nn/metric/precision.py +++ b/tensorflow_graphics/nn/metric/precision.py @@ -17,23 +17,26 @@ from __future__ import division from __future__ import print_function +from typing import Any, Callable, List, Optional, Union, Tuple + import tensorflow as tf from tensorflow_graphics.util import export_api from tensorflow_graphics.util import safe_ops from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias def _cast_to_int(prediction): return tf.cast(x=prediction, dtype=tf.int32) -def evaluate(ground_truth, - prediction, - classes=None, - reduce_average=True, - prediction_to_category_function=_cast_to_int, - name="precision_evaluate"): +def evaluate(ground_truth: type_alias.TensorLike, + prediction: type_alias.TensorLike, + classes: Optional[Union[int, List[int], Tuple[int]]] = None, + reduce_average: bool = True, + prediction_to_category_function: Callable[..., Any] = _cast_to_int, + name: str = "precision_evaluate"): """Computes the precision metric for the given ground truth and predictions. Note: diff --git a/tensorflow_graphics/nn/metric/recall.py b/tensorflow_graphics/nn/metric/recall.py index 250f9e544..283cfd844 100644 --- a/tensorflow_graphics/nn/metric/recall.py +++ b/tensorflow_graphics/nn/metric/recall.py @@ -17,23 +17,26 @@ from __future__ import division from __future__ import print_function +from typing import Any, Callable, List, Optional, Tuple, Union + import tensorflow as tf from tensorflow_graphics.util import export_api from tensorflow_graphics.util import safe_ops from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias def _cast_to_int(prediction): return tf.cast(x=prediction, dtype=tf.int32) -def evaluate(ground_truth, - prediction, - classes=None, - reduce_average=True, - prediction_to_category_function=_cast_to_int, - name="recall_evaluate"): +def evaluate(ground_truth: type_alias.TensorLike, + prediction: type_alias.TensorLike, + classes: Optional[Union[int, List[int], Tuple[int]]] = None, + reduce_average: bool = True, + prediction_to_category_function: Callable[..., Any] = _cast_to_int, + name: str = "recall_evaluate"): """Computes the recall metric for the given ground truth and predictions. Note: