diff --git a/tensorflow_graphics/geometry/transformation/linear_blend_skinning.py b/tensorflow_graphics/geometry/transformation/linear_blend_skinning.py index db2ac923a..58870a6d7 100644 --- a/tensorflow_graphics/geometry/transformation/linear_blend_skinning.py +++ b/tensorflow_graphics/geometry/transformation/linear_blend_skinning.py @@ -22,13 +22,14 @@ from tensorflow_graphics.geometry.transformation import rotation_matrix_3d from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def blend(points, - skinning_weights, - bone_rotations, - bone_translations, - name="linear_blend_skinning_blend"): +def blend(points: type_alias.TensorLike, + skinning_weights: type_alias.TensorLike, + bone_rotations: type_alias.TensorLike, + bone_translations: type_alias.TensorLike, + name: str = "linear_blend_skinning_blend") -> tf.Tensor: """Transforms the points using Linear Blend Skinning. Note: diff --git a/tensorflow_graphics/geometry/transformation/look_at.py b/tensorflow_graphics/geometry/transformation/look_at.py index e8c648270..91f7267ed 100644 --- a/tensorflow_graphics/geometry/transformation/look_at.py +++ b/tensorflow_graphics/geometry/transformation/look_at.py @@ -22,9 +22,12 @@ from tensorflow_graphics.math import vector from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def right_handed(camera_position, look_at, up_vector, name="right_handed"): +def right_handed(camera_position: type_alias.TensorLike, + look_at, up_vector: type_alias.TensorLike, + name: str = "right_handed") -> tf.Tensor: """Builds a right handed look at view matrix. Note: diff --git a/tensorflow_graphics/geometry/transformation/quaternion.py b/tensorflow_graphics/geometry/transformation/quaternion.py index 8d03cb59e..fb53504c7 100644 --- a/tensorflow_graphics/geometry/transformation/quaternion.py +++ b/tensorflow_graphics/geometry/transformation/quaternion.py @@ -30,6 +30,8 @@ from __future__ import division from __future__ import print_function +from typing import List + from six.moves import range import tensorflow as tf @@ -39,9 +41,12 @@ from tensorflow_graphics.util import export_api from tensorflow_graphics.util import safe_ops from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def _build_quaternion_from_sines_and_cosines(sin_half_angles, cos_half_angles): +def _build_quaternion_from_sines_and_cosines( + sin_half_angles: type_alias.TensorLike, + cos_half_angles: type_alias.TensorLike) -> tf.Tensor: """Builds a quaternion from sines and cosines of half Euler angles. Note: @@ -66,9 +71,10 @@ def _build_quaternion_from_sines_and_cosines(sin_half_angles, cos_half_angles): return tf.stack((x, y, z, w), axis=-1) -def between_two_vectors_3d(vector1, - vector2, - name="quaternion_between_two_vectors_3d"): +def between_two_vectors_3d(vector1: type_alias.TensorLike, + vector2: type_alias.TensorLike, + name: str = "quaternion_between_two_vectors_3d" + ) -> tf.Tensor: """Computes quaternion over the shortest arc between two vectors. Result quaternion describes shortest geodesic rotation from @@ -129,7 +135,8 @@ def between_two_vectors_3d(vector1, return tf.nn.l2_normalize(rot, axis=-1) -def conjugate(quaternion, name="quaternion_conjugate"): +def conjugate(quaternion: type_alias.TensorLike, + name: str = "quaternion_conjugate") -> tf.Tensor: """Computes the conjugate of a quaternion. Note: @@ -157,7 +164,10 @@ def conjugate(quaternion, name="quaternion_conjugate"): return tf.concat((-xyz, w), axis=-1) -def from_axis_angle(axis, angle, name="quaternion_from_axis_angle"): +def from_axis_angle(axis: type_alias.TensorLike, + angle: type_alias.TensorLike, + name: str = "quaternion_from_axis_angle" + ) -> tf.Tensor: """Converts an axis-angle representation to a quaternion. Note: @@ -194,7 +204,9 @@ def from_axis_angle(axis, angle, name="quaternion_from_axis_angle"): return tf.concat((xyz, w), axis=-1) -def from_euler(angles, name="quaternion_from_euler"): +def from_euler(angles: type_alias.TensorLike, + name: str = "quaternion_from_euler" + ) -> tf.Tensor: """Converts an Euler angle representation to a quaternion. Note: @@ -230,8 +242,9 @@ def from_euler(angles, name="quaternion_from_euler"): cos_half_angles) -def from_euler_with_small_angles_approximation(angles, - name="quaternion_from_euler"): +def from_euler_with_small_angles_approximation( + angles: type_alias.TensorLike, + name: str = "quaternion_from_euler") -> tf.Tensor: r"""Converts small Euler angles to quaternions. Under the small angle assumption, $$\sin(x)$$ and $$\cos(x)$$ can be @@ -274,8 +287,9 @@ def from_euler_with_small_angles_approximation(angles, return tf.nn.l2_normalize(quaternion, axis=-1) -def from_rotation_matrix(rotation_matrix, - name="quaternion_from_rotation_matrix"): +def from_rotation_matrix(rotation_matrix: type_alias.TensorLike, + name: str = "quaternion_from_rotation_matrix" + ) -> tf.Tensor: """Converts a rotation matrix representation to a quaternion. Warning: @@ -361,7 +375,9 @@ def cond_idx(cond): return quat -def inverse(quaternion, name="quaternion_inverse"): +def inverse(quaternion: type_alias.TensorLike, + name: str = "quaternion_inverse" + ) -> tf.Tensor: """Computes the inverse of a quaternion. Note: @@ -391,7 +407,10 @@ def inverse(quaternion, name="quaternion_inverse"): return safe_ops.safe_unsigned_div(conjugate(quaternion), squared_norm) -def is_normalized(quaternion, atol=1e-3, name="quaternion_is_normalized"): +def is_normalized(quaternion: type_alias.TensorLike, + atol: type_alias.Float = 1e-3, + name: str = "quaternion_is_normalized" + ) -> tf.Tensor: """Determines if quaternion is normalized quaternion or not. Note: @@ -422,7 +441,10 @@ def is_normalized(quaternion, atol=1e-3, name="quaternion_is_normalized"): tf.zeros_like(norms, dtype=bool)) -def normalize(quaternion, eps=1e-12, name="quaternion_normalize"): +def normalize(quaternion: type_alias.TensorLike, + eps: type_alias.Float = 1e-12, + name: str = "quaternion_normalize" + ) -> tf.Tensor: """Normalizes a quaternion. Note: @@ -450,7 +472,10 @@ def normalize(quaternion, eps=1e-12, name="quaternion_normalize"): return tf.math.l2_normalize(quaternion, axis=-1, epsilon=eps) -def multiply(quaternion1, quaternion2, name="quaternion_multiply"): +def multiply(quaternion1: type_alias.TensorLike, + quaternion2: type_alias.TensorLike, + name: str = "quaternion_multiply" + ) -> tf.Tensor: """Multiplies two quaternions. Note: @@ -487,8 +512,9 @@ def multiply(quaternion1, quaternion2, name="quaternion_multiply"): return tf.stack((x, y, z, w), axis=-1) -def normalized_random_uniform(quaternion_shape, - name="quaternion_normalized_random_uniform"): +def normalized_random_uniform(quaternion_shape: List[int], + name: str = "quaternion_normalized_random_uniform" + ) -> tf.Tensor: """Random normalized quaternion following a uniform distribution law on SO(3). Args: @@ -545,7 +571,10 @@ def _initializer(shape, dtype=tf.float32, partition_info=None): # pylint: enable=redefined-outer-name -def rotate(point, quaternion, name="quaternion_rotate"): +def rotate(point: type_alias.TensorLike, + quaternion: type_alias.TensorLike, + name: str = "quaternion_rotate" + ) -> tf.Tensor: """Rotates a point using a quaternion. Note: @@ -586,7 +615,10 @@ def rotate(point, quaternion, name="quaternion_rotate"): return xyz -def relative_angle(quaternion1, quaternion2, name="quaternion_relative_angle"): +def relative_angle(quaternion1: type_alias.TensorLike, + quaternion2: type_alias.TensorLike, + name: str = "quaternion_relative_angle" + ) -> tf.Tensor: r"""Computes the unsigned relative rotation angle between 2 unit quaternions. Given two normalized quanternions $$\mathbf{q}_1$$ and $$\mathbf{q}_2$$, the diff --git a/tensorflow_graphics/geometry/transformation/rotation_matrix_2d.py b/tensorflow_graphics/geometry/transformation/rotation_matrix_2d.py index 3556533ab..27a6b3837 100644 --- a/tensorflow_graphics/geometry/transformation/rotation_matrix_2d.py +++ b/tensorflow_graphics/geometry/transformation/rotation_matrix_2d.py @@ -39,9 +39,11 @@ from tensorflow_graphics.geometry.transformation import rotation_matrix_common from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def from_euler(angle, name="rotation_matrix_2d_from_euler_angle"): +def from_euler(angle: type_alias.TensorLike, + name: str = "rotation_matrix_2d_from_euler_angle") -> tf.Tensor: r"""Converts an angle to a 2d rotation matrix. Converts an angle $$\theta$$ to a 2d rotation matrix following the equation @@ -89,8 +91,9 @@ def from_euler(angle, name="rotation_matrix_2d_from_euler_angle"): def from_euler_with_small_angles_approximation( - angles, - name="rotation_matrix_2d_from_euler_with_small_angles_approximation"): + angles: type_alias.TensorLike, + name: str = "rotation_matrix_2d_from_euler_with_small_angles_approximation" +) -> tf.Tensor: r"""Converts an angle to a 2d rotation matrix under the small angle assumption. Under the small angle assumption, $$\sin(x)$$ and $$\cos(x)$$ can be @@ -142,7 +145,8 @@ def from_euler_with_small_angles_approximation( return tf.reshape(matrix, shape=output_shape) -def inverse(matrix, name="rotation_matrix_2d_inverse"): +def inverse(matrix: type_alias.TensorLike, + name: str = "rotation_matrix_2d_inverse") -> tf.Tensor: """Computes the inverse of a 2D rotation matrix. Note: @@ -174,7 +178,9 @@ def inverse(matrix, name="rotation_matrix_2d_inverse"): return tf.transpose(a=matrix, perm=perm) -def is_valid(matrix, atol=1e-3, name="rotation_matrix_2d_is_valid"): +def is_valid(matrix: type_alias.TensorLike, + atol: type_alias.Float = 1e-3, + name: str = "rotation_matrix_2d_is_valid") -> tf.Tensor: r"""Determines if a matrix is a valid rotation matrix. Determines if a matrix $$\mathbf{R}$$ is a valid rotation matrix by checking @@ -205,7 +211,9 @@ def is_valid(matrix, atol=1e-3, name="rotation_matrix_2d_is_valid"): return rotation_matrix_common.is_valid(matrix, atol) -def rotate(point, matrix, name="rotation_matrix_2d_rotate"): +def rotate(point: type_alias.TensorLike, + matrix: type_alias.TensorLike, + name: str = "rotation_matrix_2d_rotate") -> tf.Tensor: """Rotates a 2d point using a 2d rotation matrix. Note: diff --git a/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py b/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py index f235461f1..4ff9bd4c3 100644 --- a/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py +++ b/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py @@ -30,11 +30,14 @@ from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape from tensorflow_graphics.util import tfg_flags +from tensorflow_graphics.util import type_alias FLAGS = flags.FLAGS -def _build_matrix_from_sines_and_cosines(sin_angles, cos_angles): +def _build_matrix_from_sines_and_cosines( + sin_angles: type_alias.TensorLike, + cos_angles: type_alias.TensorLike) -> tf.Tensor: """Builds a rotation matrix from sines and cosines of Euler angles. Note: @@ -71,9 +74,10 @@ def _build_matrix_from_sines_and_cosines(sin_angles, cos_angles): return tf.reshape(matrix, shape=output_shape) -def assert_rotation_matrix_normalized(matrix, - eps=1e-3, - name="assert_rotation_matrix_normalized"): +def assert_rotation_matrix_normalized( + matrix: type_alias.TensorLike, + eps: type_alias.Float = 1e-3, + name: str = "assert_rotation_matrix_normalized") -> tf.Tensor: """Checks whether a matrix is a rotation matrix. Note: @@ -113,7 +117,10 @@ def assert_rotation_matrix_normalized(matrix, return tf.identity(matrix) -def from_axis_angle(axis, angle, name="rotation_matrix_3d_from_axis_angle"): +def from_axis_angle( + axis: type_alias.TensorLike, + angle: type_alias.TensorLike, + name: str = "rotation_matrix_3d_from_axis_angle") -> tf.Tensor: """Convert an axis-angle representation to a rotation matrix. Note: @@ -174,7 +181,8 @@ def from_axis_angle(axis, angle, name="rotation_matrix_3d_from_axis_angle"): return tf.reshape(matrix, shape=output_shape) -def from_euler(angles, name="rotation_matrix_3d_from_euler"): +def from_euler(angles: type_alias.TensorLike, + name: str = "rotation_matrix_3d_from_euler") -> tf.Tensor: r"""Convert an Euler angle representation to a rotation matrix. The resulting matrix is $$\mathbf{R} = \mathbf{R}_z\mathbf{R}_y\mathbf{R}_x$$. @@ -208,7 +216,8 @@ def from_euler(angles, name="rotation_matrix_3d_from_euler"): def from_euler_with_small_angles_approximation( - angles, name="rotation_matrix_3d_from_euler_with_small_angles"): + angles: type_alias.TensorLike, + name: str = "rotation_matrix_3d_from_euler_with_small_angles") -> tf.Tensor: r"""Convert an Euler angle representation to a rotation matrix. The resulting matrix is $$\mathbf{R} = \mathbf{R}_z\mathbf{R}_y\mathbf{R}_x$$. @@ -246,7 +255,9 @@ def from_euler_with_small_angles_approximation( return _build_matrix_from_sines_and_cosines(sin_angles, cos_angles) -def from_quaternion(quaternion, name="rotation_matrix_3d_from_quaternion"): +def from_quaternion( + quaternion: type_alias.TensorLike, + name: str = "rotation_matrix_3d_from_quaternion") -> tf.Tensor: """Convert a quaternion to a rotation matrix. Note: @@ -293,7 +304,8 @@ def from_quaternion(quaternion, name="rotation_matrix_3d_from_quaternion"): return tf.reshape(matrix, shape=output_shape) -def inverse(matrix, name="rotation_matrix_3d_inverse"): +def inverse(matrix: type_alias.TensorLike, + name: str = "rotation_matrix_3d_inverse") -> tf.Tensor: """Computes the inverse of a 3D rotation matrix. Note: @@ -326,7 +338,9 @@ def inverse(matrix, name="rotation_matrix_3d_inverse"): return tf.transpose(a=matrix, perm=perm) -def is_valid(matrix, atol=1e-3, name="rotation_matrix_3d_is_valid"): +def is_valid(matrix: type_alias.TensorLike, + atol: type_alias.Float = 1e-3, + name: str = "rotation_matrix_3d_is_valid") -> tf.Tensor: """Determines if a matrix is a valid rotation matrix. Note: @@ -354,7 +368,9 @@ def is_valid(matrix, atol=1e-3, name="rotation_matrix_3d_is_valid"): return rotation_matrix_common.is_valid(matrix, atol) -def rotate(point, matrix, name="rotation_matrix_3d_rotate"): +def rotate(point: type_alias.TensorLike, + matrix: type_alias.TensorLike, + name: str = "rotation_matrix_3d_rotate") -> tf.Tensor: """Rotate a point using a rotation matrix 3d. Note: diff --git a/tensorflow_graphics/geometry/transformation/rotation_matrix_common.py b/tensorflow_graphics/geometry/transformation/rotation_matrix_common.py index 14ed119e8..45cf7e1eb 100644 --- a/tensorflow_graphics/geometry/transformation/rotation_matrix_common.py +++ b/tensorflow_graphics/geometry/transformation/rotation_matrix_common.py @@ -22,9 +22,12 @@ from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def is_valid(matrix, atol=1e-3, name="rotation_matrix_common_is_valid"): +def is_valid(matrix: type_alias.TensorLike, + atol: type_alias.Float = 1e-3, + name: str = "rotation_matrix_common_is_valid") -> tf.Tensor: r"""Determines if a matrix in K-dimensions is a valid rotation matrix. Determines if a matrix $$\mathbf{R}$$ is a valid rotation matrix by checking diff --git a/tensorflow_graphics/image/color_space/linear_rgb.py b/tensorflow_graphics/image/color_space/linear_rgb.py index 61b8e0272..f9524e0e7 100644 --- a/tensorflow_graphics/image/color_space/linear_rgb.py +++ b/tensorflow_graphics/image/color_space/linear_rgb.py @@ -22,6 +22,7 @@ from tensorflow_graphics.util import asserts from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias # Conversion constants following the naming convention from the 'theory of the # transformation' section at https://en.wikipedia.org/wiki/SRGB. @@ -31,7 +32,8 @@ _GAMMA = constants.srgb_gamma["GAMMA"] -def from_srgb(srgb, name="linear_rgb_from_srgb"): +def from_srgb(srgb: type_alias.TensorLike, + name: str = "linear_rgb_from_srgb") -> tf.Tensor: """Converts sRGB colors to linear colors. Note: diff --git a/tensorflow_graphics/image/color_space/srgb.py b/tensorflow_graphics/image/color_space/srgb.py index f65cdb4e1..8c5552edf 100644 --- a/tensorflow_graphics/image/color_space/srgb.py +++ b/tensorflow_graphics/image/color_space/srgb.py @@ -28,6 +28,7 @@ from tensorflow_graphics.util import asserts from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias # Conversion constants following the naming convention from the 'theory of the # transformation' section at https://en.wikipedia.org/wiki/SRGB. @@ -37,7 +38,8 @@ _GAMMA = constants.srgb_gamma["GAMMA"] -def from_linear_rgb(linear_rgb, name="srgb_from_linear_rgb"): +def from_linear_rgb(linear_rgb: type_alias.TensorLike, + name: str = "srgb_from_linear_rgb") -> tf.Tensor: """Converts linear RGB to sRGB colors. Note: diff --git a/tensorflow_graphics/image/matting.py b/tensorflow_graphics/image/matting.py index 6bb621b73..4788efc5d 100644 --- a/tensorflow_graphics/image/matting.py +++ b/tensorflow_graphics/image/matting.py @@ -17,6 +17,8 @@ from __future__ import division from __future__ import print_function +from typing import List, Tuple, Union + import numpy as np import tensorflow as tf @@ -24,9 +26,11 @@ from tensorflow_graphics.util import asserts from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def _shape(batch_shape, *shapes): +def _shape(batch_shape: Union[type_alias.TensorLike, List[int]], + *shapes) -> tf.Tensor: """Creates a new shape concatenating batch_shape and shapes. Args: @@ -39,7 +43,8 @@ def _shape(batch_shape, *shapes): return tf.concat((batch_shape, shapes), axis=-1) -def _quadratic_form(matrix, vector): +def _quadratic_form(matrix: type_alias.TensorLike, + vector: type_alias.TensorLike) -> tf.Tensor: """Computes the quadratic form between a matrix and a vector. The quadratic form between a matrix A and a vector x can be written as @@ -60,7 +65,7 @@ def _quadratic_form(matrix, vector): return vector_matrix_vector -def _image_patches(image, size): +def _image_patches(image: type_alias.TensorLike, size: int) -> tf.Tensor: """Extracts square image patches. Args: @@ -79,7 +84,7 @@ def _image_patches(image, size): padding="VALID") -def _image_average(image, size): +def _image_average(image: type_alias.TensorLike, size: int) -> tf.Tensor: """Computes average over image patches. Args: @@ -96,7 +101,11 @@ def _image_average(image, size): padding="VALID") -def build_matrices(image, size=3, eps=1e-5, name="matting_build_matrices"): +def build_matrices(image: type_alias.TensorLike, + size: int = 3, + eps: type_alias.Float = 1e-5, + name: str = "matting_build_matrices" + ) -> Tuple[tf.Tensor, tf.Tensor]: """Generates the closed form matting Laplacian. Generates the closed form matting Laplacian as proposed by Levin et @@ -157,7 +166,8 @@ def build_matrices(image, size=3, eps=1e-5, name="matting_build_matrices"): def linear_coefficients(matte, pseudo_inverse, - name="matting_linear_coefficients"): + name="matting_linear_coefficients" + ) -> Tuple[tf.Tensor, tf.Tensor]: """Computes the matting linear coefficients. Computes the matting linear coefficients (a, b) based on the `pseudo_inverse` @@ -209,7 +219,9 @@ def linear_coefficients(matte, return tf.split(coeffs, (-1, 1), axis=-1) -def loss(matte, laplacian, name="matting_loss"): +def loss(matte: type_alias.TensorLike, + laplacian: type_alias.TensorLike, + name: str = "matting_loss") -> tf.Tensor: """Computes the matting loss function based on the matting Laplacian. Computes the matting loss function based on the `laplacian` generated by the @@ -249,7 +261,10 @@ def loss(matte, laplacian, name="matting_loss"): return tf.reduce_mean(input_tensor=losses) -def reconstruct(image, coeff_mul, coeff_add, name="matting_reconstruct"): +def reconstruct(image: type_alias.TensorLike, + coeff_mul: type_alias.TensorLike, + coeff_add: type_alias.TensorLike, + name: str = "matting_reconstruct") -> tf.Tensor: """Reconstruct the matte from the image using the linear coefficients. Reconstruct the matte from the image using the linear coefficients (a, b) diff --git a/tensorflow_graphics/image/pyramid.py b/tensorflow_graphics/image/pyramid.py index 86bf08a54..07f529029 100644 --- a/tensorflow_graphics/image/pyramid.py +++ b/tensorflow_graphics/image/pyramid.py @@ -21,15 +21,19 @@ from __future__ import division from __future__ import print_function +from typing import List, Optional, Tuple + import numpy as np from six.moves import range import tensorflow as tf from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def _downsample(image, kernel): +def _downsample(image: type_alias.TensorLike, + kernel: type_alias.TensorLike) -> tf.Tensor: """Downsamples the image using a convolution with stride 2. Args: @@ -48,7 +52,8 @@ def _downsample(image, kernel): input=image, filters=kernel, strides=[1, 2, 2, 1], padding="SAME") -def _binomial_kernel(num_channels, dtype=tf.float32): +def _binomial_kernel(num_channels: int, + dtype: tf.DType = tf.float32) -> tf.Tensor: """Creates a 5x5 binomial kernel. Args: @@ -65,7 +70,9 @@ def _binomial_kernel(num_channels, dtype=tf.float32): return tf.constant(kernel, dtype=dtype) * tf.eye(num_channels, dtype=dtype) -def _build_pyramid(image, sampler, num_levels): +def _build_pyramid(image: type_alias.TensorLike, + sampler, + num_levels: int) -> List[tf.Tensor]: """Creates the different levels of the pyramid. Args: @@ -87,7 +94,8 @@ def _build_pyramid(image, sampler, num_levels): return levels -def _split(image, kernel): +def _split(image: type_alias.TensorLike, + kernel: type_alias.TensorLike) -> Tuple[tf.Tensor, tf.Tensor]: """Splits the image into high and low frequencies. This is achieved by smoothing the input image and substracting the smoothed @@ -112,7 +120,10 @@ def _split(image, kernel): return high, low -def _upsample(image, kernel, output_shape=None): +def _upsample(image: type_alias.TensorLike, + kernel: type_alias.TensorLike, + output_shape: Optional[type_alias.TensorLike] = None + ) -> tf.Tensor: """Upsamples the image using a transposed convolution with stride 2. Args: @@ -139,7 +150,9 @@ def _upsample(image, kernel, output_shape=None): padding="SAME") -def downsample(image, num_levels, name="pyramid_downsample"): +def downsample(image: type_alias.TensorLike, + num_levels: int, + name: str = "pyramid_downsample") -> List[tf.Tensor]: """Generates the different levels of the pyramid (downsampling). Args: @@ -165,7 +178,8 @@ def downsample(image, num_levels, name="pyramid_downsample"): return _build_pyramid(image, _downsample, num_levels) -def merge(levels, name="pyramid_merge"): +def merge(levels: List[type_alias.TensorLike], + name: str = "pyramid_merge") -> tf.Tensor: """Merges the different levels of the pyramid back to an image. Args: @@ -196,7 +210,9 @@ def merge(levels, name="pyramid_merge"): return image -def split(image, num_levels, name="pyramid_split"): +def split(image: type_alias.TensorLike, + num_levels: int, + name: str = "pyramid_split") -> List[tf.Tensor]: """Generates the different levels of the pyramid. Args: @@ -228,7 +244,9 @@ def split(image, num_levels, name="pyramid_split"): return levels -def upsample(image, num_levels, name="pyramid_upsample"): +def upsample(image: type_alias.TensorLike, + num_levels: int, + name: str = "pyramid_upsample") -> List[tf.Tensor]: """Generates the different levels of the pyramid (upsampling). Args: diff --git a/tensorflow_graphics/math/math_helpers.py b/tensorflow_graphics/math/math_helpers.py index f58fab324..5334c7c42 100644 --- a/tensorflow_graphics/math/math_helpers.py +++ b/tensorflow_graphics/math/math_helpers.py @@ -24,12 +24,13 @@ from tensorflow_graphics.util import export_api from tensorflow_graphics.util import safe_ops from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def cartesian_to_spherical_coordinates(point_cartesian, - eps=None, - name="cartesian_to_spherical_coordinates" - ): +def cartesian_to_spherical_coordinates( + point_cartesian: type_alias.TensorLike, + eps: type_alias.Float = None, + name: str = "cartesian_to_spherical_coordinates") -> tf.Tensor: """Function to transform Cartesian coordinates to spherical coordinates. This function assumes a right handed coordinate system with `z` pointing up. @@ -77,7 +78,7 @@ def _double_factorial_loop_condition(n, result, two): return tf.cast(tf.math.count_nonzero(tf.greater_equal(n, two)), tf.bool) -def double_factorial(n): +def double_factorial(n: type_alias.TensorLike) -> tf.Tensor: """Computes the double factorial of `n`. Note: @@ -100,7 +101,7 @@ def double_factorial(n): return result -def factorial(n): +def factorial(n: type_alias.TensorLike) -> tf.Tensor: """Computes the factorial of `n`. Note: @@ -117,9 +118,9 @@ def factorial(n): return tf.exp(tf.math.lgamma(n + 1)) -def spherical_to_cartesian_coordinates(point_spherical, - name="spherical_to_cartesian_coordinates" - ): +def spherical_to_cartesian_coordinates( + point_spherical: type_alias.TensorLike, + name: str = "spherical_to_cartesian_coordinates") -> tf.Tensor: """Function to transform Cartesian coordinates to spherical coordinates. Note: @@ -156,9 +157,9 @@ def spherical_to_cartesian_coordinates(point_spherical, return tf.stack((x, y, z), axis=-1) -def square_to_spherical_coordinates(point_2d, - name="math_square_to_spherical_coordinates" - ): +def square_to_spherical_coordinates( + point_2d: type_alias.TensorLike, + name: str = "math_square_to_spherical_coordinates") -> tf.Tensor: """Maps points from a unit square to a unit sphere. Note: diff --git a/tensorflow_graphics/math/vector.py b/tensorflow_graphics/math/vector.py index f320984d9..58639232f 100644 --- a/tensorflow_graphics/math/vector.py +++ b/tensorflow_graphics/math/vector.py @@ -22,9 +22,13 @@ from tensorflow_graphics.util import asserts from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def cross(vector1, vector2, axis=-1, name="vector_cross"): +def cross(vector1: type_alias.TensorLike, + vector2: type_alias.TensorLike, + axis: int = -1, + name: str = "vector_cross") -> tf.Tensor: """Computes the cross product between two tensors along an axis. Note: @@ -62,7 +66,11 @@ def cross(vector1, vector2, axis=-1, name="vector_cross"): return tf.stack((n_x, n_y, n_z), axis=axis) -def dot(vector1, vector2, axis=-1, keepdims=True, name="vector_dot"): +def dot(vector1: type_alias.TensorLike, + vector2: type_alias.TensorLike, + axis: int = -1, + keepdims: bool = True, + name: str = "vector_dot") -> tf.Tensor: """Computes the dot product between two tensors along an axis. Note: @@ -97,7 +105,10 @@ def dot(vector1, vector2, axis=-1, keepdims=True, name="vector_dot"): input_tensor=vector1 * vector2, axis=axis, keepdims=keepdims) -def reflect(vector, normal, axis=-1, name="vector_reflect"): +def reflect(vector: type_alias.TensorLike, + normal: type_alias.TensorLike, + axis: int = -1, + name: str = "vector_reflect") -> tf.Tensor: r"""Computes the reflection direction for an incident vector. For an incident vector \\(\mathbf{v}\\) and normal $$\mathbf{n}$$ this