Skip to content

Commit

Permalink
Adds typing information to the module util.safe_ops.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 424056349
  • Loading branch information
G4G authored and Copybara-Service committed Jan 25, 2022
1 parent 8699795 commit 70d95be
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 13 deletions.
4 changes: 2 additions & 2 deletions tensorflow_graphics/opensource_only.files
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
tensorflow_graphics/rendering/kernels/BUILD
tensorflow_graphics/rendering/opengl/BUILD
tensorflow_graphics/rendering/kernels/BUILD:
tensorflow_graphics/rendering/opengl/BUILD:
44 changes: 33 additions & 11 deletions tensorflow_graphics/util/safe_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@
from __future__ import division
from __future__ import print_function

from typing import Optional

import numpy as np
import tensorflow as tf

from tensorflow_graphics.util import asserts
from tensorflow_graphics.util import type_alias


def nonzero_sign(x, name='nonzero_sign'):
def nonzero_sign(
x: type_alias.TensorLike,
name: str = 'nonzero_sign') -> tf.Tensor:
"""Returns the sign of x with sign(0) defined as 1 instead of 0."""
with tf.name_scope(name):
x = tf.convert_to_tensor(value=x)
Expand All @@ -40,7 +45,11 @@ def nonzero_sign(x, name='nonzero_sign'):
return tf.where(tf.greater_equal(x, 0.0), one, -one)


def safe_cospx_div_cosx(theta, factor, eps=None, name='safe_cospx_div_cosx'):
def safe_cospx_div_cosx(
theta: type_alias.TensorLike,
factor: type_alias.TensorLike,
eps: Optional[type_alias.Float] = None,
name: str = 'safe_cospx_div_cosx') -> tf.Tensor:
"""Calculates cos(factor * theta)/cos(theta) safely.
The term `cos(factor * theta)/cos(theta)` has periodic edge cases with
Expand Down Expand Up @@ -84,12 +93,13 @@ def safe_cospx_div_cosx(theta, factor, eps=None, name='safe_cospx_div_cosx'):
return asserts.assert_no_infs_or_nans(div)


def safe_shrink(vector,
minval=None,
maxval=None,
open_bounds=False,
eps=None,
name='safe_shrink'):
def safe_shrink(
vector: type_alias.TensorLike,
minval: Optional[type_alias.TensorLike] = None,
maxval: Optional[type_alias.TensorLike] = None,
open_bounds: bool = False,
eps: Optional[type_alias.Float] = None,
name: str = 'safe_shrink') -> tf.Tensor:
"""Shrinks vector by (1.0 - eps) based on its dtype.
This function shrinks the input vector by a very small amount to ensure that
Expand Down Expand Up @@ -141,7 +151,11 @@ def safe_shrink(vector,
return vector


def safe_signed_div(a, b, eps=None, name='safe_signed_div'):
def safe_signed_div(
a: type_alias.TensorLike,
b: type_alias.TensorLike,
eps: Optional[type_alias.Float] = None,
name: str = 'safe_signed_div') -> tf.Tensor:
"""Calculates a/b safely.
If the tf-graphics debug flag is set to `True`, this function adds assertions
Expand Down Expand Up @@ -177,7 +191,11 @@ def safe_signed_div(a, b, eps=None, name='safe_signed_div'):
return asserts.assert_no_infs_or_nans(a / (b + nonzero_sign(b) * eps))


def safe_sinpx_div_sinx(theta, factor, eps=None, name='safe_sinpx_div_sinx'):
def safe_sinpx_div_sinx(
theta: type_alias.TensorLike,
factor: type_alias.TensorLike,
eps: Optional[type_alias.Float] = None,
name: str = 'safe_sinpx_div_sinx') -> tf.Tensor:
"""Calculates sin(factor * theta)/sin(theta) safely.
The term `sin(factor * theta)/sin(theta)` appears when calculating spherical
Expand Down Expand Up @@ -221,7 +239,11 @@ def safe_sinpx_div_sinx(theta, factor, eps=None, name='safe_sinpx_div_sinx'):
return asserts.assert_no_infs_or_nans(div)


def safe_unsigned_div(a, b, eps=None, name='safe_unsigned_div'):
def safe_unsigned_div(
a: type_alias.TensorLike,
b: type_alias.TensorLike,
eps: Optional[type_alias.Float] = None,
name: str = 'safe_unsigned_div') -> tf.Tensor:
"""Calculates a/b with b >= 0 safely.
If the tfg debug flag TFG_ADD_ASSERTS_TO_GRAPH defined in tfg_flags.py
Expand Down

0 comments on commit 70d95be

Please sign in to comment.