Skip to content

Commit

Permalink
Merge 77b7c2b into d4f2f6a
Browse files Browse the repository at this point in the history
  • Loading branch information
copybara-service[bot] committed Jun 18, 2021
2 parents d4f2f6a + 77b7c2b commit 3ba8341
Showing 1 changed file with 25 additions and 13 deletions.
38 changes: 25 additions & 13 deletions tensorflow_graphics/geometry/convolution/utils.py
Expand Up @@ -17,12 +17,16 @@
from __future__ import division
from __future__ import print_function

from typing import Any, List, Optional, Tuple, Union

import tensorflow as tf

from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def _is_dynamic_shape(tensors):
def _is_dynamic_shape(tensors: Union[List[type_alias.TensorLike],
Tuple[Any, tf.sparse.SparseTensor]]):
"""Helper function to test if any tensor in a list has a dynamic shape.
Args:
Expand All @@ -36,7 +40,9 @@ def _is_dynamic_shape(tensors):
return not all([shape.is_static(tensor.shape) for tensor in tensors])


def check_valid_graph_convolution_input(data, neighbors, sizes):
def check_valid_graph_convolution_input(data: type_alias.TensorLike,
neighbors: tf.sparse.SparseTensor,
sizes: type_alias.TensorLike):
"""Checks that the inputs are valid for graph convolution ops.
Note:
Expand Down Expand Up @@ -86,7 +92,9 @@ def check_valid_graph_convolution_input(data, neighbors, sizes):
broadcast_compatible=False)


def check_valid_graph_pooling_input(data, pool_map, sizes):
def check_valid_graph_pooling_input(data: type_alias.TensorLike,
pool_map: tf.sparse.SparseTensor,
sizes: type_alias.TensorLike):
"""Checks that the inputs are valid for graph pooling.
Note:
Expand Down Expand Up @@ -136,7 +144,9 @@ def check_valid_graph_pooling_input(data, pool_map, sizes):
broadcast_compatible=False)


def check_valid_graph_unpooling_input(data, pool_map, sizes):
def check_valid_graph_unpooling_input(data: type_alias.TensorLike,
pool_map: tf.sparse.SparseTensor,
sizes: type_alias.TensorLike):
"""Checks that the inputs are valid for graph unpooling.
Note:
Expand Down Expand Up @@ -187,7 +197,9 @@ def check_valid_graph_unpooling_input(data, pool_map, sizes):
broadcast_compatible=False)


def flatten_batch_to_2d(data, sizes=None, name="utils_flatten_batch_to_2d"):
def flatten_batch_to_2d(data: type_alias.TensorLike,
sizes: type_alias.TensorLike = None,
name: str = "utils_flatten_batch_to_2d"):
"""Reshapes a batch of 2d Tensors by flattening across the batch dimensions.
Note:
Expand Down Expand Up @@ -286,10 +298,10 @@ def unflatten(flat, name="utils_unflatten"):
return flat, unflatten


def unflatten_2d_to_batch(data,
sizes,
max_rows=None,
name="utils_unflatten_2d_to_batch"):
def unflatten_2d_to_batch(data: type_alias.TensorLike,
sizes: type_alias.TensorLike,
max_rows: Optional[int] = None,
name: str = "utils_unflatten_2d_to_batch"):
r"""Reshapes a 2d Tensor into a batch of 2d Tensors.
The `data` tensor with shape `[D1, D2]` will be mapped to a tensor with shape
Expand Down Expand Up @@ -369,10 +381,10 @@ def unflatten_2d_to_batch(data,
return tf.scatter_nd(indices=mask_indices, updates=data, shape=output_shape)


def convert_to_block_diag_2d(data,
sizes=None,
validate_indices=False,
name="utils_convert_to_block_diag_2d"):
def convert_to_block_diag_2d(data: tf.sparse.SparseTensor,
sizes: Optional[type_alias.TensorLike] = None,
validate_indices: bool = False,
name: str = "utils_convert_to_block_diag_2d"):
"""Convert a batch of 2d SparseTensors to a 2d block diagonal SparseTensor.
Note:
Expand Down

0 comments on commit 3ba8341

Please sign in to comment.