Skip to content

Commit

Permalink
Adds typing information to the nn.metric package.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 416596386
  • Loading branch information
G4G authored and Copybara-Service committed Dec 15, 2021
1 parent 75fe1a1 commit b8400e1
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 34 deletions.
44 changes: 33 additions & 11 deletions tensorflow_graphics/nn/layer/pointnet.py
Expand Up @@ -38,6 +38,8 @@
`C`: The number of feature channels.
"""

from typing import Optional

import tensorflow as tf
from tensorflow_graphics.util import export_api

Expand All @@ -62,13 +64,15 @@ def __init__(self, channels, momentum):
self.channels = channels
self.momentum = momentum

def build(self, input_shape):
def build(self, input_shape: tf.Tensor):
"""Builds the layer with a specified input_shape."""
self.conv = tf.keras.layers.Conv2D(
self.channels, (1, 1), input_shape=input_shape)
self.bn = tf.keras.layers.BatchNormalization(momentum=self.momentum)

def call(self, inputs, training=None): # pylint: disable=arguments-differ
def call(self,
inputs: tf.Tensor,
training: Optional[bool] = None) -> tf.Tensor: # pylint: disable=arguments-differ
"""Executes the convolution.
Args:
Expand Down Expand Up @@ -96,12 +100,14 @@ def __init__(self, channels, momentum):
self.momentum = momentum
self.channels = channels

def build(self, input_shape):
def build(self, input_shape: tf.Tensor):
"""Builds the layer with a specified input_shape."""
self.dense = tf.keras.layers.Dense(self.channels, input_shape=input_shape)
self.bn = tf.keras.layers.BatchNormalization(momentum=self.momentum)

def call(self, inputs, training=None): # pylint: disable=arguments-differ
def call(self,
inputs: tf.Tensor,
training: Optional[bool] = None) -> tf.Tensor: # pylint: disable=arguments-differ
"""Executes the convolution.
Args:
Expand All @@ -125,7 +131,7 @@ class VanillaEncoder(tf.keras.layers.Layer):
https://github.com/charlesq34/pointnet/blob/master/models/pointnet_cls_basic.py
"""

def __init__(self, momentum=.5):
def __init__(self, momentum: float = .5):
"""Constructs a VanillaEncoder keras layer.
Args:
Expand All @@ -138,7 +144,9 @@ def __init__(self, momentum=.5):
self.conv4 = PointNetConv2Layer(128, momentum)
self.conv5 = PointNetConv2Layer(1024, momentum)

def call(self, inputs, training=None): # pylint: disable=arguments-differ
def call(self,
inputs: tf.Tensor,
training: Optional[bool] = None) -> tf.Tensor: # pylint: disable=arguments-differ
"""Computes the PointNet features.
Args:
Expand Down Expand Up @@ -166,7 +174,10 @@ class ClassificationHead(tf.keras.layers.Layer):
logits of the num_classes classes.
"""

def __init__(self, num_classes=40, momentum=0.5, dropout_rate=0.3):
def __init__(self,
num_classes: int = 40,
momentum: float = 0.5,
dropout_rate: float = 0.3):
"""Constructor.
Args:
Expand All @@ -180,7 +191,9 @@ def __init__(self, num_classes=40, momentum=0.5, dropout_rate=0.3):
self.dropout = tf.keras.layers.Dropout(dropout_rate)
self.dense3 = tf.keras.layers.Dense(num_classes, activation="linear")

def call(self, inputs, training=None): # pylint: disable=arguments-differ
def call(self,
inputs: tf.Tensor,
training: Optional[bool] = None) -> tf.Tensor: # pylint: disable=arguments-differ
"""Computes the classifiation logits given features (note: without softmax).
Args:
Expand All @@ -199,7 +212,10 @@ def call(self, inputs, training=None): # pylint: disable=arguments-differ
class PointNetVanillaClassifier(tf.keras.layers.Layer):
"""The PointNet 'Vanilla' classifier (i.e. without spatial transformer)."""

def __init__(self, num_classes=40, momentum=.5, dropout_rate=.3):
def __init__(self,
num_classes: int = 40,
momentum: float = .5,
dropout_rate: float = .3):
"""Constructor.
Args:
Expand All @@ -212,7 +228,9 @@ def __init__(self, num_classes=40, momentum=.5, dropout_rate=.3):
self.classifier = ClassificationHead(
num_classes=num_classes, momentum=momentum, dropout_rate=dropout_rate)

def call(self, points, training=None): # pylint: disable=arguments-differ
def call(self,
points: tf.Tensor,
training: Optional[bool] = None) -> tf.Tensor: # pylint: disable=arguments-differ
"""Computes the classifiation logits of a point set.
Args:
Expand All @@ -227,7 +245,8 @@ def call(self, points, training=None): # pylint: disable=arguments-differ
return logits

@staticmethod
def loss(labels, logits):
def loss(labels: tf.Tensor,
logits: tf.Tensor) -> tf.Tensor:
"""The classification model training loss.
Note:
Expand All @@ -236,6 +255,9 @@ def loss(labels, logits):
Args:
labels: a tensor with shape `[B,]`
logits: a tensor with shape `[B,num_classes]`
Returns:
A tensor with the same shape as labels and of the same type as logits.
"""
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits
residual = cross_entropy(labels, logits)
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_graphics/nn/loss/chamfer_distance.py
Expand Up @@ -21,9 +21,12 @@

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(point_set_a, point_set_b, name="chamfer_distance_evaluate"):
def evaluate(point_set_a: type_alias.TensorLike,
point_set_b: type_alias.TensorLike,
name: str = "chamfer_distance_evaluate") -> tf.Tensor:
"""Computes the Chamfer distance for the given two point sets.
Note:
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_graphics/nn/loss/hausdorff_distance.py
Expand Up @@ -21,9 +21,12 @@

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(point_set_a, point_set_b, name="hausdorff_distance_evaluate"):
def evaluate(point_set_a: type_alias.TensorLike,
point_set_b: type_alias.TensorLike,
name: str = "hausdorff_distance_evaluate") -> tf.Tensor:
"""Computes the Hausdorff distance from point_set_a to point_set_b.
Note:
Expand Down
13 changes: 8 additions & 5 deletions tensorflow_graphics/nn/metric/fscore.py
Expand Up @@ -17,20 +17,23 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable

import tensorflow as tf

from tensorflow_graphics.nn.metric import precision as precision_module
from tensorflow_graphics.nn.metric import recall as recall_module
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(ground_truth,
prediction,
precision_function=precision_module.evaluate,
recall_function=recall_module.evaluate,
name="fscore_evaluate"):
def evaluate(ground_truth: type_alias.TensorLike,
prediction: type_alias.TensorLike,
precision_function: Callable[..., Any] = precision_module.evaluate,
recall_function: Callable[..., Any] = recall_module.evaluate,
name: str = "fscore_evaluate"):
"""Computes the fscore metric for the given ground truth and predicted labels.
The fscore is calculated as 2 * (precision * recall) / (precision + recall)
Expand Down
9 changes: 5 additions & 4 deletions tensorflow_graphics/nn/metric/intersection_over_union.py
Expand Up @@ -23,12 +23,13 @@
from tensorflow_graphics.util import asserts
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(ground_truth_labels,
predicted_labels,
grid_size=1,
name="intersection_over_union_evaluate"):
def evaluate(ground_truth_labels: type_alias.TensorLike,
predicted_labels: type_alias.TensorLike,
grid_size: int = 1,
name: str = "intersection_over_union_evaluate"):
"""Computes the Intersection-Over-Union metric for the given ground truth and predicted labels.
Note:
Expand Down
15 changes: 9 additions & 6 deletions tensorflow_graphics/nn/metric/precision.py
Expand Up @@ -17,23 +17,26 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable, List, Optional, Union, Tuple

import tensorflow as tf

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def _cast_to_int(prediction):
return tf.cast(x=prediction, dtype=tf.int32)


def evaluate(ground_truth,
prediction,
classes=None,
reduce_average=True,
prediction_to_category_function=_cast_to_int,
name="precision_evaluate"):
def evaluate(ground_truth: type_alias.TensorLike,
prediction: type_alias.TensorLike,
classes: Optional[Union[int, List[int], Tuple[int]]] = None,
reduce_average: bool = True,
prediction_to_category_function: Callable[..., Any] = _cast_to_int,
name: str = "precision_evaluate"):
"""Computes the precision metric for the given ground truth and predictions.
Note:
Expand Down
15 changes: 9 additions & 6 deletions tensorflow_graphics/nn/metric/recall.py
Expand Up @@ -17,23 +17,26 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable, List, Optional, Tuple, Union

import tensorflow as tf

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def _cast_to_int(prediction):
return tf.cast(x=prediction, dtype=tf.int32)


def evaluate(ground_truth,
prediction,
classes=None,
reduce_average=True,
prediction_to_category_function=_cast_to_int,
name="recall_evaluate"):
def evaluate(ground_truth: type_alias.TensorLike,
prediction: type_alias.TensorLike,
classes: Optional[Union[int, List[int], Tuple[int]]] = None,
reduce_average: bool = True,
prediction_to_category_function: Callable[..., Any] = _cast_to_int,
name: str = "recall_evaluate"):
"""Computes the recall metric for the given ground truth and predictions.
Note:
Expand Down

0 comments on commit b8400e1

Please sign in to comment.