Skip to content

Commit

Permalink
Adds typing information to the nn.loss package.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 420046706
  • Loading branch information
G4G authored and Copybara-Service committed Jan 6, 2022
1 parent bab2352 commit 3b030f8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion tensorflow_graphics/nn/loss/chamfer_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(point_set_a, point_set_b, name="chamfer_distance_evaluate"):
def evaluate(point_set_a: type_alias.TensorLike,
point_set_b: type_alias.TensorLike,
name: str = "chamfer_distance_evaluate") -> tf.Tensor:
"""Computes the Chamfer distance for the given two point sets.
Note:
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_graphics/nn/loss/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def evaluate(point_set_a, point_set_b, name="hausdorff_distance_evaluate"):
def evaluate(point_set_a: type_alias.TensorLike,
point_set_b: type_alias.TensorLike,
name: str = "hausdorff_distance_evaluate") -> tf.Tensor:
"""Computes the Hausdorff distance from point_set_a to point_set_b.
Note:
Expand Down

0 comments on commit 3b030f8

Please sign in to comment.