Skip to content

Commit

Permalink
Adds typing information to the module util.shape.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 425878021
  • Loading branch information
G4G authored and Copybara-Service committed Feb 2, 2022
1 parent 70d95be commit 8538149
Showing 1 changed file with 29 additions and 16 deletions.
45 changes: 29 additions & 16 deletions tensorflow_graphics/util/shape.py
Expand Up @@ -18,6 +18,7 @@
from __future__ import print_function

import itertools
from typing import Any, List, Optional, Tuple, Union

import numpy as np
import six
Expand All @@ -26,7 +27,8 @@
import tensorflow as tf


def _broadcast_shape_helper(shape_x, shape_y):
def _broadcast_shape_helper(shape_x: tf.TensorShape,
shape_y: tf.TensorShape) -> Optional[List[Any]]:
"""Helper function for is_broadcast_compatible and broadcast_shape.
Args:
Expand Down Expand Up @@ -74,7 +76,8 @@ def _broadcast_shape_helper(shape_x, shape_y):
return return_dims


def is_broadcast_compatible(shape_x, shape_y):
def is_broadcast_compatible(shape_x: tf.TensorShape,
shape_y: tf.TensorShape) -> bool:
"""Returns True if `shape_x` and `shape_y` are broadcast compatible.
Args:
Expand All @@ -90,7 +93,8 @@ def is_broadcast_compatible(shape_x, shape_y):
return _broadcast_shape_helper(shape_x, shape_y) is not None


def get_broadcasted_shape(shape_x, shape_y):
def get_broadcasted_shape(shape_x: tf.TensorShape,
shape_y: tf.TensorShape) -> Optional[List[Any]]:
"""Returns the common shape for broadcast compatible shapes.
Args:
Expand Down Expand Up @@ -135,14 +139,15 @@ def _get_dim(tensor, axis):
return tf.compat.dimension_value(tensor.shape[axis])


def check_static(tensor,
has_rank=None,
has_rank_greater_than=None,
has_rank_less_than=None,
def check_static(tensor: tf.Tensor,
has_rank: Optional[int] = None,
has_rank_greater_than: Optional[int] = None,
has_rank_less_than: Optional[int] = None,
has_dim_equals=None,
has_dim_greater_than=None,
has_dim_less_than=None,
tensor_name='tensor'):
tensor_name: str = 'tensor') -> None:
# TODO(cengizo): Typing for has_dim_equals, has_dim_greater(less)_than.
"""Checks static shapes for rank and dimension constraints.
This function can be used to check a tensor's shape for multiple rank and
Expand Down Expand Up @@ -276,11 +281,12 @@ def _raise_error(tensor_names, batch_shapes):
'Not all batch dimensions are identical: {}'.format(formatted_list))


def compare_batch_dimensions(tensors,
last_axes,
broadcast_compatible,
initial_axes=0,
tensor_names=None):
def compare_batch_dimensions(
tensors: Union[List[tf.Tensor], Tuple[tf.Tensor]],
last_axes: Union[int, List[int], Tuple[int]],
broadcast_compatible: bool,
initial_axes: Union[int, List[int], Tuple[int]] = 0,
tensor_names: Optional[Union[List[str], Tuple[str]]] = None) -> None:
"""Compares batch dimensions for tensors with static shapes.
Args:
Expand Down Expand Up @@ -347,7 +353,10 @@ def compare_batch_dimensions(tensors,
]))


def compare_dimensions(tensors, axes, tensor_names=None):
def compare_dimensions(
tensors: Union[List[tf.Tensor], Tuple[tf.Tensor]],
axes: Union[int, List[int], Tuple[int]],
tensor_names: Optional[Union[List[str], Tuple[str]]] = None) -> None:
"""Compares dimensions of tensors with static or dynamic shapes.
Args:
Expand Down Expand Up @@ -376,15 +385,19 @@ def compare_dimensions(tensors, axes, tensor_names=None):
list(tensor_names), list(axes), list(dimensions)))


def is_static(tensor_shape):
def is_static(
tensor_shape: Union[List[Any], Tuple[Any], tf.TensorShape]) -> bool:
"""Checks if the given tensor shape is static."""
if isinstance(tensor_shape, (list, tuple)):
return None not in tensor_shape
else:
return None not in tensor_shape.as_list()


def add_batch_dimensions(tensor, tensor_name, batch_shape, last_axis=None):
def add_batch_dimensions(tensor: tf.Tensor,
tensor_name: str,
batch_shape: List[int],
last_axis: Optional[int] = None) -> tf.Tensor:
"""Broadcasts tensor to match batch dimensions.
It will either broadcast to all provided batch dimensions, therefore
Expand Down

0 comments on commit 8538149

Please sign in to comment.