Skip to content

Commit

Permalink
Adds typing information to the nn.loss package.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 413486464
  • Loading branch information
G4G authored and Copybara-Service committed Dec 1, 2021
1 parent 38771b1 commit 64f0733
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 80 deletions.
5 changes: 3 additions & 2 deletions tensorflow_graphics/math/feature_representation.py
Expand Up @@ -20,11 +20,12 @@
import tensorflow as tf

from tensorflow_graphics.util import export_api
from tensorflow_graphics.util.type_alias import TensorLike


def positional_encoding(features: tf.Tensor,
def positional_encoding(features: TensorLike,
num_frequencies: int,
name="positional_encoding") -> tf.Tensor:
name: str = "positional_encoding") -> TensorLike:
"""Positional enconding of a tensor as described in the NeRF paper (https://arxiv.org/abs/2003.08934).
Args:
Expand Down
27 changes: 14 additions & 13 deletions tensorflow_graphics/math/math_helpers.py
Expand Up @@ -19,17 +19,18 @@

import numpy as np
import tensorflow as tf

from tensorflow_graphics.util import asserts
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import safe_ops
from tensorflow_graphics.util import shape
from tensorflow_graphics.util.type_alias import Float
from tensorflow_graphics.util.type_alias import TensorLike


def cartesian_to_spherical_coordinates(point_cartesian,
eps=None,
name="cartesian_to_spherical_coordinates"
):
def cartesian_to_spherical_coordinates(
point_cartesian: TensorLike,
eps: Float = None,
name: str = "cartesian_to_spherical_coordinates") -> tf.Tensor:
"""Function to transform Cartesian coordinates to spherical coordinates.
This function assumes a right handed coordinate system with `z` pointing up.
Expand Down Expand Up @@ -77,7 +78,7 @@ def _double_factorial_loop_condition(n, result, two):
return tf.cast(tf.math.count_nonzero(tf.greater_equal(n, two)), tf.bool)


def double_factorial(n):
def double_factorial(n: TensorLike) -> TensorLike:
"""Computes the double factorial of `n`.
Note:
Expand All @@ -100,7 +101,7 @@ def double_factorial(n):
return result


def factorial(n):
def factorial(n: TensorLike) -> TensorLike:
"""Computes the factorial of `n`.
Note:
Expand All @@ -117,9 +118,9 @@ def factorial(n):
return tf.exp(tf.math.lgamma(n + 1))


def spherical_to_cartesian_coordinates(point_spherical,
name="spherical_to_cartesian_coordinates"
):
def spherical_to_cartesian_coordinates(
point_spherical: TensorLike,
name: str = "spherical_to_cartesian_coordinates") -> TensorLike:
"""Function to transform Cartesian coordinates to spherical coordinates.
Note:
Expand Down Expand Up @@ -156,9 +157,9 @@ def spherical_to_cartesian_coordinates(point_spherical,
return tf.stack((x, y, z), axis=-1)


def square_to_spherical_coordinates(point_2d,
name="math_square_to_spherical_coordinates"
):
def square_to_spherical_coordinates(
point_2d: TensorLike,
name: str = "math_square_to_spherical_coordinates") -> TensorLike:
"""Maps points from a unit square to a unit sphere.
Note:
Expand Down
47 changes: 29 additions & 18 deletions tensorflow_graphics/math/spherical_harmonics.py
Expand Up @@ -17,6 +17,8 @@
from __future__ import division
from __future__ import print_function

from typing import Tuple

import numpy as np
from six.moves import range
import tensorflow as tf
Expand All @@ -26,12 +28,14 @@
from tensorflow_graphics.util import asserts
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util.type_alias import TensorLike


def integration_product(harmonics1,
harmonics2,
keepdims=True,
name="spherical_harmonics_convolution"):
def integration_product(
harmonics1: TensorLike,
harmonics2: TensorLike,
keepdims: bool = True,
name: str = "spherical_harmonics_convolution") -> TensorLike:
"""Computes the integral of harmonics1.harmonics2 over the sphere.
Note:
Expand Down Expand Up @@ -72,7 +76,8 @@ def integration_product(harmonics1,


def generate_l_m_permutations(
max_band, name="spherical_harmonics_generate_l_m_permutations"):
max_band: int,
name: str = "spherical_harmonics_generate_l_m_permutations") -> Tuple[TensorLike, TensorLike]: # pylint: disable=line-too-long
"""Generates permutations of degree l and order m for spherical harmonics.
Args:
Expand All @@ -94,7 +99,9 @@ def generate_l_m_permutations(
tf.convert_to_tensor(value=order_m))


def generate_l_m_zonal(max_band, name="spherical_harmonics_generate_l_m_zonal"):
def generate_l_m_zonal(
max_band: int,
name: str = "spherical_harmonics_generate_l_m_zonal") -> Tuple[TensorLike, TensorLike]: # pylint: disable=line-too-long
"""Generates l and m coefficients for zonal harmonics.
Args:
Expand Down Expand Up @@ -154,7 +161,9 @@ def _evaluate_legendre_polynomial_branch(l, m, x, pmm):
return res


def evaluate_legendre_polynomial(degree_l, order_m, x):
def evaluate_legendre_polynomial(degree_l: TensorLike,
order_m: TensorLike,
x: TensorLike) -> TensorLike:
"""Evaluates the Legendre polynomial of degree l and order m at x.
Note:
Expand Down Expand Up @@ -227,11 +236,11 @@ def _evaluate_spherical_harmonics_branch(degree,


def evaluate_spherical_harmonics(
degree_l,
order_m,
theta,
phi,
name="spherical_harmonics_evaluate_spherical_harmonics"):
degree_l: TensorLike,
order_m: TensorLike,
theta: TensorLike,
phi: TensorLike,
name: str = "spherical_harmonics_evaluate_spherical_harmonics") -> TensorLike: # pylint: disable=line-too-long
"""Evaluates a point sample of a Spherical Harmonic basis function.
Note:
Expand Down Expand Up @@ -305,10 +314,11 @@ def evaluate_spherical_harmonics(
return tf.where(tf.equal(order_m, zeros), result_m_zero, result_branch)


def rotate_zonal_harmonics(zonal_coeffs,
theta,
phi,
name="spherical_harmonics_rotate_zonal_harmonics"):
def rotate_zonal_harmonics(
zonal_coeffs: TensorLike,
theta: TensorLike,
phi: TensorLike,
name: str = "spherical_harmonics_rotate_zonal_harmonics") -> TensorLike:
"""Rotates zonal harmonics.
Note:
Expand Down Expand Up @@ -356,8 +366,9 @@ def rotate_zonal_harmonics(zonal_coeffs,
l_broadcasted, m_broadcasted, theta, phi)


def tile_zonal_coefficients(coefficients,
name="spherical_harmonics_tile_zonal_coefficients"):
def tile_zonal_coefficients(
coefficients: TensorLike,
name: str = "spherical_harmonics_tile_zonal_coefficients") -> TensorLike:
"""Tiles zonal coefficients.
Zonal Harmonics only contains the harmonics where m=0. This function returns
Expand Down
17 changes: 14 additions & 3 deletions tensorflow_graphics/math/vector.py
Expand Up @@ -22,9 +22,13 @@
from tensorflow_graphics.util import asserts
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util.type_alias import TensorLike


def cross(vector1, vector2, axis=-1, name="vector_cross"):
def cross(vector1: TensorLike,
vector2: TensorLike,
axis: int = -1,
name: str = "vector_cross") -> TensorLike:
"""Computes the cross product between two tensors along an axis.
Note:
Expand Down Expand Up @@ -62,7 +66,11 @@ def cross(vector1, vector2, axis=-1, name="vector_cross"):
return tf.stack((n_x, n_y, n_z), axis=axis)


def dot(vector1, vector2, axis=-1, keepdims=True, name="vector_dot"):
def dot(vector1: TensorLike,
vector2: TensorLike,
axis: int = -1,
keepdims: bool = True,
name: str = "vector_dot") -> TensorLike:
"""Computes the dot product between two tensors along an axis.
Note:
Expand Down Expand Up @@ -97,7 +105,10 @@ def dot(vector1, vector2, axis=-1, keepdims=True, name="vector_dot"):
input_tensor=vector1 * vector2, axis=axis, keepdims=keepdims)


def reflect(vector, normal, axis=-1, name="vector_reflect"):
def reflect(vector: TensorLike,
normal: TensorLike,
axis: int = -1,
name: str = "vector_reflect") -> TensorLike:
r"""Computes the reflection direction for an incident vector.
For an incident vector \\(\mathbf{v}\\) and normal $$\mathbf{n}$$ this
Expand Down
71 changes: 39 additions & 32 deletions tensorflow_graphics/nn/layer/graph_convolution.py
Expand Up @@ -17,22 +17,26 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable, List, Optional

import tensorflow as tf

import tensorflow_graphics.geometry.convolution.graph_convolution as gc
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import type_alias


def feature_steered_convolution_layer(
data,
neighbors,
sizes,
translation_invariant=True,
num_weight_matrices=8,
num_output_channels=None,
initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1),
name='graph_convolution_feature_steered_convolution',
var_name=None):
data: type_alias.TensorLike,
neighbors: tf.SparseTensor,
sizes: type_alias.TensorLike,
translation_invariant: bool = True,
num_weight_matrices: int = 8,
num_output_channels: Optional[int] = None,
initializer: tf.keras.initializers.Initializer = tf.keras.initializers
.TruncatedNormal(stddev=0.1),
name: str = 'graph_convolution_feature_steered_convolution',
var_name: Optional[str] = None) -> tf.Tensor:
# pyformat: disable
"""Wraps the function `feature_steered_convolution` as a TensorFlow layer.
Expand Down Expand Up @@ -142,11 +146,11 @@ class FeatureSteeredConvolutionKerasLayer(tf.keras.layers.Layer):
"""Wraps the function `feature_steered_convolution` as a Keras layer."""

def __init__(self,
translation_invariant=True,
num_weight_matrices=8,
num_output_channels=None,
initializer=None,
name=None,
translation_invariant: bool = True,
num_weight_matrices: int = 8,
num_output_channels: Optional[int] = None,
initializer: Optional[tf.keras.initializers.Initializer] = None,
name: Optional[str] = None,
**kwargs):
"""Initializes FeatureSteeredConvolutionKerasLayer.
Expand All @@ -173,7 +177,7 @@ def __init__(self,
else:
self._initializer = initializer

def build(self, input_shape):
def build(self, input_shape: type_alias.TensorLike):
"""Initializes the trainable weights."""
in_channels = tf.TensorShape(input_shape[0]).as_list()[-1]
if self._num_output_channels is None:
Expand Down Expand Up @@ -215,7 +219,9 @@ def build(self, input_shape):
name='b',
trainable=True)

def call(self, inputs, **kwargs):
def call(self,
inputs: List[type_alias.TensorLike],
**kwargs) -> tf.Tensor:
# pyformat: disable
"""Executes the convolution.
Expand Down Expand Up @@ -290,20 +296,21 @@ class DynamicGraphConvolutionKerasLayer(tf.keras.layers.Layer):
input.
"""

def __init__(self,
num_output_channels,
reduction,
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
name=None,
**kwargs):
def __init__(
self,
num_output_channels: int,
reduction: str,
activation: Optional[Callable[[Any], Any]] = None,
use_bias: bool = True,
kernel_initializer: str = 'glorot_uniform',
bias_initializer: str = 'zeros',
kernel_regularizer: Optional[tf.keras.initializers.Initializer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
activity_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
kernel_constraint: Optional[tf.keras.constraints.Constraint] = None,
bias_constraint: Optional[tf.keras.constraints.Constraint] = None,
name: Optional[str] = None,
**kwargs):
"""Initializes DynamicGraphConvolutionKerasLayer.
Args:
Expand Down Expand Up @@ -342,7 +349,7 @@ def __init__(self,
self._kernel_constraint = kernel_constraint
self._bias_constraint = bias_constraint

def build(self, input_shape): # pylint: disable=unused-argument
def build(self, input_shape: type_alias.TensorLike): # pylint: disable=unused-argument
"""Initializes the layer weights."""
self._conv1d_layer = tf.keras.layers.Conv1D(
filters=self._num_output_channels,
Expand All @@ -359,7 +366,7 @@ def build(self, input_shape): # pylint: disable=unused-argument
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)

def call(self, inputs, **kwargs):
def call(self, inputs: List[type_alias.TensorLike], **kwargs) -> tf.Tensor:
# pyformat: disable
"""Executes the convolution.
Expand Down

0 comments on commit 64f0733

Please sign in to comment.