Skip to content

Commit

Permalink
Migrate several left-off files to TensorFlow 2 and Python 3.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 380068484
  • Loading branch information
tensorflower-gardener authored and Copybara-Service committed Jun 17, 2021
1 parent 1d33eb6 commit 1ed0d01
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 58 deletions.
35 changes: 19 additions & 16 deletions tensorflow_graphics/geometry/convolution/graph_convolution.py
Expand Up @@ -17,24 +17,26 @@
from __future__ import division
from __future__ import print_function

from typing import Any, Callable, Dict
from six.moves import zip
import tensorflow as tf

from tensorflow_graphics.geometry.convolution import utils
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def feature_steered_convolution(
data,
neighbors,
sizes,
var_u,
var_v,
var_c,
var_w,
var_b,
name="graph_convolution_feature_steered_convolution"):
data: type_alias.TensorLike,
neighbors: tf.sparse.SparseTensor,
sizes: type_alias.TensorLike,
var_u: type_alias.TensorLike,
var_v: type_alias.TensorLike,
var_c: type_alias.TensorLike,
var_w: type_alias.TensorLike,
var_b: type_alias.TensorLike,
name="graph_convolution_feature_steered_convolution") -> tf.Tensor:
# pyformat: disable
"""Implements the Feature Steered graph convolution.
Expand Down Expand Up @@ -160,13 +162,14 @@ def feature_steered_convolution(


def edge_convolution_template(
data,
neighbors,
sizes,
edge_function,
reduction,
edge_function_kwargs,
name="graph_convolution_edge_convolution_template"):
data: type_alias.TensorLike,
neighbors: tf.sparse.SparseTensor,
sizes: type_alias.TensorLike,
edge_function: Callable[[type_alias.TensorLike, type_alias.TensorLike],
type_alias.TensorLike],
reduction: str,
edge_function_kwargs: Dict[str, Any],
name: str = "graph_convolution_edge_convolution_template") -> tf.Tensor:
# pyformat: disable
r"""A template for edge convolutions.
Expand Down
27 changes: 19 additions & 8 deletions tensorflow_graphics/geometry/convolution/graph_pooling.py
Expand Up @@ -17,13 +17,20 @@
from __future__ import division
from __future__ import print_function

from typing import Callable

import tensorflow as tf

from tensorflow_graphics.geometry.convolution import utils
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import type_alias


def pool(data, pool_map, sizes, algorithm='max', name='graph_pooling_pool'):
def pool(data: type_alias.TensorLike,
pool_map: type_alias.TensorLike,
sizes: type_alias.TensorLike,
algorithm: str = 'max',
name: str = 'graph_pooling_pool') -> tf.Tensor:
# pyformat: disable
"""Implements graph pooling.
Expand Down Expand Up @@ -109,7 +116,10 @@ def pool(data, pool_map, sizes, algorithm='max', name='graph_pooling_pool'):
return pooled


def unpool(data, pool_map, sizes, name='graph_pooling_unpool'):
def unpool(data: type_alias.TensorLike,
pool_map: type_alias.TensorLike,
sizes: type_alias.TensorLike,
name: str = 'graph_pooling_unpool') -> tf.Tensor:
# pyformat: disable
r"""Graph upsampling by inverting the pooling map.
Expand Down Expand Up @@ -176,12 +186,13 @@ def unpool(data, pool_map, sizes, name='graph_pooling_unpool'):


def upsample_transposed_convolution(
data,
pool_map,
sizes,
kernel_size,
transposed_convolution_op,
name='graph_pooling_upsample_transposed_convolution'):
data: type_alias.TensorLike,
pool_map: type_alias.TensorLike,
sizes: type_alias.TensorLike,
kernel_size: int,
transposed_convolution_op: Callable[[type_alias.TensorLike],
type_alias.TensorLike],
name: str = 'graph_pooling_upsample_transposed_convolution') -> tf.Tensor:
# pyformat: disable
r"""Graph upsampling by transposed convolution.
Expand Down
38 changes: 25 additions & 13 deletions tensorflow_graphics/geometry/convolution/utils.py
Expand Up @@ -17,12 +17,16 @@
from __future__ import division
from __future__ import print_function

from typing import Any, List, Optional, Tuple, Union

import tensorflow as tf

from tensorflow_graphics.util import shape
from tensorflow_graphics.util import type_alias


def _is_dynamic_shape(tensors):
def _is_dynamic_shape(tensors: Union[List[type_alias.TensorLike],
Tuple[Any, tf.sparse.SparseTensor]]):
"""Helper function to test if any tensor in a list has a dynamic shape.
Args:
Expand All @@ -36,7 +40,9 @@ def _is_dynamic_shape(tensors):
return not all([shape.is_static(tensor.shape) for tensor in tensors])


def check_valid_graph_convolution_input(data, neighbors, sizes):
def check_valid_graph_convolution_input(data: type_alias.TensorLike,
neighbors: tf.sparse.SparseTensor,
sizes: type_alias.TensorLike):
"""Checks that the inputs are valid for graph convolution ops.
Note:
Expand Down Expand Up @@ -86,7 +92,9 @@ def check_valid_graph_convolution_input(data, neighbors, sizes):
broadcast_compatible=False)


def check_valid_graph_pooling_input(data, pool_map, sizes):
def check_valid_graph_pooling_input(data: type_alias.TensorLike,
pool_map: tf.sparse.SparseTensor,
sizes: type_alias.TensorLike):
"""Checks that the inputs are valid for graph pooling.
Note:
Expand Down Expand Up @@ -136,7 +144,9 @@ def check_valid_graph_pooling_input(data, pool_map, sizes):
broadcast_compatible=False)


def check_valid_graph_unpooling_input(data, pool_map, sizes):
def check_valid_graph_unpooling_input(data: type_alias.TensorLike,
pool_map: tf.sparse.SparseTensor,
sizes: type_alias.TensorLike):
"""Checks that the inputs are valid for graph unpooling.
Note:
Expand Down Expand Up @@ -187,7 +197,9 @@ def check_valid_graph_unpooling_input(data, pool_map, sizes):
broadcast_compatible=False)


def flatten_batch_to_2d(data, sizes=None, name="utils_flatten_batch_to_2d"):
def flatten_batch_to_2d(data: type_alias.TensorLike,
sizes: type_alias.TensorLike = None,
name: str = "utils_flatten_batch_to_2d"):
"""Reshapes a batch of 2d Tensors by flattening across the batch dimensions.
Note:
Expand Down Expand Up @@ -286,10 +298,10 @@ def unflatten(flat, name="utils_unflatten"):
return flat, unflatten


def unflatten_2d_to_batch(data,
sizes,
max_rows=None,
name="utils_unflatten_2d_to_batch"):
def unflatten_2d_to_batch(data: type_alias.TensorLike,
sizes: type_alias.TensorLike,
max_rows: Optional[int] = None,
name: str = "utils_unflatten_2d_to_batch"):
r"""Reshapes a 2d Tensor into a batch of 2d Tensors.
The `data` tensor with shape `[D1, D2]` will be mapped to a tensor with shape
Expand Down Expand Up @@ -369,10 +381,10 @@ def unflatten_2d_to_batch(data,
return tf.scatter_nd(indices=mask_indices, updates=data, shape=output_shape)


def convert_to_block_diag_2d(data,
sizes=None,
validate_indices=False,
name="utils_convert_to_block_diag_2d"):
def convert_to_block_diag_2d(data: tf.sparse.SparseTensor,
sizes: Optional[type_alias.TensorLike] = None,
validate_indices: bool = False,
name: str = "utils_convert_to_block_diag_2d"):
"""Convert a batch of 2d SparseTensors to a 2d block diagonal SparseTensor.
Note:
Expand Down
Expand Up @@ -17,23 +17,26 @@
from __future__ import division
from __future__ import print_function

from typing import Optional

import tensorflow as tf

from tensorflow_graphics.geometry.transformation import quaternion
from tensorflow_graphics.math import vector
from tensorflow_graphics.util import export_api
from tensorflow_graphics.util import shape


def energy(vertices_rest_pose,
vertices_deformed_pose,
quaternions,
edges,
vertex_weight=None,
edge_weight=None,
conformal_energy=True,
aggregate_loss=True,
name=None):
from tensorflow_graphics.util import type_alias


def energy(vertices_rest_pose: type_alias.TensorLike,
vertices_deformed_pose: type_alias.TensorLike,
quaternions: type_alias.TensorLike,
edges: type_alias.TensorLike,
vertex_weight: Optional[type_alias.TensorLike] = None,
edge_weight: Optional[type_alias.TensorLike] = None,
conformal_energy: bool = True,
aggregate_loss: bool = True,
name: str = "as_conformal_as_possible_energy"):
"""Estimates an As Conformal As Possible (ACAP) fitting energy.
For a given mesh in rest pose, this function evaluates a variant of the ACAP
Expand Down Expand Up @@ -92,10 +95,7 @@ def energy(vertices_rest_pose,
ValueError: if the shape of `vertices_rest_pose`, `vertices_deformed_pose`,
`quaternions`, `edges`, `vertex_weight`, or `edge_weight` is not supported.
"""
with tf.compat.v1.name_scope(name, "as_conformal_as_possible_energy", [
vertices_rest_pose, vertices_deformed_pose, quaternions, edges,
conformal_energy, vertex_weight, edge_weight
]):
with tf.name_scope(name):
vertices_rest_pose = tf.convert_to_tensor(value=vertices_rest_pose)
vertices_deformed_pose = tf.convert_to_tensor(value=vertices_deformed_pose)
quaternions = tf.convert_to_tensor(value=quaternions)
Expand Down
10 changes: 4 additions & 6 deletions tensorflow_graphics/rendering/triangle_rasterizer.py
Expand Up @@ -30,7 +30,7 @@


def _dim_value(dim):
return 1 if dim is None else tf.compat.v1.dimension_value(dim)
return 1 if dim is None else tf.compat.dimension_value(dim)


def _merge_batch_dims(tensor, last_axis):
Expand All @@ -50,7 +50,7 @@ def rasterize(vertices,
image_size,
enable_cull_face=True,
backend=rasterization_backend.RasterizationBackends.OPENGL,
name=None):
name="triangle_rasterizer_rasterize"):
"""Rasterizes the scene.
Note:
Expand All @@ -73,7 +73,7 @@ def rasterize(vertices,
False.
backend: A rasterization_backend.RasterizationBackends enum containing the
backend method to use for rasterization.
name: A name for this op. Defaults to 'triangle_rasterizer_rasterize'.
name: A name for this op. Defaults to "triangle_rasterizer_rasterize".
Returns:
A dictionary. The key "mask" is of shape `[A1, ..., An, height, width, 1]`
Expand All @@ -83,9 +83,7 @@ def rasterize(vertices,
the dictionary contains perspective correct interpolated attributes of shape
`[A1, ..., An, height, width, K]` per entry in the `attributes` dictionary.
"""
with tf.compat.v1.name_scope(
name, "triangle_rasterizer_rasterize",
(vertices, triangles, attributes, view_projection_matrix)):
with tf.name_scope(name):
vertices = tf.convert_to_tensor(value=vertices)
triangles = tf.convert_to_tensor(value=triangles)
view_projection_matrix = tf.convert_to_tensor(value=view_projection_matrix)
Expand Down

0 comments on commit 1ed0d01

Please sign in to comment.