diff --git a/tensorflow_graphics/geometry/convolution/utils.py b/tensorflow_graphics/geometry/convolution/utils.py index f6bef95a4..25ec9f233 100644 --- a/tensorflow_graphics/geometry/convolution/utils.py +++ b/tensorflow_graphics/geometry/convolution/utils.py @@ -17,12 +17,16 @@ from __future__ import division from __future__ import print_function +from typing import Any, List, Optional, Tuple, Union + import tensorflow as tf from tensorflow_graphics.util import shape +from tensorflow_graphics.util import type_alias -def _is_dynamic_shape(tensors): +def _is_dynamic_shape(tensors: Union[List[type_alias.TensorLike], + Tuple[Any, tf.sparse.SparseTensor]]): """Helper function to test if any tensor in a list has a dynamic shape. Args: @@ -36,7 +40,9 @@ def _is_dynamic_shape(tensors): return not all([shape.is_static(tensor.shape) for tensor in tensors]) -def check_valid_graph_convolution_input(data, neighbors, sizes): +def check_valid_graph_convolution_input(data: type_alias.TensorLike, + neighbors: tf.sparse.SparseTensor, + sizes: type_alias.TensorLike): """Checks that the inputs are valid for graph convolution ops. Note: @@ -86,7 +92,9 @@ def check_valid_graph_convolution_input(data, neighbors, sizes): broadcast_compatible=False) -def check_valid_graph_pooling_input(data, pool_map, sizes): +def check_valid_graph_pooling_input(data: type_alias.TensorLike, + pool_map: tf.sparse.SparseTensor, + sizes: type_alias.TensorLike): """Checks that the inputs are valid for graph pooling. Note: @@ -136,7 +144,9 @@ def check_valid_graph_pooling_input(data, pool_map, sizes): broadcast_compatible=False) -def check_valid_graph_unpooling_input(data, pool_map, sizes): +def check_valid_graph_unpooling_input(data: type_alias.TensorLike, + pool_map: tf.sparse.SparseTensor, + sizes: type_alias.TensorLike): """Checks that the inputs are valid for graph unpooling. Note: @@ -187,7 +197,9 @@ def check_valid_graph_unpooling_input(data, pool_map, sizes): broadcast_compatible=False) -def flatten_batch_to_2d(data, sizes=None, name="utils_flatten_batch_to_2d"): +def flatten_batch_to_2d(data: type_alias.TensorLike, + sizes: type_alias.TensorLike = None, + name: str = "utils_flatten_batch_to_2d"): """Reshapes a batch of 2d Tensors by flattening across the batch dimensions. Note: @@ -286,10 +298,10 @@ def unflatten(flat, name="utils_unflatten"): return flat, unflatten -def unflatten_2d_to_batch(data, - sizes, - max_rows=None, - name="utils_unflatten_2d_to_batch"): +def unflatten_2d_to_batch(data: type_alias.TensorLike, + sizes: type_alias.TensorLike, + max_rows: Optional[int] = None, + name: str = "utils_unflatten_2d_to_batch"): r"""Reshapes a 2d Tensor into a batch of 2d Tensors. The `data` tensor with shape `[D1, D2]` will be mapped to a tensor with shape @@ -369,10 +381,10 @@ def unflatten_2d_to_batch(data, return tf.scatter_nd(indices=mask_indices, updates=data, shape=output_shape) -def convert_to_block_diag_2d(data, - sizes=None, - validate_indices=False, - name="utils_convert_to_block_diag_2d"): +def convert_to_block_diag_2d(data: tf.sparse.SparseTensor, + sizes: Optional[type_alias.TensorLike] = None, + validate_indices: bool = False, + name: str = "utils_convert_to_block_diag_2d"): """Convert a batch of 2d SparseTensors to a 2d block diagonal SparseTensor. Note: