diff --git a/tensorflow_graphics/geometry/convolution/graph_pooling.py b/tensorflow_graphics/geometry/convolution/graph_pooling.py index 705d4e006..d93a6f1d3 100644 --- a/tensorflow_graphics/geometry/convolution/graph_pooling.py +++ b/tensorflow_graphics/geometry/convolution/graph_pooling.py @@ -17,13 +17,20 @@ from __future__ import division from __future__ import print_function +from typing import Callable + import tensorflow as tf from tensorflow_graphics.geometry.convolution import utils from tensorflow_graphics.util import export_api +from tensorflow_graphics.util import type_alias -def pool(data, pool_map, sizes, algorithm='max', name='graph_pooling_pool'): +def pool(data: type_alias.TensorLike, + pool_map: type_alias.TensorLike, + sizes: type_alias.TensorLike, + algorithm: str = 'max', + name: str = 'graph_pooling_pool') -> tf.Tensor: # pyformat: disable """Implements graph pooling. @@ -109,7 +116,10 @@ def pool(data, pool_map, sizes, algorithm='max', name='graph_pooling_pool'): return pooled -def unpool(data, pool_map, sizes, name='graph_pooling_unpool'): +def unpool(data: type_alias.TensorLike, + pool_map: type_alias.TensorLike, + sizes: type_alias.TensorLike, + name: str = 'graph_pooling_unpool') -> tf.Tensor: # pyformat: disable r"""Graph upsampling by inverting the pooling map. @@ -176,12 +186,13 @@ def unpool(data, pool_map, sizes, name='graph_pooling_unpool'): def upsample_transposed_convolution( - data, - pool_map, - sizes, - kernel_size, - transposed_convolution_op, - name='graph_pooling_upsample_transposed_convolution'): + data: type_alias.TensorLike, + pool_map: type_alias.TensorLike, + sizes: type_alias.TensorLike, + kernel_size: int, + transposed_convolution_op: Callable[[type_alias.TensorLike], + type_alias.TensorLike], + name: str = 'graph_pooling_upsample_transposed_convolution') -> tf.Tensor: # pyformat: disable r"""Graph upsampling by transposed convolution.