Skip to content

Commit

Permalink
Adds typing information to geometry/convolution/utils.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 379965032
  • Loading branch information
G4G authored and Copybara-Service committed Jun 17, 2021
1 parent 1d33eb6 commit bd33a91
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 37 deletions.
35 changes: 19 additions & 16 deletions tensorflow_graphics/geometry/convolution/graph_convolution.py
Expand Up @@ -17,24 +17,26 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable, Dict
from six.moves import zip
import tensorflow as tf

from tensorflow_graphics.geometry.convolution import utils
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def feature_steered_convolution(
data,
neighbors,
sizes,
var_u,
var_v,
var_c,
var_w,
var_b,
name="graph_convolution_feature_steered_convolution"):
data: type_alias.TensorLike,
neighbors: tf.sparse.SparseTensor,
sizes: type_alias.TensorLike,
var_u: type_alias.TensorLike,
var_v: type_alias.TensorLike,
var_c: type_alias.TensorLike,
var_w: type_alias.TensorLike,
var_b: type_alias.TensorLike,
name="graph_convolution_feature_steered_convolution") -> tf.Tensor:
# pyformat: disable
"""Implements the Feature Steered graph convolution.
Expand Down Expand Up @@ -160,13 +162,14 @@ def feature_steered_convolution(


def edge_convolution_template(
data,
neighbors,
sizes,
edge_function,
reduction,
edge_function_kwargs,
name="graph_convolution_edge_convolution_template"):
data: type_alias.TensorLike,
neighbors: tf.sparse.SparseTensor,
sizes: type_alias.TensorLike,
edge_function: Callable[[type_alias.TensorLike, type_alias.TensorLike],
type_alias.TensorLike],
reduction: str,
edge_function_kwargs: Dict[str, Any],
name: str = "graph_convolution_edge_convolution_template") -> tf.Tensor:
# pyformat: disable
r"""A template for edge convolutions.
Expand Down
27 changes: 19 additions & 8 deletions tensorflow_graphics/geometry/convolution/graph_pooling.py
Expand Up @@ -17,13 +17,20 @@
from __future__ import division
from __future__ import print_function

from typing import Callable

import tensorflow as tf

from tensorflow_graphics.geometry.convolution import utils
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import type_alias


def pool(data, pool_map, sizes, algorithm='max', name='graph_pooling_pool'):
def pool(data: type_alias.TensorLike,
pool_map: type_alias.TensorLike,
sizes: type_alias.TensorLike,
algorithm: str = 'max',
name: str = 'graph_pooling_pool') -> tf.Tensor:
# pyformat: disable
"""Implements graph pooling.
Expand Down Expand Up @@ -109,7 +116,10 @@ def pool(data, pool_map, sizes, algorithm='max', name='graph_pooling_pool'):
return pooled


def unpool(data, pool_map, sizes, name='graph_pooling_unpool'):
def unpool(data: type_alias.TensorLike,
pool_map: type_alias.TensorLike,
sizes: type_alias.TensorLike,
name: str = 'graph_pooling_unpool') -> tf.Tensor:
# pyformat: disable
r"""Graph upsampling by inverting the pooling map.
Expand Down Expand Up @@ -176,12 +186,13 @@ def unpool(data, pool_map, sizes, name='graph_pooling_unpool'):


def upsample_transposed_convolution(
data,
pool_map,
sizes,
kernel_size,
transposed_convolution_op,
name='graph_pooling_upsample_transposed_convolution'):
data: type_alias.TensorLike,
pool_map: type_alias.TensorLike,
sizes: type_alias.TensorLike,
kernel_size: int,
transposed_convolution_op: Callable[[type_alias.TensorLike],
type_alias.TensorLike],
name: str = 'graph_pooling_upsample_transposed_convolution') -> tf.Tensor:
# pyformat: disable
r"""Graph upsampling by transposed convolution.
Expand Down
38 changes: 25 additions & 13 deletions tensorflow_graphics/geometry/convolution/utils.py
Expand Up @@ -17,12 +17,16 @@
from __future__ import division
from __future__ import print_function

from typing import Any, List, Optional, Tuple, Union

import tensorflow as tf

from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def _is_dynamic_shape(tensors):
def _is_dynamic_shape(tensors: Union[List[type_alias.TensorLike],
Tuple[Any, tf.sparse.SparseTensor]]):
"""Helper function to test if any tensor in a list has a dynamic shape.
Args:
Expand All @@ -36,7 +40,9 @@ def _is_dynamic_shape(tensors):
return not all([shape.is_static(tensor.shape) for tensor in tensors])


def check_valid_graph_convolution_input(data, neighbors, sizes):
def check_valid_graph_convolution_input(data: type_alias.TensorLike,
neighbors: tf.sparse.SparseTensor,
sizes: type_alias.TensorLike):
"""Checks that the inputs are valid for graph convolution ops.
Note:
Expand Down Expand Up @@ -86,7 +92,9 @@ def check_valid_graph_convolution_input(data, neighbors, sizes):
broadcast_compatible=False)


def check_valid_graph_pooling_input(data, pool_map, sizes):
def check_valid_graph_pooling_input(data: type_alias.TensorLike,
pool_map: tf.sparse.SparseTensor,
sizes: type_alias.TensorLike):
"""Checks that the inputs are valid for graph pooling.
Note:
Expand Down Expand Up @@ -136,7 +144,9 @@ def check_valid_graph_pooling_input(data, pool_map, sizes):
broadcast_compatible=False)


def check_valid_graph_unpooling_input(data, pool_map, sizes):
def check_valid_graph_unpooling_input(data: type_alias.TensorLike,
pool_map: tf.sparse.SparseTensor,
sizes: type_alias.TensorLike):
"""Checks that the inputs are valid for graph unpooling.
Note:
Expand Down Expand Up @@ -187,7 +197,9 @@ def check_valid_graph_unpooling_input(data, pool_map, sizes):
broadcast_compatible=False)


def flatten_batch_to_2d(data, sizes=None, name="utils_flatten_batch_to_2d"):
def flatten_batch_to_2d(data: type_alias.TensorLike,
sizes: type_alias.TensorLike = None,
name: str = "utils_flatten_batch_to_2d"):
"""Reshapes a batch of 2d Tensors by flattening across the batch dimensions.
Note:
Expand Down Expand Up @@ -286,10 +298,10 @@ def unflatten(flat, name="utils_unflatten"):
return flat, unflatten


def unflatten_2d_to_batch(data,
sizes,
max_rows=None,
name="utils_unflatten_2d_to_batch"):
def unflatten_2d_to_batch(data: type_alias.TensorLike,
sizes: type_alias.TensorLike,
max_rows: Optional[int] = None,
name: str = "utils_unflatten_2d_to_batch"):
r"""Reshapes a 2d Tensor into a batch of 2d Tensors.
The `data` tensor with shape `[D1, D2]` will be mapped to a tensor with shape
Expand Down Expand Up @@ -369,10 +381,10 @@ def unflatten_2d_to_batch(data,
return tf.scatter_nd(indices=mask_indices, updates=data, shape=output_shape)


def convert_to_block_diag_2d(data,
sizes=None,
validate_indices=False,
name="utils_convert_to_block_diag_2d"):
def convert_to_block_diag_2d(data: tf.sparse.SparseTensor,
sizes: type_alias.TensorLike = None,
validate_indices: bool = False,
name: str = "utils_convert_to_block_diag_2d"):
"""Convert a batch of 2d SparseTensors to a 2d block diagonal SparseTensor.
Note:
Expand Down

0 comments on commit bd33a91

Please sign in to comment.