diff --git a/tensorflow_graphics/geometry/convolution/graph_convolution.py b/tensorflow_graphics/geometry/convolution/graph_convolution.py index ab66e3db4..0f45c31e4 100644 --- a/tensorflow_graphics/geometry/convolution/graph_convolution.py +++ b/tensorflow_graphics/geometry/convolution/graph_convolution.py @@ -32,6 +32,7 @@ def feature_steered_convolution(data, var_c, var_w, var_b, + ordered=False, name=None): # pyformat: disable """Implements the Feature Steered graph convolution. @@ -81,6 +82,8 @@ def feature_steered_convolution(data, var_c: A 1-D tensor with shape `[W]`. var_w: A 3-D tensor with shape `[W, C, D]`. var_b: A 1-D tensor with shape `[D]`. + ordered: if True, neighbors is assumed to be ordered in the canonical, + row-major ordering, e.g. via `tf.sparse.reorder` name: A name for this op. Defaults to `graph_convolution_feature_steered_convolution`. @@ -146,8 +149,12 @@ def feature_steered_convolution(data, for q_m, w_m in zip(q_m_list, w_m_list): # Compute `y_i_m = sum_{j in neighborhood(i)} q_m(x_i, x_j) * w_m * x_j`. q_m = tf.expand_dims(q_m, axis=-1) - p_sum = utils.partition_sums_2d(q_m * x_sep, adjacency_ind_0, - adjacency.values) + terms = q_m * x_sep * tf.expand_dims(adjacency.values, axis=-1) + if ordered: + p_sum = tf.math.segment_sum(terms, adjacency_ind_0) + else: + p_sum = tf.math.unsorted_segment_sum( + terms, adjacency_ind_0, tf.reduce_max(adjacency_ind_0) + 1) y_i_m.append(tf.matmul(p_sum, w_m)) y_out = tf.add_n(inputs=y_i_m) + tf.reshape(var_b, [1, -1]) if data_ndims > 2: @@ -161,6 +168,7 @@ def edge_convolution_template(data, edge_function, reduction, edge_function_kwargs, + ordered=False, name=None): # pyformat: disable r"""A template for edge convolutions. @@ -219,7 +227,9 @@ def edge_convolution_template(data, in the equation above. For 'max' the reduction is a max over features in which case the weights $$w_{ij}$$ are ignored. edge_function_kwargs: A dict containing any additional keyword arguments to - be passed to `edge_function`. + be passed to `edge_function`. + ordered: if True, neighbors is assumed to be ordered in the canonical, + row-major ordering, e.g. via `tf.sparse.reorder` name: A name for this op. Defaults to `graph_convolution_edge_convolution_template`. @@ -262,17 +272,29 @@ def edge_convolution_template(data, **edge_function_kwargs) if reduction == "weighted": - features = utils.partition_sums_2d(edge_features, adjacency_ind_0, - adjacency.values) + edge_features = tf.expand_dims(adjacency.values, axis=-1) * edge_features + if ordered: + features = tf.math.segment_sum(edge_features, adjacency_ind_0) + else: + features = tf.math.unsorted_segment_sum( + data=edge_features, + segment_ids=adjacency_ind_0, + num_segments=tf.reduce_max(adjacency_ind_0) + 1) elif reduction == "max": - features = tf.math.segment_max(data=edge_features, - segment_ids=adjacency_ind_0) - features.set_shape(features.shape.merge_with( - (tf.compat.v1.dimension_value(x_flat.shape[0]), - tf.compat.v1.dimension_value(edge_features.shape[-1])))) + if ordered: + features = tf.math.segment_max( + edge_features, adjacency_ind_0) + else: + features = tf.math.unsorted_segment_max( + data=edge_features, + segment_ids=adjacency_ind_0, + num_segments=tf.reduce_max(adjacency_ind_0) + 1) else: raise ValueError("The reduction method must be 'weighted' or 'max'") + features.set_shape(features.shape.merge_with( + (tf.compat.v1.dimension_value(x_flat.shape[0]), + tf.compat.v1.dimension_value(edge_features.shape[-1])))) if data_ndims > 2: features = unflatten(features) return features diff --git a/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py b/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py index 27f28cdbb..6436c12bd 100644 --- a/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py +++ b/tensorflow_graphics/geometry/convolution/tests/graph_convolution_test.py @@ -27,23 +27,38 @@ from tensorflow_graphics.util import test_case -def _dense_to_sparse(data): +def _shuffle_sparse(st): + indices = st.indices + values = st.values + order = tf.range(tf.shape(indices, out_type=tf.int64)[-2]) + order = tf.random.shuffle(order) + indices = tf.gather(indices, order, axis=-2) + values = tf.gather(values, order, axis=0) + return tf.SparseTensor(indices, values, st.shape) + + +def _dense_to_sparse(data, ordered=True): """Convert a numpy array to a tf.SparseTensor.""" indices = np.where(data) - return tf.SparseTensor( + st = tf.SparseTensor( np.stack(indices, axis=-1), data[indices], dense_shape=data.shape) + if not ordered: + st = _shuffle_sparse(st) + return st -def _dummy_data(batch_size, num_vertices, num_channels): +def _dummy_data(batch_size, num_vertices, num_channels, ordered=True): """Create inputs for feature_steered_convolution.""" if batch_size > 0: data = np.zeros( shape=(batch_size, num_vertices, num_channels), dtype=np.float32) neighbors = _dense_to_sparse( - np.tile(np.eye(num_vertices, dtype=np.float32), (batch_size, 1, 1))) + np.tile(np.eye(num_vertices, dtype=np.float32), (batch_size, 1, 1)), + ordered=ordered) else: data = np.zeros(shape=(num_vertices, num_channels), dtype=np.float32) - neighbors = _dense_to_sparse(np.eye(num_vertices, dtype=np.float32)) + neighbors = _dense_to_sparse( + np.eye(num_vertices, dtype=np.float32), ordered=ordered) return data, neighbors @@ -64,7 +79,8 @@ def _random_data(batch_size, only_self_edges, data_type=np.float32, neighbors_type=np.float32, - sizes_type=np.int32): + sizes_type=np.int32, + ordered=True): """Create random inputs for feature_steered_convolution.""" def _random_data_2d(padding): @@ -94,15 +110,15 @@ def _random_data_2d(padding): neighbors = np.stack([i[1] for i in list_2d], 0).astype(neighbors_type) if padding: sizes = np.stack([i[2] for i in list_2d], 0).astype(sizes_type) - return data, _dense_to_sparse(neighbors), sizes + return data, _dense_to_sparse(neighbors, ordered=ordered), sizes else: - return data, _dense_to_sparse(neighbors) + return data, _dense_to_sparse(neighbors, ordered=ordered) else: if padding: raise ValueError("Padding only allowed with batched data.") data, neighbors = _random_data_2d(padding=False) return data.astype(data_type), _dense_to_sparse( - neighbors.astype(neighbors_type)) + neighbors.astype(neighbors_type), ordered=ordered) def _random_variables(in_channels, @@ -358,21 +374,26 @@ def test_feature_steered_convolution_padding_preset(self, data, neighbors, u, self.assertAllClose(y, expected) @parameterized.parameters( - (1, 5, 1, 1, 1), - (2, 6, 3, 6, 5), - (5, 15, 6, 12, 8), + (1, 5, 1, 1, 1, True), + (2, 6, 3, 6, 5, True), + (5, 15, 6, 12, 8, True), + (1, 5, 1, 1, 1, False), + (2, 6, 3, 6, 5, False), + (5, 15, 6, 12, 8, False), ) def test_feature_steered_convolution_padding_random(self, batch_size, num_vertices, in_channels, out_channels, - num_weight_matrices): + num_weight_matrices, + ordered): """Test mixed topology batches (random vertices and neighbors).""" data, neighbors, sizes = _random_data( batch_size, num_vertices, in_channels, padding=True, - only_self_edges=False) + only_self_edges=False, + ordered=ordered) u, v, c, w, b = _random_variables(in_channels, out_channels, num_weight_matrices) @@ -509,7 +530,7 @@ def _pass_through(self, vertex_features, neighbor_features): """A callable for `edge_convolution_template`.""" return neighbor_features - def _circular_2d_data(self, num_vertices, include_normals=False): + def _circular_2d_data(self, num_vertices, include_normals=False, ordered=True): """Create data for a circle graph.""" # Vertices are points distributed uniformly on a circle, with each point # connected to its closest neighbor on either side. @@ -520,7 +541,7 @@ def _circular_2d_data(self, num_vertices, include_normals=False): eye = np.eye(num_vertices) neighbors = np.maximum(np.roll(eye, 1, axis=1), np.roll(eye, -1, axis=1)) * 0.5 - return data, _dense_to_sparse(neighbors) + return data, _dense_to_sparse(neighbors, ordered=ordered) def _edge_curvature_2d(self, vertex_features, neighbor_features): """A callable for `edge_convolution_template` that computes curvature.""" @@ -664,16 +685,35 @@ def test_edge_convolution_template_output_shape(self, batch_size, self.assertAllEqual(y_shape[:-1], data.shape[:-1]) @parameterized.parameters( - (1, 10, 3, True, "weighted"), - (3, 6, 1, True, "weighted"), - (0, 10, 5, False, "weighted"), - (1, 10, 3, False, "max"), - (3, 6, 1, False, "max"), - (0, 10, 5, False, "max"), + (1, 10, 3, True, "weighted", True), + (3, 6, 1, True, "weighted", True), + (0, 10, 5, False, "weighted", True), + (1, 10, 3, False, "max", True), + (3, 6, 1, False, "max", True), + (0, 10, 5, False, "max", True), + (1, 10, 3, True, "weighted", True), + (3, 6, 1, True, "weighted", True), + (0, 10, 5, False, "weighted", True), + (1, 10, 3, False, "max", True), + (3, 6, 1, False, "max", True), + (0, 10, 5, False, "max", True), + (1, 10, 3, True, "weighted", False), + (3, 6, 1, True, "weighted", False), + (0, 10, 5, False, "weighted", False), + (1, 10, 3, False, "max", False), + (3, 6, 1, False, "max", False), + (0, 10, 5, False, "max", False), + (1, 10, 3, True, "weighted", False), + (3, 6, 1, True, "weighted", False), + (0, 10, 5, False, "weighted", False), + (1, 10, 3, False, "max", False), + (3, 6, 1, False, "max", False), + (0, 10, 5, False, "max", False), ) def test_edge_convolution_template_jacobian_random(self, batch_size, num_vertices, in_channels, - padding, reduction): + padding, reduction, + ordered): """Test the jacobian for random input data.""" random_data = _random_data( batch_size, @@ -682,7 +722,8 @@ def test_edge_convolution_template_jacobian_random(self, batch_size, padding, only_self_edges=False, data_type=np.float64, - neighbors_type=np.float64) + neighbors_type=np.float64, + ordered=ordered) data_init = random_data[0] neighbors = random_data[1] sizes = None if not padding else random_data[2] @@ -694,7 +735,8 @@ def test_edge_convolution_template_jacobian_random(self, batch_size, sizes=sizes, edge_function=self._pass_through, reduction=reduction, - edge_function_kwargs=dict()) + edge_function_kwargs=dict(), + ordered=ordered) self.assert_jacobian_is_correct(data, data_init, y) @@ -753,7 +795,8 @@ def test_edge_convolution_template_jacobian_preset(self, num_vertices, self.assert_jacobian_is_correct(data, data_init, y) - def test_edge_convolution_template_laplacian_smoothing(self): + @parameterized.parameters((True,), (False)) + def test_edge_convolution_template_laplacian_smoothing(self, ordered): r"""Test the expected result with laplacian smoothing. Laplacian smoothing for meshes is defined as @@ -780,7 +823,7 @@ def test_edge_convolution_template_laplacian_smoothing(self): with self.subTest(name="circular_2d"): num_vertices = 500 - data, neighbors = self._circular_2d_data(num_vertices) + data, neighbors = self._circular_2d_data(num_vertices, ordered=ordered) data_smoothed = gc.edge_convolution_template( data=data, @@ -794,7 +837,8 @@ def test_edge_convolution_template_laplacian_smoothing(self): self.assertAllClose(data, data_smoothed_normalized) - def test_edge_convolution_template_curvature(self): + @parameterized.parameters((True,), (False)) + def test_edge_convolution_template_curvature(self, ordered): r"""Test the expected result with curvature. (Approximate) curvature for meshes is defined as @@ -811,7 +855,8 @@ def test_edge_convolution_template_curvature(self): """ # We can reuse `self._edge_curvature_2d` as the curvature functional. num_vertices = 500 - data, neighbors = self._circular_2d_data(num_vertices, include_normals=True) + data, neighbors = self._circular_2d_data( + num_vertices, include_normals=True, ordered=ordered) data_curvature = gc.edge_convolution_template( data=data, diff --git a/tensorflow_graphics/geometry/convolution/utils.py b/tensorflow_graphics/geometry/convolution/utils.py index 9f383002f..3faf89779 100644 --- a/tensorflow_graphics/geometry/convolution/utils.py +++ b/tensorflow_graphics/geometry/convolution/utils.py @@ -456,13 +456,15 @@ def convert_to_block_diag_2d(data, return block_diag -def partition_sums_2d(data, group_ids, row_weights=None, name=None): +def partition_sums_2d( + data, group_ids, row_weights=None, is_sorted=False, name=None): """Sum over subsets of rows in a 2-D tensor. Args: data: 2-D tensor with shape `[D1, D2]`. group_ids: 1-D `int` tensor with shape `[D1]`. row_weights: 1-D tensor with shape `[D1]`. Can be `None`. + is_sorted: if True, group_ids are known to be sorted. name: A name for this op. Defaults to 'utils_partition_sums_2d'. Returns: @@ -481,27 +483,29 @@ def partition_sums_2d(data, group_ids, row_weights=None, name=None): raise TypeError("'group_ids' must be an integer tensor.") elif group_ids.dtype != tf.int64: group_ids = tf.cast(group_ids, dtype=tf.int64) - if row_weights is None: - row_weights = tf.ones_like(group_ids, dtype=data.dtype) - else: - row_weights = tf.convert_to_tensor(value=row_weights) - if row_weights.dtype != data.dtype: - raise TypeError("'data' and 'row_weights' must have the same type.") shape.check_static(tensor=data, tensor_name="data", has_rank=2) shape.check_static(tensor=group_ids, tensor_name="group_ids", has_rank=1) - shape.check_static( + + if row_weights is None: + shape.compare_dimensions( + tensors=(data, group_ids), + tensor_names=("data", "group_ids"), + axes=0) + else: + row_weights = tf.convert_to_tensor(value=row_weights) + shape.check_static( tensor=row_weights, tensor_name="row_weights", has_rank=1) - shape.compare_dimensions( + if row_weights.dtype != data.dtype: + raise TypeError("'data' and 'row_weights' must have the same type.") + shape.compare_dimensions( tensors=(data, group_ids, row_weights), tensor_names=("data", "group_ids", "row_weights"), axes=0) + data = data * tf.expand_dims(row_weights, axis=-1) - num_rows = tf.size(input=group_ids, out_type=tf.int64) - sparse_indices = tf.stack((group_ids, tf.range(num_rows)), axis=1) - out_shape = (tf.reduce_max(input_tensor=group_ids) + 1, num_rows) - sparse = tf.SparseTensor(sparse_indices, row_weights, dense_shape=out_shape) - return tf.sparse.sparse_dense_matmul(sparse, data) + return tf.math.unsorted_segment_sum( + data, group_ids, tf.reduce_max(group_ids) + 1) # API contains all public functions and classes. diff --git a/tensorflow_graphics/util/shape.py b/tensorflow_graphics/util/shape.py index 096af1880..5a6a39d09 100644 --- a/tensorflow_graphics/util/shape.py +++ b/tensorflow_graphics/util/shape.py @@ -367,9 +367,9 @@ def compare_dimensions(tensors, axes, tensor_names=None): tensor_names = _give_default_names(tensors, 'tensor') if not tf.executing_eagerly(): dimensions = [ - int(tensor.shape[axis]) if tensor.shape[axis].value is not None else 1 - for tensor, axis in zip(tensors, axes) - ] + 1 if dim is None else dim for dim in + (tf.compat.v1.dimension_value(tensor.shape[axis]) + for tensor, axis in zip(tensors, axes))] else: # In eager mode tensor.shape[axis].value doesn't exist if v2 behavior is # enabled. Therefore we can use tf.shape() to return the shape as a tensor.