diff --git a/tensorflow_graphics/geometry/convolution/graph_convolution.py b/tensorflow_graphics/geometry/convolution/graph_convolution.py index 7e56bbea3..c1a74985c 100644 --- a/tensorflow_graphics/geometry/convolution/graph_convolution.py +++ b/tensorflow_graphics/geometry/convolution/graph_convolution.py @@ -17,24 +17,26 @@ from __future__ import division from __future__ import print_function +from typing import Any, Callable, Dict from six.moves import zip import tensorflow as tf from tensorflow_graphics.geometry.convolution import utils from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias def feature_steered_convolution( - data, - neighbors, - sizes, - var_u, - var_v, - var_c, - var_w, - var_b, - name="graph_convolution_feature_steered_convolution"): + data: type_alias.TensorLike, + neighbors: tf.sparse.SparseTensor, + sizes: type_alias.TensorLike, + var_u: type_alias.TensorLike, + var_v: type_alias.TensorLike, + var_c: type_alias.TensorLike, + var_w: type_alias.TensorLike, + var_b: type_alias.TensorLike, + name="graph_convolution_feature_steered_convolution") -> tf.Tensor: # pyformat: disable """Implements the Feature Steered graph convolution. @@ -160,13 +162,14 @@ def feature_steered_convolution( def edge_convolution_template( - data, - neighbors, - sizes, - edge_function, - reduction, - edge_function_kwargs, - name="graph_convolution_edge_convolution_template"): + data: type_alias.TensorLike, + neighbors: tf.sparse.SparseTensor, + sizes: type_alias.TensorLike, + edge_function: Callable[[type_alias.TensorLike, type_alias.TensorLike], + type_alias.TensorLike], + reduction: str, + edge_function_kwargs: Dict[str, Any], + name: str = "graph_convolution_edge_convolution_template") -> tf.Tensor: # pyformat: disable r"""A template for edge convolutions. diff --git a/tensorflow_graphics/geometry/convolution/graph_pooling.py b/tensorflow_graphics/geometry/convolution/graph_pooling.py index 705d4e006..d93a6f1d3 100644 --- a/tensorflow_graphics/geometry/convolution/graph_pooling.py +++ b/tensorflow_graphics/geometry/convolution/graph_pooling.py @@ -17,13 +17,20 @@ from __future__ import division from __future__ import print_function +from typing import Callable + import tensorflow as tf from tensorflow_graphics.geometry.convolution import utils from tensorflow_graphics.util import export_api +from tensorflow_graphics.util import type_alias -def pool(data, pool_map, sizes, algorithm='max', name='graph_pooling_pool'): +def pool(data: type_alias.TensorLike, + pool_map: type_alias.TensorLike, + sizes: type_alias.TensorLike, + algorithm: str = 'max', + name: str = 'graph_pooling_pool') -> tf.Tensor: # pyformat: disable """Implements graph pooling. @@ -109,7 +116,10 @@ def pool(data, pool_map, sizes, algorithm='max', name='graph_pooling_pool'): return pooled -def unpool(data, pool_map, sizes, name='graph_pooling_unpool'): +def unpool(data: type_alias.TensorLike, + pool_map: type_alias.TensorLike, + sizes: type_alias.TensorLike, + name: str = 'graph_pooling_unpool') -> tf.Tensor: # pyformat: disable r"""Graph upsampling by inverting the pooling map. @@ -176,12 +186,13 @@ def unpool(data, pool_map, sizes, name='graph_pooling_unpool'): def upsample_transposed_convolution( - data, - pool_map, - sizes, - kernel_size, - transposed_convolution_op, - name='graph_pooling_upsample_transposed_convolution'): + data: type_alias.TensorLike, + pool_map: type_alias.TensorLike, + sizes: type_alias.TensorLike, + kernel_size: int, + transposed_convolution_op: Callable[[type_alias.TensorLike], + type_alias.TensorLike], + name: str = 'graph_pooling_upsample_transposed_convolution') -> tf.Tensor: # pyformat: disable r"""Graph upsampling by transposed convolution. diff --git a/tensorflow_graphics/geometry/convolution/utils.py b/tensorflow_graphics/geometry/convolution/utils.py index f6bef95a4..25ec9f233 100644 --- a/tensorflow_graphics/geometry/convolution/utils.py +++ b/tensorflow_graphics/geometry/convolution/utils.py @@ -17,12 +17,16 @@ from __future__ import division from __future__ import print_function +from typing import Any, List, Optional, Tuple, Union + import tensorflow as tf from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def _is_dynamic_shape(tensors): +def _is_dynamic_shape(tensors: Union[List[type_alias.TensorLike], + Tuple[Any, tf.sparse.SparseTensor]]): """Helper function to test if any tensor in a list has a dynamic shape. Args: @@ -36,7 +40,9 @@ def _is_dynamic_shape(tensors): return not all([shape.is_static(tensor.shape) for tensor in tensors]) -def check_valid_graph_convolution_input(data, neighbors, sizes): +def check_valid_graph_convolution_input(data: type_alias.TensorLike, + neighbors: tf.sparse.SparseTensor, + sizes: type_alias.TensorLike): """Checks that the inputs are valid for graph convolution ops. Note: @@ -86,7 +92,9 @@ def check_valid_graph_convolution_input(data, neighbors, sizes): broadcast_compatible=False) -def check_valid_graph_pooling_input(data, pool_map, sizes): +def check_valid_graph_pooling_input(data: type_alias.TensorLike, + pool_map: tf.sparse.SparseTensor, + sizes: type_alias.TensorLike): """Checks that the inputs are valid for graph pooling. Note: @@ -136,7 +144,9 @@ def check_valid_graph_pooling_input(data, pool_map, sizes): broadcast_compatible=False) -def check_valid_graph_unpooling_input(data, pool_map, sizes): +def check_valid_graph_unpooling_input(data: type_alias.TensorLike, + pool_map: tf.sparse.SparseTensor, + sizes: type_alias.TensorLike): """Checks that the inputs are valid for graph unpooling. Note: @@ -187,7 +197,9 @@ def check_valid_graph_unpooling_input(data, pool_map, sizes): broadcast_compatible=False) -def flatten_batch_to_2d(data, sizes=None, name="utils_flatten_batch_to_2d"): +def flatten_batch_to_2d(data: type_alias.TensorLike, + sizes: type_alias.TensorLike = None, + name: str = "utils_flatten_batch_to_2d"): """Reshapes a batch of 2d Tensors by flattening across the batch dimensions. Note: @@ -286,10 +298,10 @@ def unflatten(flat, name="utils_unflatten"): return flat, unflatten -def unflatten_2d_to_batch(data, - sizes, - max_rows=None, - name="utils_unflatten_2d_to_batch"): +def unflatten_2d_to_batch(data: type_alias.TensorLike, + sizes: type_alias.TensorLike, + max_rows: Optional[int] = None, + name: str = "utils_unflatten_2d_to_batch"): r"""Reshapes a 2d Tensor into a batch of 2d Tensors. The `data` tensor with shape `[D1, D2]` will be mapped to a tensor with shape @@ -369,10 +381,10 @@ def unflatten_2d_to_batch(data, return tf.scatter_nd(indices=mask_indices, updates=data, shape=output_shape) -def convert_to_block_diag_2d(data, - sizes=None, - validate_indices=False, - name="utils_convert_to_block_diag_2d"): +def convert_to_block_diag_2d(data: tf.sparse.SparseTensor, + sizes: Optional[type_alias.TensorLike] = None, + validate_indices: bool = False, + name: str = "utils_convert_to_block_diag_2d"): """Convert a batch of 2d SparseTensors to a 2d block diagonal SparseTensor. Note: diff --git a/tensorflow_graphics/geometry/deformation_energy/as_conformal_as_possible.py b/tensorflow_graphics/geometry/deformation_energy/as_conformal_as_possible.py index 472cc1051..85fc3fc7e 100644 --- a/tensorflow_graphics/geometry/deformation_energy/as_conformal_as_possible.py +++ b/tensorflow_graphics/geometry/deformation_energy/as_conformal_as_possible.py @@ -17,23 +17,26 @@ from __future__ import division from __future__ import print_function +from typing import Optional + import tensorflow as tf from tensorflow_graphics.geometry.transformation import quaternion from tensorflow_graphics.math import vector from tensorflow_graphics.util import export_api from tensorflow_graphics.util import shape - - -def energy(vertices_rest_pose, - vertices_deformed_pose, - quaternions, - edges, - vertex_weight=None, - edge_weight=None, - conformal_energy=True, - aggregate_loss=True, - name=None): +from tensorflow_graphics.util import type_alias + + +def energy(vertices_rest_pose: type_alias.TensorLike, + vertices_deformed_pose: type_alias.TensorLike, + quaternions: type_alias.TensorLike, + edges: type_alias.TensorLike, + vertex_weight: Optional[type_alias.TensorLike] = None, + edge_weight: Optional[type_alias.TensorLike] = None, + conformal_energy: bool = True, + aggregate_loss: bool = True, + name: str = "as_conformal_as_possible_energy"): """Estimates an As Conformal As Possible (ACAP) fitting energy. For a given mesh in rest pose, this function evaluates a variant of the ACAP @@ -92,10 +95,7 @@ def energy(vertices_rest_pose, ValueError: if the shape of `vertices_rest_pose`, `vertices_deformed_pose`, `quaternions`, `edges`, `vertex_weight`, or `edge_weight` is not supported. """ - with tf.compat.v1.name_scope(name, "as_conformal_as_possible_energy", [ - vertices_rest_pose, vertices_deformed_pose, quaternions, edges, - conformal_energy, vertex_weight, edge_weight - ]): + with tf.name_scope(name): vertices_rest_pose = tf.convert_to_tensor(value=vertices_rest_pose) vertices_deformed_pose = tf.convert_to_tensor(value=vertices_deformed_pose) quaternions = tf.convert_to_tensor(value=quaternions) diff --git a/tensorflow_graphics/rendering/triangle_rasterizer.py b/tensorflow_graphics/rendering/triangle_rasterizer.py index 1fd5f663c..184ec62d1 100644 --- a/tensorflow_graphics/rendering/triangle_rasterizer.py +++ b/tensorflow_graphics/rendering/triangle_rasterizer.py @@ -30,7 +30,7 @@ def _dim_value(dim): - return 1 if dim is None else tf.compat.v1.dimension_value(dim) + return 1 if dim is None else tf.compat.dimension_value(dim) def _merge_batch_dims(tensor, last_axis): @@ -50,7 +50,7 @@ def rasterize(vertices, image_size, enable_cull_face=True, backend=rasterization_backend.RasterizationBackends.OPENGL, - name=None): + name="triangle_rasterizer_rasterize"): """Rasterizes the scene. Note: @@ -73,7 +73,7 @@ def rasterize(vertices, False. backend: A rasterization_backend.RasterizationBackends enum containing the backend method to use for rasterization. - name: A name for this op. Defaults to 'triangle_rasterizer_rasterize'. + name: A name for this op. Defaults to "triangle_rasterizer_rasterize". Returns: A dictionary. The key "mask" is of shape `[A1, ..., An, height, width, 1]` @@ -83,9 +83,7 @@ def rasterize(vertices, the dictionary contains perspective correct interpolated attributes of shape `[A1, ..., An, height, width, K]` per entry in the `attributes` dictionary. """ - with tf.compat.v1.name_scope( - name, "triangle_rasterizer_rasterize", - (vertices, triangles, attributes, view_projection_matrix)): + with tf.name_scope(name): vertices = tf.convert_to_tensor(value=vertices) triangles = tf.convert_to_tensor(value=triangles) view_projection_matrix = tf.convert_to_tensor(value=view_projection_matrix)